dotfiles/.config/nvim/pack/tree/start/plenary.nvim/lua/luassert/state.lua

128 lines
3 KiB
Lua
Raw Normal View History

2025-09-16 01:01:02 +02:00
-- maintains a state of the assert engine in a linked-list fashion
-- records; formatters, parameters, spies and stubs
local state_mt = {
__call = function(self)
self:revert()
end
}
local spies_mt = { __mode = "kv" }
local nilvalue = {} -- unique ID to refer to nil values for parameters
-- will hold the current state
local current
-- exported module table
local state = {}
------------------------------------------------------
-- Reverts to a (specific) snapshot.
-- @param self (optional) the snapshot to revert to. If not provided, it will revert to the last snapshot.
state.revert = function(self)
if not self then
-- no snapshot given, so move 1 up
self = current
if not self.previous then
-- top of list, no previous one, nothing to do
return
end
end
if getmetatable(self) ~= state_mt then error("Value provided is not a valid snapshot", 2) end
if self.next then
self.next:revert()
end
-- revert formatters in 'last'
self.formatters = {}
-- revert parameters in 'last'
self.parameters = {}
-- revert spies/stubs in 'last'
for s,_ in pairs(self.spies) do
self.spies[s] = nil
s:revert()
end
setmetatable(self, nil) -- invalidate as a snapshot
current = self.previous
current.next = nil
end
------------------------------------------------------
-- Creates a new snapshot.
-- @return snapshot table
state.snapshot = function()
local new = setmetatable ({
formatters = {},
parameters = {},
spies = setmetatable({}, spies_mt),
previous = current,
revert = state.revert,
}, state_mt)
if current then current.next = new end
current = new
return current
end
-- FORMATTERS
state.add_formatter = function(callback)
table.insert(current.formatters, 1, callback)
end
state.remove_formatter = function(callback, s)
s = s or current
for i, v in ipairs(s.formatters) do
if v == callback then
table.remove(s.formatters, i)
break
end
end
-- wasn't found, so traverse up 1 state
if s.previous then
state.remove_formatter(callback, s.previous)
end
end
state.format_argument = function(val, s, fmtargs)
s = s or current
for _, fmt in ipairs(s.formatters) do
local valfmt = fmt(val, fmtargs)
if valfmt ~= nil then return valfmt end
end
-- nothing found, check snapshot 1 up in list
if s.previous then
return state.format_argument(val, s.previous, fmtargs)
end
return nil -- end of list, couldn't format
end
-- PARAMETERS
state.set_parameter = function(name, value)
if value == nil then value = nilvalue end
current.parameters[name] = value
end
state.get_parameter = function(name, s)
s = s or current
local val = s.parameters[name]
if val == nil and s.previous then
-- not found, so check 1 up in list
return state.get_parameter(name, s.previous)
end
if val ~= nilvalue then
return val
end
return nil
end
-- SPIES / STUBS
state.add_spy = function(spy)
current.spies[spy] = true
end
state.snapshot() -- create initial state
return state