Meh I'll figure out submodules later

This commit is contained in:
mustard 2025-09-16 01:01:02 +02:00
parent 4ca9d44a90
commit 8cb281f436
352 changed files with 66107 additions and 0 deletions

View file

@ -0,0 +1,70 @@
local assert = require('luassert.assert')
local say = require('say')
-- Example usage:
-- local arr = { "one", "two", "three" }
--
-- assert.array(arr).has.no.holes() -- checks the array to not contain holes --> passes
-- assert.array(arr).has.no.holes(4) -- sets explicit length to 4 --> fails
--
-- local first_hole = assert.array(arr).has.holes(4) -- check array of size 4 to contain holes --> passes
-- assert.equal(4, first_hole) -- passes, as the index of the first hole is returned
-- Unique key to store the object we operate on in the state object
-- key must be unique, to make sure we do not have name collissions in the shared state object
local ARRAY_STATE_KEY = "__array_state"
-- The modifier, to store the object in our state
local function array(state, args, level)
assert(args.n > 0, "No array provided to the array-modifier")
assert(rawget(state, ARRAY_STATE_KEY) == nil, "Array already set")
rawset(state, ARRAY_STATE_KEY, args[1])
return state
end
-- The actual assertion that operates on our object, stored via the modifier
local function holes(state, args, level)
local length = args[1]
local arr = rawget(state, ARRAY_STATE_KEY) -- retrieve previously set object
-- only check against nil, metatable types are allowed
assert(arr ~= nil, "No array set, please use the array modifier to set the array to validate")
if length == nil then
length = 0
for i in pairs(arr) do
if type(i) == "number" and
i > length and
math.floor(i) == i then
length = i
end
end
end
assert(type(length) == "number", "expected array length to be of type 'number', got: "..tostring(length))
-- let's do the actual assertion
local missing
for i = 1, length do
if arr[i] == nil then
missing = i
break
end
end
-- format arguments for output strings;
args[1] = missing
args.n = missing and 1 or 0
return missing ~= nil, { missing } -- assert result + first missing index as return value
end
-- Register the proper assertion messages
say:set("assertion.array_holes.positive", [[
Expected array to have holes, but none was found.
]])
say:set("assertion.array_holes.negative", [[
Expected array to not have holes, hole found at position: %s
]])
-- Register the assertion, and the modifier
assert:register("assertion", "holes", holes,
"assertion.array_holes.positive",
"assertion.array_holes.negative")
assert:register("modifier", "array", array)

View file

@ -0,0 +1,180 @@
local s = require 'say'
local astate = require 'luassert.state'
local util = require 'luassert.util'
local unpack = util.unpack
local obj -- the returned module table
local level_mt = {}
-- list of namespaces
local namespace = require 'luassert.namespaces'
local function geterror(assertion_message, failure_message, args)
if util.hastostring(failure_message) then
failure_message = tostring(failure_message)
elseif failure_message ~= nil then
failure_message = astate.format_argument(failure_message)
end
local message = s(assertion_message, obj:format(args))
if message and failure_message then
message = failure_message .. "\n" .. message
end
return message or failure_message
end
local __state_meta = {
__call = function(self, ...)
local keys = util.extract_keys("assertion", self.tokens)
local assertion
for _, key in ipairs(keys) do
assertion = namespace.assertion[key] or assertion
end
if assertion then
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self)
end
end
local arguments = util.make_arglist(...)
local val, retargs = assertion.callback(self, arguments, util.errorlevel())
if (not val) == self.mod then
local message = assertion.positive_message
if not self.mod then
message = assertion.negative_message
end
local err = geterror(message, rawget(self,"failure_message"), arguments)
error(err or "assertion failed!", util.errorlevel())
end
if retargs then
return unpack(retargs)
end
return ...
else
local arguments = util.make_arglist(...)
self.tokens = {}
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self, arguments, util.errorlevel())
end
end
end
return self
end,
__index = function(self, key)
for token in key:lower():gmatch('[^_]+') do
table.insert(self.tokens, token)
end
return self
end
}
obj = {
state = function() return setmetatable({mod=true, tokens={}}, __state_meta) end,
-- registers a function in namespace
register = function(self, nspace, name, callback, positive_message, negative_message)
local lowername = name:lower()
if not namespace[nspace] then
namespace[nspace] = {}
end
namespace[nspace][lowername] = {
callback = callback,
name = lowername,
positive_message=positive_message,
negative_message=negative_message
}
end,
-- unregisters a function in a namespace
unregister = function(self, nspace, name)
local lowername = name:lower()
if not namespace[nspace] then
namespace[nspace] = {}
end
namespace[nspace][lowername] = nil
end,
-- registers a formatter
-- a formatter takes a single argument, and converts it to a string, or returns nil if it cannot format the argument
add_formatter = function(self, callback)
astate.add_formatter(callback)
end,
-- unregisters a formatter
remove_formatter = function(self, fmtr)
astate.remove_formatter(fmtr)
end,
format = function(self, args)
-- args.n specifies the number of arguments in case of 'trailing nil' arguments which get lost
local nofmt = args.nofmt or {} -- arguments in this list should not be formatted
local fmtargs = args.fmtargs or {} -- additional arguments to be passed to formatter
for i = 1, (args.n or #args) do -- cannot use pairs because table might have nils
if not nofmt[i] then
local val = args[i]
local valfmt = astate.format_argument(val, nil, fmtargs[i])
if valfmt == nil then valfmt = tostring(val) end -- no formatter found
args[i] = valfmt
end
end
return args
end,
set_parameter = function(self, name, value)
astate.set_parameter(name, value)
end,
get_parameter = function(self, name)
return astate.get_parameter(name)
end,
add_spy = function(self, spy)
astate.add_spy(spy)
end,
snapshot = function(self)
return astate.snapshot()
end,
level = function(self, level)
return setmetatable({
level = level
}, level_mt)
end,
-- returns the level if a level-value, otherwise nil
get_level = function(self, level)
if getmetatable(level) ~= level_mt then
return nil -- not a valid error-level
end
return level.level
end,
}
local __meta = {
__call = function(self, bool, message, level, ...)
if not bool then
local err_level = (self:get_level(level) or 1) + 1
error(message or "assertion failed!", err_level)
end
return bool , message , level , ...
end,
__index = function(self, key)
return rawget(self, key) or self.state()[key]
end,
}
return setmetatable(obj, __meta)

View file

@ -0,0 +1,334 @@
-- module will not return anything, only register assertions with the main assert engine
-- assertions take 2 parameters;
-- 1) state
-- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils
-- 3) level The level of the error position relative to the called function
-- returns; boolean; whether assertion passed
local assert = require('luassert.assert')
local astate = require ('luassert.state')
local util = require ('luassert.util')
local s = require('say')
local function format(val)
return astate.format_argument(val) or tostring(val)
end
local function set_failure_message(state, message)
if message ~= nil then
state.failure_message = message
end
end
local function unique(state, arguments, level)
local list = arguments[1]
local deep
local argcnt = arguments.n
if type(arguments[2]) == "boolean" or (arguments[2] == nil and argcnt > 2) then
deep = arguments[2]
set_failure_message(state, arguments[3])
else
if type(arguments[3]) == "boolean" then
deep = arguments[3]
end
set_failure_message(state, arguments[2])
end
for k,v in pairs(list) do
for k2, v2 in pairs(list) do
if k ~= k2 then
if deep and util.deepcompare(v, v2, true) then
return false
else
if v == v2 then
return false
end
end
end
end
end
return true
end
local function near(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 2, s("assertion.internal.argtolittle", { "near", 3, tostring(argcnt) }), level)
local expected = tonumber(arguments[1])
local actual = tonumber(arguments[2])
local tolerance = tonumber(arguments[3])
local numbertype = "number or object convertible to a number"
assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level)
assert(actual, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level)
assert(tolerance, s("assertion.internal.badargtype", { 3, "near", numbertype, format(arguments[3]) }), level)
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
arguments[3] = tolerance
arguments.nofmt = arguments.nofmt or {}
arguments.nofmt[3] = true
set_failure_message(state, arguments[4])
return (actual >= expected - tolerance and actual <= expected + tolerance)
end
local function matches(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 1, s("assertion.internal.argtolittle", { "matches", 2, tostring(argcnt) }), level)
local pattern = arguments[1]
local actual = nil
if util.hastostring(arguments[2]) or type(arguments[2]) == "number" then
actual = tostring(arguments[2])
end
local err_message
local init_arg_num = 3
for i=3,argcnt,1 do
if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then
if i == 3 then init_arg_num = init_arg_num + 1 end
err_message = util.tremove(arguments, i)
break
end
end
local init = arguments[3]
local plain = arguments[4]
local stringtype = "string or object convertible to a string"
assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level)
assert(actual, s("assertion.internal.badargtype", { 2, "matches", stringtype, format(arguments[2]) }), level)
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level)
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
set_failure_message(state, err_message)
local retargs
local ok
if plain then
ok = (actual:find(pattern, init, plain) ~= nil)
retargs = ok and { pattern } or {}
else
retargs = { actual:match(pattern, init) }
ok = (retargs[1] ~= nil)
end
return ok, retargs
end
local function equals(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 1, s("assertion.internal.argtolittle", { "equals", 2, tostring(argcnt) }), level)
local result = arguments[1] == arguments[2]
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
set_failure_message(state, arguments[3])
return result
end
local function same(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 1, s("assertion.internal.argtolittle", { "same", 2, tostring(argcnt) }), level)
if type(arguments[1]) == 'table' and type(arguments[2]) == 'table' then
local result, crumbs = util.deepcompare(arguments[1], arguments[2], true)
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
arguments.fmtargs = arguments.fmtargs or {}
arguments.fmtargs[1] = { crumbs = crumbs }
arguments.fmtargs[2] = { crumbs = crumbs }
set_failure_message(state, arguments[3])
return result
end
local result = arguments[1] == arguments[2]
-- switch arguments for proper output message
util.tinsert(arguments, 1, util.tremove(arguments, 2))
set_failure_message(state, arguments[3])
return result
end
local function truthy(state, arguments, level)
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "truthy", 1, tostring(argcnt) }), level)
set_failure_message(state, arguments[2])
return arguments[1] ~= false and arguments[1] ~= nil
end
local function falsy(state, arguments, level)
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "falsy", 1, tostring(argcnt) }), level)
return not truthy(state, arguments, level)
end
local function has_error(state, arguments, level)
local level = (level or 1) + 1
local retargs = util.shallowcopy(arguments)
local func = arguments[1]
local err_expected = arguments[2]
local failure_message = arguments[3]
assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error", "function or callable object", type(func) }), level)
local ok, err_actual = pcall(func)
if type(err_actual) == 'string' then
-- remove 'path/to/file:line: ' from string
err_actual = err_actual:gsub('^.-:%d+: ', '', 1)
end
retargs[1] = err_actual
arguments.nofmt = {}
arguments.n = 2
arguments[1] = (ok and '(no error)' or err_actual)
arguments[2] = (err_expected == nil and '(error)' or err_expected)
arguments.nofmt[1] = ok
arguments.nofmt[2] = (err_expected == nil)
set_failure_message(state, failure_message)
if ok or err_expected == nil then
return not ok, retargs
end
if type(err_expected) == 'string' then
-- err_actual must be (convertible to) a string
if util.hastostring(err_actual) then
err_actual = tostring(err_actual)
retargs[1] = err_actual
end
if type(err_actual) == 'string' then
return err_expected == err_actual, retargs
end
elseif type(err_expected) == 'number' then
if type(err_actual) == 'string' then
return tostring(err_expected) == tostring(tonumber(err_actual)), retargs
end
end
return same(state, {err_expected, err_actual, ["n"] = 2}), retargs
end
local function error_matches(state, arguments, level)
local level = (level or 1) + 1
local retargs = util.shallowcopy(arguments)
local argcnt = arguments.n
local func = arguments[1]
local pattern = arguments[2]
assert(argcnt > 1, s("assertion.internal.argtolittle", { "error_matches", 2, tostring(argcnt) }), level)
assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error_matches", "function or callable object", type(func) }), level)
assert(pattern == nil or type(pattern) == "string", s("assertion.internal.badargtype", { 2, "error", "string", type(pattern) }), level)
local failure_message
local init_arg_num = 3
for i=3,argcnt,1 do
if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then
if i == 3 then init_arg_num = init_arg_num + 1 end
failure_message = util.tremove(arguments, i)
break
end
end
local init = arguments[3]
local plain = arguments[4]
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level)
local ok, err_actual = pcall(func)
if type(err_actual) == 'string' then
-- remove 'path/to/file:line: ' from string
err_actual = err_actual:gsub('^.-:%d+: ', '', 1)
end
retargs[1] = err_actual
arguments.nofmt = {}
arguments.n = 2
arguments[1] = (ok and '(no error)' or err_actual)
arguments[2] = pattern
arguments.nofmt[1] = ok
arguments.nofmt[2] = false
set_failure_message(state, failure_message)
if ok then return not ok, retargs end
if err_actual == nil and pattern == nil then
return true, {}
end
-- err_actual must be (convertible to) a string
if util.hastostring(err_actual) or
type(err_actual) == "number" or
type(err_actual) == "boolean" then
err_actual = tostring(err_actual)
retargs[1] = err_actual
end
if type(err_actual) == 'string' then
local ok
local retargs_ok
if plain then
retargs_ok = { pattern }
ok = (err_actual:find(pattern, init, plain) ~= nil)
else
retargs_ok = { err_actual:match(pattern, init) }
ok = (retargs_ok[1] ~= nil)
end
if ok then retargs = retargs_ok end
return ok, retargs
end
return false, retargs
end
local function is_true(state, arguments, level)
util.tinsert(arguments, 2, true)
set_failure_message(state, arguments[3])
return arguments[1] == arguments[2]
end
local function is_false(state, arguments, level)
util.tinsert(arguments, 2, false)
set_failure_message(state, arguments[3])
return arguments[1] == arguments[2]
end
local function is_type(state, arguments, level, etype)
util.tinsert(arguments, 2, "type " .. etype)
arguments.nofmt = arguments.nofmt or {}
arguments.nofmt[2] = true
set_failure_message(state, arguments[3])
return arguments.n > 1 and type(arguments[1]) == etype
end
local function returned_arguments(state, arguments, level)
arguments[1] = tostring(arguments[1])
arguments[2] = tostring(arguments.n - 1)
arguments.nofmt = arguments.nofmt or {}
arguments.nofmt[1] = true
arguments.nofmt[2] = true
if arguments.n < 2 then arguments.n = 2 end
return arguments[1] == arguments[2]
end
local function set_message(state, arguments, level)
state.failure_message = arguments[1]
end
local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end
local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end
local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end
local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end
local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end
local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end
local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end
local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end
assert:register("modifier", "message", set_message)
assert:register("assertion", "true", is_true, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "false", is_false, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "boolean", is_boolean, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "number", is_number, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "string", is_string, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "table", is_table, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "nil", is_nil, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "userdata", is_userdata, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "function", is_function, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "thread", is_thread, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "returned_arguments", returned_arguments, "assertion.returned_arguments.positive", "assertion.returned_arguments.negative")
assert:register("assertion", "same", same, "assertion.same.positive", "assertion.same.negative")
assert:register("assertion", "matches", matches, "assertion.matches.positive", "assertion.matches.negative")
assert:register("assertion", "match", matches, "assertion.matches.positive", "assertion.matches.negative")
assert:register("assertion", "near", near, "assertion.near.positive", "assertion.near.negative")
assert:register("assertion", "equals", equals, "assertion.equals.positive", "assertion.equals.negative")
assert:register("assertion", "equal", equals, "assertion.equals.positive", "assertion.equals.negative")
assert:register("assertion", "unique", unique, "assertion.unique.positive", "assertion.unique.negative")
assert:register("assertion", "error", has_error, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "errors", has_error, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "error_matches", error_matches, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "error_match", error_matches, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "matches_error", error_matches, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "match_error", error_matches, "assertion.error.positive", "assertion.error.negative")
assert:register("assertion", "truthy", truthy, "assertion.truthy.positive", "assertion.truthy.negative")
assert:register("assertion", "falsy", falsy, "assertion.falsy.positive", "assertion.falsy.negative")

View file

@ -0,0 +1,9 @@
-- no longer needed, only for backward compatibility
local unpack = require ("luassert.util").unpack
return {
unpack = function(...)
print(debug.traceback("WARN: calling deprecated function 'luassert.compatibility.unpack' use 'luassert.util.unpack' instead"))
return unpack(...)
end
}

View file

@ -0,0 +1,28 @@
local format = function (str)
if type(str) ~= "string" then return nil end
local result = "Binary string length; " .. tostring(#str) .. " bytes\n"
local i = 1
local hex = ""
local chr = ""
while i <= #str do
local byte = str:byte(i)
hex = string.format("%s%2x ", hex, byte)
if byte < 32 then byte = string.byte(".") end
chr = chr .. string.char(byte)
if math.floor(i/16) == i/16 or i == #str then
-- reached end of line
hex = hex .. string.rep(" ", 16 * 3 - #hex)
chr = chr .. string.rep(" ", 16 - #chr)
result = result .. hex:sub(1, 8 * 3) .. " " .. hex:sub(8*3+1, -1) .. " " .. chr:sub(1,8) .. " " .. chr:sub(9,-1) .. "\n"
hex = ""
chr = ""
end
i = i + 1
end
return result
end
return format

View file

@ -0,0 +1,255 @@
-- module will not return anything, only register formatters with the main assert engine
local assert = require('luassert.assert')
local match = require('luassert.match')
local util = require('luassert.util')
local isatty, colors do
local ok, term = pcall(require, 'term')
isatty = io.type(io.stdout) == 'file' and ok and term.isatty(io.stdout)
if not isatty then
local isWindows = package.config:sub(1,1) == '\\'
if isWindows and os.getenv("ANSICON") then
isatty = true
end
end
colors = setmetatable({
none = function(c) return c end
},{ __index = function(self, key)
return function(c)
for token in key:gmatch("[^%.]+") do
c = term.colors[token](c)
end
return c
end
end
})
end
local function fmt_string(arg)
if type(arg) == "string" then
return string.format("(string) '%s'", arg)
end
end
-- A version of tostring which formats numbers more precisely.
local function tostr(arg)
if type(arg) ~= "number" then
return tostring(arg)
end
if arg ~= arg then
return "NaN"
elseif arg == 1/0 then
return "Inf"
elseif arg == -1/0 then
return "-Inf"
end
local str = string.format("%.20g", arg)
if math.type and math.type(arg) == "float" and not str:find("[%.,]") then
-- Number is a float but looks like an integer.
-- Insert ".0" after first run of digits.
str = str:gsub("%d+", "%0.0", 1)
end
return str
end
local function fmt_number(arg)
if type(arg) == "number" then
return string.format("(number) %s", tostr(arg))
end
end
local function fmt_boolean(arg)
if type(arg) == "boolean" then
return string.format("(boolean) %s", tostring(arg))
end
end
local function fmt_nil(arg)
if type(arg) == "nil" then
return "(nil)"
end
end
local type_priorities = {
number = 1,
boolean = 2,
string = 3,
table = 4,
["function"] = 5,
userdata = 6,
thread = 7
}
local function is_in_array_part(key, length)
return type(key) == "number" and 1 <= key and key <= length and math.floor(key) == key
end
local function get_sorted_keys(t)
local keys = {}
local nkeys = 0
for key in pairs(t) do
nkeys = nkeys + 1
keys[nkeys] = key
end
local length = #t
local function key_comparator(key1, key2)
local type1, type2 = type(key1), type(key2)
local priority1 = is_in_array_part(key1, length) and 0 or type_priorities[type1] or 8
local priority2 = is_in_array_part(key2, length) and 0 or type_priorities[type2] or 8
if priority1 == priority2 then
if type1 == "string" or type1 == "number" then
return key1 < key2
elseif type1 == "boolean" then
return key1 -- put true before false
end
else
return priority1 < priority2
end
end
table.sort(keys, key_comparator)
return keys, nkeys
end
local function fmt_table(arg, fmtargs)
if type(arg) ~= "table" then
return
end
local tmax = assert:get_parameter("TableFormatLevel")
local showrec = assert:get_parameter("TableFormatShowRecursion")
local errchar = assert:get_parameter("TableErrorHighlightCharacter") or ""
local errcolor = assert:get_parameter("TableErrorHighlightColor")
local crumbs = fmtargs and fmtargs.crumbs or {}
local cache = {}
local type_desc
if getmetatable(arg) == nil then
type_desc = "(" .. tostring(arg) .. ") "
elseif not pcall(setmetatable, arg, getmetatable(arg)) then
-- cannot set same metatable, so it is protected, skip id
type_desc = "(table) "
else
-- unprotected metatable, temporary remove the mt
local mt = getmetatable(arg)
setmetatable(arg, nil)
type_desc = "(" .. tostring(arg) .. ") "
setmetatable(arg, mt)
end
local function ft(t, l, with_crumbs)
if showrec and cache[t] and cache[t] > 0 then
return "{ ... recursive }"
end
if next(t) == nil then
return "{ }"
end
if l > tmax and tmax >= 0 then
return "{ ... more }"
end
local result = "{"
local keys, nkeys = get_sorted_keys(t)
cache[t] = (cache[t] or 0) + 1
local crumb = crumbs[#crumbs - l + 1]
for i = 1, nkeys do
local k = keys[i]
local v = t[k]
local use_crumbs = with_crumbs and k == crumb
if type(v) == "table" then
v = ft(v, l + 1, use_crumbs)
elseif type(v) == "string" then
v = "'"..v.."'"
end
local ch = use_crumbs and errchar or ""
local indent = string.rep(" ",l * 2 - ch:len())
local mark = (ch:len() == 0 and "" or colors[errcolor](ch))
result = result .. string.format("\n%s%s[%s] = %s", indent, mark, tostr(k), tostr(v))
end
cache[t] = cache[t] - 1
return result .. " }"
end
return type_desc .. ft(arg, 1, true)
end
local function fmt_function(arg)
if type(arg) == "function" then
local debug_info = debug.getinfo(arg)
return string.format("%s @ line %s in %s", tostring(arg), tostring(debug_info.linedefined), tostring(debug_info.source))
end
end
local function fmt_userdata(arg)
if type(arg) == "userdata" then
return string.format("(userdata) '%s'", tostring(arg))
end
end
local function fmt_thread(arg)
if type(arg) == "thread" then
return string.format("(thread) '%s'", tostring(arg))
end
end
local function fmt_matcher(arg)
if not match.is_matcher(arg) then
return
end
local not_inverted = {
[true] = "is.",
[false] = "no.",
}
local args = {}
for idx = 1, arg.arguments.n do
table.insert(args, assert:format({ arg.arguments[idx], n = 1, })[1])
end
return string.format("(matcher) %s%s(%s)",
not_inverted[arg.mod],
tostring(arg.name),
table.concat(args, ", "))
end
local function fmt_arglist(arglist)
if not util.is_arglist(arglist) then
return
end
local formatted_vals = {}
for idx = 1, arglist.n do
table.insert(formatted_vals, assert:format({ arglist[idx], n = 1, })[1])
end
return "(values list) (" .. table.concat(formatted_vals, ", ") .. ")"
end
assert:add_formatter(fmt_string)
assert:add_formatter(fmt_number)
assert:add_formatter(fmt_boolean)
assert:add_formatter(fmt_nil)
assert:add_formatter(fmt_table)
assert:add_formatter(fmt_function)
assert:add_formatter(fmt_userdata)
assert:add_formatter(fmt_thread)
assert:add_formatter(fmt_matcher)
assert:add_formatter(fmt_arglist)
-- Set default table display depth for table formatter
assert:set_parameter("TableFormatLevel", 3)
assert:set_parameter("TableFormatShowRecursion", false)
assert:set_parameter("TableErrorHighlightCharacter", "*")
assert:set_parameter("TableErrorHighlightColor", isatty and "red" or "none")

View file

@ -0,0 +1,17 @@
local assert = require('luassert.assert')
assert._COPYRIGHT = "Copyright (c) 2018 Olivine Labs, LLC."
assert._DESCRIPTION = "Extends Lua's built-in assertions to provide additional tests and the ability to create your own."
assert._VERSION = "Luassert 1.8.0"
-- load basic asserts
require('luassert.assertions')
require('luassert.modifiers')
require('luassert.array')
require('luassert.matchers')
require('luassert.formatters')
-- load default language
require('luassert.languages.en')
return assert

View file

@ -0,0 +1,48 @@
local s = require('say')
s:set_namespace('en')
s:set("assertion.same.positive", "Expected objects to be the same.\nPassed in:\n%s\nExpected:\n%s")
s:set("assertion.same.negative", "Expected objects to not be the same.\nPassed in:\n%s\nDid not expect:\n%s")
s:set("assertion.equals.positive", "Expected objects to be equal.\nPassed in:\n%s\nExpected:\n%s")
s:set("assertion.equals.negative", "Expected objects to not be equal.\nPassed in:\n%s\nDid not expect:\n%s")
s:set("assertion.near.positive", "Expected values to be near.\nPassed in:\n%s\nExpected:\n%s +/- %s")
s:set("assertion.near.negative", "Expected values to not be near.\nPassed in:\n%s\nDid not expect:\n%s +/- %s")
s:set("assertion.matches.positive", "Expected strings to match.\nPassed in:\n%s\nExpected:\n%s")
s:set("assertion.matches.negative", "Expected strings not to match.\nPassed in:\n%s\nDid not expect:\n%s")
s:set("assertion.unique.positive", "Expected object to be unique:\n%s")
s:set("assertion.unique.negative", "Expected object to not be unique:\n%s")
s:set("assertion.error.positive", "Expected a different error.\nCaught:\n%s\nExpected:\n%s")
s:set("assertion.error.negative", "Expected no error, but caught:\n%s")
s:set("assertion.truthy.positive", "Expected to be truthy, but value was:\n%s")
s:set("assertion.truthy.negative", "Expected to not be truthy, but value was:\n%s")
s:set("assertion.falsy.positive", "Expected to be falsy, but value was:\n%s")
s:set("assertion.falsy.negative", "Expected to not be falsy, but value was:\n%s")
s:set("assertion.called.positive", "Expected to be called %s time(s), but was called %s time(s)")
s:set("assertion.called.negative", "Expected not to be called exactly %s time(s), but it was.")
s:set("assertion.called_at_least.positive", "Expected to be called at least %s time(s), but was called %s time(s)")
s:set("assertion.called_at_most.positive", "Expected to be called at most %s time(s), but was called %s time(s)")
s:set("assertion.called_more_than.positive", "Expected to be called more than %s time(s), but was called %s time(s)")
s:set("assertion.called_less_than.positive", "Expected to be called less than %s time(s), but was called %s time(s)")
s:set("assertion.called_with.positive", "Function was never called with matching arguments.\nCalled with (last call if any):\n%s\nExpected:\n%s")
s:set("assertion.called_with.negative", "Function was called with matching arguments at least once.\nCalled with (last matching call):\n%s\nDid not expect:\n%s")
s:set("assertion.returned_with.positive", "Function never returned matching arguments.\nReturned (last call if any):\n%s\nExpected:\n%s")
s:set("assertion.returned_with.negative", "Function returned matching arguments at least once.\nReturned (last matching call):\n%s\nDid not expect:\n%s")
s:set("assertion.returned_arguments.positive", "Expected to be called with %s argument(s), but was called with %s")
s:set("assertion.returned_arguments.negative", "Expected not to be called with %s argument(s), but was called with %s")
-- errors
s:set("assertion.internal.argtolittle", "the '%s' function requires a minimum of %s arguments, got: %s")
s:set("assertion.internal.badargtype", "bad argument #%s to '%s' (%s expected, got %s)")

View file

@ -0,0 +1,79 @@
local namespace = require 'luassert.namespaces'
local util = require 'luassert.util'
local matcher_mt = {
__call = function(self, value)
return self.callback(value) == self.mod
end,
}
local state_mt = {
__call = function(self, ...)
local keys = util.extract_keys("matcher", self.tokens)
self.tokens = {}
local matcher
for _, key in ipairs(keys) do
matcher = namespace.matcher[key] or matcher
end
if matcher then
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self)
end
end
local arguments = util.make_arglist(...)
local matches = matcher.callback(self, arguments, util.errorlevel())
return setmetatable({
name = matcher.name,
mod = self.mod,
callback = matches,
arguments = arguments,
}, matcher_mt)
else
local arguments = util.make_arglist(...)
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self, arguments, util.errorlevel())
end
end
end
return self
end,
__index = function(self, key)
for token in key:lower():gmatch('[^_]+') do
table.insert(self.tokens, token)
end
return self
end
}
local match = {
_ = setmetatable({mod=true, callback=function() return true end}, matcher_mt),
state = function() return setmetatable({mod=true, tokens={}}, state_mt) end,
is_matcher = function(object)
return type(object) == "table" and getmetatable(object) == matcher_mt
end,
is_ref_matcher = function(object)
local ismatcher = (type(object) == "table" and getmetatable(object) == matcher_mt)
return ismatcher and object.name == "ref"
end,
}
local mt = {
__index = function(self, key)
return rawget(self, key) or self.state()[key]
end,
}
return setmetatable(match, mt)

View file

@ -0,0 +1,61 @@
local assert = require('luassert.assert')
local match = require ('luassert.match')
local s = require('say')
local function none(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "none", 1, tostring(argcnt) }), level)
for i = 1, argcnt do
assert(match.is_matcher(arguments[i]), s("assertion.internal.badargtype", { 1, "none", "matcher", type(arguments[i]) }), level)
end
return function(value)
for _, matcher in ipairs(arguments) do
if matcher(value) then
return false
end
end
return true
end
end
local function any(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "any", 1, tostring(argcnt) }), level)
for i = 1, argcnt do
assert(match.is_matcher(arguments[i]), s("assertion.internal.badargtype", { 1, "any", "matcher", type(arguments[i]) }), level)
end
return function(value)
for _, matcher in ipairs(arguments) do
if matcher(value) then
return true
end
end
return false
end
end
local function all(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "all", 1, tostring(argcnt) }), level)
for i = 1, argcnt do
assert(match.is_matcher(arguments[i]), s("assertion.internal.badargtype", { 1, "all", "matcher", type(arguments[i]) }), level)
end
return function(value)
for _, matcher in ipairs(arguments) do
if not matcher(value) then
return false
end
end
return true
end
end
assert:register("matcher", "none_of", none)
assert:register("matcher", "any_of", any)
assert:register("matcher", "all_of", all)

View file

@ -0,0 +1,173 @@
-- module will return the list of matchers, and registers matchers with the main assert engine
-- matchers take 1 parameters;
-- 1) state
-- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils
-- 3) level The level of the error position relative to the called function
-- returns; function (or callable object); a function that, given an argument, returns a boolean
local assert = require('luassert.assert')
local astate = require('luassert.state')
local util = require('luassert.util')
local s = require('say')
local function format(val)
return astate.format_argument(val) or tostring(val)
end
local function unique(state, arguments, level)
local deep = arguments[1]
return function(value)
local list = value
for k,v in pairs(list) do
for k2, v2 in pairs(list) do
if k ~= k2 then
if deep and util.deepcompare(v, v2, true) then
return false
else
if v == v2 then
return false
end
end
end
end
end
return true
end
end
local function near(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 1, s("assertion.internal.argtolittle", { "near", 2, tostring(argcnt) }), level)
local expected = tonumber(arguments[1])
local tolerance = tonumber(arguments[2])
local numbertype = "number or object convertible to a number"
assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level)
assert(tolerance, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level)
return function(value)
local actual = tonumber(value)
if not actual then return false end
return (actual >= expected - tolerance and actual <= expected + tolerance)
end
end
local function matches(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "matches", 1, tostring(argcnt) }), level)
local pattern = arguments[1]
local init = arguments[2]
local plain = arguments[3]
assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level)
assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { 2, "matches", "number", type(arguments[2]) }), level)
return function(value)
local actualtype = type(value)
local actual = nil
if actualtype == "string" or actualtype == "number" or
actualtype == "table" and (getmetatable(value) or {}).__tostring then
actual = tostring(value)
end
if not actual then return false end
return (actual:find(pattern, init, plain) ~= nil)
end
end
local function equals(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "equals", 1, tostring(argcnt) }), level)
return function(value)
return value == arguments[1]
end
end
local function same(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
assert(argcnt > 0, s("assertion.internal.argtolittle", { "same", 1, tostring(argcnt) }), level)
return function(value)
if type(value) == 'table' and type(arguments[1]) == 'table' then
local result = util.deepcompare(value, arguments[1], true)
return result
end
return value == arguments[1]
end
end
local function ref(state, arguments, level)
local level = (level or 1) + 1
local argcnt = arguments.n
local argtype = type(arguments[1])
local isobject = (argtype == "table" or argtype == "function" or argtype == "thread" or argtype == "userdata")
assert(argcnt > 0, s("assertion.internal.argtolittle", { "ref", 1, tostring(argcnt) }), level)
assert(isobject, s("assertion.internal.badargtype", { 1, "ref", "object", argtype }), level)
return function(value)
return value == arguments[1]
end
end
local function is_true(state, arguments, level)
return function(value)
return value == true
end
end
local function is_false(state, arguments, level)
return function(value)
return value == false
end
end
local function truthy(state, arguments, level)
return function(value)
return value ~= false and value ~= nil
end
end
local function falsy(state, arguments, level)
local is_truthy = truthy(state, arguments, level)
return function(value)
return not is_truthy(value)
end
end
local function is_type(state, arguments, level, etype)
return function(value)
return type(value) == etype
end
end
local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end
local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end
local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end
local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end
local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end
local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end
local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end
local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end
assert:register("matcher", "true", is_true)
assert:register("matcher", "false", is_false)
assert:register("matcher", "nil", is_nil)
assert:register("matcher", "boolean", is_boolean)
assert:register("matcher", "number", is_number)
assert:register("matcher", "string", is_string)
assert:register("matcher", "table", is_table)
assert:register("matcher", "function", is_function)
assert:register("matcher", "userdata", is_userdata)
assert:register("matcher", "thread", is_thread)
assert:register("matcher", "ref", ref)
assert:register("matcher", "same", same)
assert:register("matcher", "matches", matches)
assert:register("matcher", "match", matches)
assert:register("matcher", "near", near)
assert:register("matcher", "equals", equals)
assert:register("matcher", "equal", equals)
assert:register("matcher", "unique", unique)
assert:register("matcher", "truthy", truthy)
assert:register("matcher", "falsy", falsy)

View file

@ -0,0 +1,3 @@
-- load basic machers
require('luassert.matchers.core')
require('luassert.matchers.composite')

View file

@ -0,0 +1,61 @@
-- module will return a mock module table, and will not register any assertions
local spy = require 'luassert.spy'
local stub = require 'luassert.stub'
local function mock_apply(object, action)
if type(object) ~= "table" then return end
if spy.is_spy(object) then
return object[action](object)
end
for k,v in pairs(object) do
mock_apply(v, action)
end
return object
end
local mock
mock = {
new = function(object, dostub, func, self, key)
local visited = {}
local function do_mock(object, self, key)
local mock_handlers = {
["table"] = function()
if spy.is_spy(object) or visited[object] then return end
visited[object] = true
for k,v in pairs(object) do
object[k] = do_mock(v, object, k)
end
return object
end,
["function"] = function()
if dostub then
return stub(self, key, func)
elseif self==nil then
return spy.new(object)
else
return spy.on(self, key)
end
end
}
local handler = mock_handlers[type(object)]
return handler and handler() or object
end
return do_mock(object, self, key)
end,
clear = function(object)
return mock_apply(object, "clear")
end,
revert = function(object)
return mock_apply(object, "revert")
end
}
return setmetatable(mock, {
__call = function(self, ...)
-- mock originally was a function only. Now that it is a module table
-- the __call method is required for backward compatibility
return mock.new(...)
end
})

View file

@ -0,0 +1,19 @@
-- module will not return anything, only register assertions/modifiers with the main assert engine
local assert = require('luassert.assert')
local function is(state)
return state
end
local function is_not(state)
state.mod = not state.mod
return state
end
assert:register("modifier", "is", is)
assert:register("modifier", "are", is)
assert:register("modifier", "was", is)
assert:register("modifier", "has", is)
assert:register("modifier", "does", is)
assert:register("modifier", "not", is_not)
assert:register("modifier", "no", is_not)

View file

@ -0,0 +1,2 @@
-- stores the list of namespaces
return {}

View file

@ -0,0 +1,195 @@
-- module will return spy table, and register its assertions with the main assert engine
local assert = require('luassert.assert')
local util = require('luassert.util')
-- Spy metatable
local spy_mt = {
__call = function(self, ...)
local arguments = util.make_arglist(...)
table.insert(self.calls, util.copyargs(arguments))
local function get_returns(...)
local returnvals = util.make_arglist(...)
table.insert(self.returnvals, util.copyargs(returnvals))
return ...
end
return get_returns(self.callback(...))
end
}
local spy -- must make local before defining table, because table contents refers to the table (recursion)
spy = {
new = function(callback)
callback = callback or function() end
if not util.callable(callback) then
error("Cannot spy on type '" .. type(callback) .. "', only on functions or callable elements", util.errorlevel())
end
local s = setmetatable({
calls = {},
returnvals = {},
callback = callback,
target_table = nil, -- these will be set when using 'spy.on'
target_key = nil,
revert = function(self)
if not self.reverted then
if self.target_table and self.target_key then
self.target_table[self.target_key] = self.callback
end
self.reverted = true
end
return self.callback
end,
clear = function(self)
self.calls = {}
self.returnvals = {}
return self
end,
called = function(self, times, compare)
if times or compare then
local compare = compare or function(count, expected) return count == expected end
return compare(#self.calls, times), #self.calls
end
return (#self.calls > 0), #self.calls
end,
called_with = function(self, args)
local last_arglist = nil
if #self.calls > 0 then
last_arglist = self.calls[#self.calls].vals
end
local matching_arglists = util.matchargs(self.calls, args)
if matching_arglists ~= nil then
return true, matching_arglists.vals
end
return false, last_arglist
end,
returned_with = function(self, args)
local last_returnvallist = nil
if #self.returnvals > 0 then
last_returnvallist = self.returnvals[#self.returnvals].vals
end
local matching_returnvallists = util.matchargs(self.returnvals, args)
if matching_returnvallists ~= nil then
return true, matching_returnvallists.vals
end
return false, last_returnvallist
end
}, spy_mt)
assert:add_spy(s) -- register with the current state
return s
end,
is_spy = function(object)
return type(object) == "table" and getmetatable(object) == spy_mt
end,
on = function(target_table, target_key)
local s = spy.new(target_table[target_key])
target_table[target_key] = s
-- store original data
s.target_table = target_table
s.target_key = target_key
return s
end
}
local function set_spy(state, arguments, level)
state.payload = arguments[1]
if arguments[2] ~= nil then
state.failure_message = arguments[2]
end
end
local function returned_with(state, arguments, level)
local level = (level or 1) + 1
local payload = rawget(state, "payload")
if payload and payload.returned_with then
local assertion_holds, matching_or_last_returnvallist = state.payload:returned_with(arguments)
local expected_returnvallist = util.shallowcopy(arguments)
util.cleararglist(arguments)
util.tinsert(arguments, 1, matching_or_last_returnvallist)
util.tinsert(arguments, 2, expected_returnvallist)
return assertion_holds
else
error("'returned_with' must be chained after 'spy(aspy)'", level)
end
end
local function called_with(state, arguments, level)
local level = (level or 1) + 1
local payload = rawget(state, "payload")
if payload and payload.called_with then
local assertion_holds, matching_or_last_arglist = state.payload:called_with(arguments)
local expected_arglist = util.shallowcopy(arguments)
util.cleararglist(arguments)
util.tinsert(arguments, 1, matching_or_last_arglist)
util.tinsert(arguments, 2, expected_arglist)
return assertion_holds
else
error("'called_with' must be chained after 'spy(aspy)'", level)
end
end
local function called(state, arguments, level, compare)
local level = (level or 1) + 1
local num_times = arguments[1]
if not num_times and not state.mod then
state.mod = true
num_times = 0
end
local payload = rawget(state, "payload")
if payload and type(payload) == "table" and payload.called then
local result, count = state.payload:called(num_times, compare)
arguments[1] = tostring(num_times or ">0")
util.tinsert(arguments, 2, tostring(count))
arguments.nofmt = arguments.nofmt or {}
arguments.nofmt[1] = true
arguments.nofmt[2] = true
return result
elseif payload and type(payload) == "function" then
error("When calling 'spy(aspy)', 'aspy' must not be the original function, but the spy function replacing the original", level)
else
error("'called' must be chained after 'spy(aspy)'", level)
end
end
local function called_at_least(state, arguments, level)
local level = (level or 1) + 1
return called(state, arguments, level, function(count, expected) return count >= expected end)
end
local function called_at_most(state, arguments, level)
local level = (level or 1) + 1
return called(state, arguments, level, function(count, expected) return count <= expected end)
end
local function called_more_than(state, arguments, level)
local level = (level or 1) + 1
return called(state, arguments, level, function(count, expected) return count > expected end)
end
local function called_less_than(state, arguments, level)
local level = (level or 1) + 1
return called(state, arguments, level, function(count, expected) return count < expected end)
end
assert:register("modifier", "spy", set_spy)
assert:register("assertion", "returned_with", returned_with, "assertion.returned_with.positive", "assertion.returned_with.negative")
assert:register("assertion", "called_with", called_with, "assertion.called_with.positive", "assertion.called_with.negative")
assert:register("assertion", "called", called, "assertion.called.positive", "assertion.called.negative")
assert:register("assertion", "called_at_least", called_at_least, "assertion.called_at_least.positive", "assertion.called_less_than.positive")
assert:register("assertion", "called_at_most", called_at_most, "assertion.called_at_most.positive", "assertion.called_more_than.positive")
assert:register("assertion", "called_more_than", called_more_than, "assertion.called_more_than.positive", "assertion.called_at_most.positive")
assert:register("assertion", "called_less_than", called_less_than, "assertion.called_less_than.positive", "assertion.called_at_least.positive")
return setmetatable(spy, {
__call = function(self, ...)
return spy.new(...)
end
})

View file

@ -0,0 +1,127 @@
-- maintains a state of the assert engine in a linked-list fashion
-- records; formatters, parameters, spies and stubs
local state_mt = {
__call = function(self)
self:revert()
end
}
local spies_mt = { __mode = "kv" }
local nilvalue = {} -- unique ID to refer to nil values for parameters
-- will hold the current state
local current
-- exported module table
local state = {}
------------------------------------------------------
-- Reverts to a (specific) snapshot.
-- @param self (optional) the snapshot to revert to. If not provided, it will revert to the last snapshot.
state.revert = function(self)
if not self then
-- no snapshot given, so move 1 up
self = current
if not self.previous then
-- top of list, no previous one, nothing to do
return
end
end
if getmetatable(self) ~= state_mt then error("Value provided is not a valid snapshot", 2) end
if self.next then
self.next:revert()
end
-- revert formatters in 'last'
self.formatters = {}
-- revert parameters in 'last'
self.parameters = {}
-- revert spies/stubs in 'last'
for s,_ in pairs(self.spies) do
self.spies[s] = nil
s:revert()
end
setmetatable(self, nil) -- invalidate as a snapshot
current = self.previous
current.next = nil
end
------------------------------------------------------
-- Creates a new snapshot.
-- @return snapshot table
state.snapshot = function()
local new = setmetatable ({
formatters = {},
parameters = {},
spies = setmetatable({}, spies_mt),
previous = current,
revert = state.revert,
}, state_mt)
if current then current.next = new end
current = new
return current
end
-- FORMATTERS
state.add_formatter = function(callback)
table.insert(current.formatters, 1, callback)
end
state.remove_formatter = function(callback, s)
s = s or current
for i, v in ipairs(s.formatters) do
if v == callback then
table.remove(s.formatters, i)
break
end
end
-- wasn't found, so traverse up 1 state
if s.previous then
state.remove_formatter(callback, s.previous)
end
end
state.format_argument = function(val, s, fmtargs)
s = s or current
for _, fmt in ipairs(s.formatters) do
local valfmt = fmt(val, fmtargs)
if valfmt ~= nil then return valfmt end
end
-- nothing found, check snapshot 1 up in list
if s.previous then
return state.format_argument(val, s.previous, fmtargs)
end
return nil -- end of list, couldn't format
end
-- PARAMETERS
state.set_parameter = function(name, value)
if value == nil then value = nilvalue end
current.parameters[name] = value
end
state.get_parameter = function(name, s)
s = s or current
local val = s.parameters[name]
if val == nil and s.previous then
-- not found, so check 1 up in list
return state.get_parameter(name, s.previous)
end
if val ~= nilvalue then
return val
end
return nil
end
-- SPIES / STUBS
state.add_spy = function(spy)
current.spies[spy] = true
end
state.snapshot() -- create initial state
return state

View file

@ -0,0 +1,107 @@
-- module will return a stub module table
local assert = require 'luassert.assert'
local spy = require 'luassert.spy'
local util = require 'luassert.util'
local unpack = util.unpack
local pack = util.pack
local stub = {}
function stub.new(object, key, ...)
if object == nil and key == nil then
-- called without arguments, create a 'blank' stub
object = {}
key = ""
end
local return_values = pack(...)
assert(type(object) == "table" and key ~= nil, "stub.new(): Can only create stub on a table key, call with 2 params; table, key", util.errorlevel())
assert(object[key] == nil or util.callable(object[key]), "stub.new(): The element for which to create a stub must either be callable, or be nil", util.errorlevel())
local old_elem = object[key] -- keep existing element (might be nil!)
local fn = (return_values.n == 1 and util.callable(return_values[1]) and return_values[1])
local defaultfunc = fn or function()
return unpack(return_values)
end
local oncalls = {}
local callbacks = {}
local stubfunc = function(...)
local args = util.make_arglist(...)
local match = util.matchoncalls(oncalls, args)
if match then
return callbacks[match](...)
end
return defaultfunc(...)
end
object[key] = stubfunc -- set the stubfunction
local s = spy.on(object, key) -- create a spy on top of the stub function
local spy_revert = s.revert -- keep created revert function
s.revert = function(self) -- wrap revert function to restore original element
if not self.reverted then
spy_revert(self)
object[key] = old_elem
self.reverted = true
end
return old_elem
end
s.returns = function(...)
local return_args = pack(...)
defaultfunc = function()
return unpack(return_args)
end
return s
end
s.invokes = function(func)
defaultfunc = function(...)
return func(...)
end
return s
end
s.by_default = {
returns = s.returns,
invokes = s.invokes,
}
s.on_call_with = function(...)
local match_args = util.make_arglist(...)
match_args = util.copyargs(match_args)
return {
returns = function(...)
local return_args = pack(...)
table.insert(oncalls, match_args)
callbacks[match_args] = function()
return unpack(return_args)
end
return s
end,
invokes = function(func)
table.insert(oncalls, match_args)
callbacks[match_args] = function(...)
return func(...)
end
return s
end
}
end
return s
end
local function set_stub(state, arguments)
state.payload = arguments[1]
state.failure_message = arguments[2]
end
assert:register("modifier", "stub", set_stub)
return setmetatable(stub, {
__call = function(self, ...)
-- stub originally was a function only. Now that it is a module table
-- the __call method is required for backward compatibility
return stub.new(...)
end
})

View file

@ -0,0 +1,362 @@
local util = {}
local arglist_mt = {}
-- have pack/unpack both respect the 'n' field
local _unpack = table.unpack or unpack
local unpack = function(t, i, j) return _unpack(t, i or 1, j or t.n or #t) end
local pack = function(...) return { n = select("#", ...), ... } end
util.pack = pack
util.unpack = unpack
function util.deepcompare(t1,t2,ignore_mt,cycles,thresh1,thresh2)
local ty1 = type(t1)
local ty2 = type(t2)
-- non-table types can be directly compared
if ty1 ~= 'table' or ty2 ~= 'table' then return t1 == t2 end
local mt1 = debug.getmetatable(t1)
local mt2 = debug.getmetatable(t2)
-- would equality be determined by metatable __eq?
if mt1 and mt1 == mt2 and mt1.__eq then
-- then use that unless asked not to
if not ignore_mt then return t1 == t2 end
else -- we can skip the deep comparison below if t1 and t2 share identity
if rawequal(t1, t2) then return true end
end
-- handle recursive tables
cycles = cycles or {{},{}}
thresh1, thresh2 = (thresh1 or 1), (thresh2 or 1)
cycles[1][t1] = (cycles[1][t1] or 0)
cycles[2][t2] = (cycles[2][t2] or 0)
if cycles[1][t1] == 1 or cycles[2][t2] == 1 then
thresh1 = cycles[1][t1] + 1
thresh2 = cycles[2][t2] + 1
end
if cycles[1][t1] > thresh1 and cycles[2][t2] > thresh2 then
return true
end
cycles[1][t1] = cycles[1][t1] + 1
cycles[2][t2] = cycles[2][t2] + 1
for k1,v1 in next, t1 do
local v2 = t2[k1]
if v2 == nil then
return false, {k1}
end
local same, crumbs = util.deepcompare(v1,v2,nil,cycles,thresh1,thresh2)
if not same then
crumbs = crumbs or {}
table.insert(crumbs, k1)
return false, crumbs
end
end
for k2,_ in next, t2 do
-- only check whether each element has a t1 counterpart, actual comparison
-- has been done in first loop above
if t1[k2] == nil then return false, {k2} end
end
cycles[1][t1] = cycles[1][t1] - 1
cycles[2][t2] = cycles[2][t2] - 1
return true
end
function util.shallowcopy(t)
if type(t) ~= "table" then return t end
local copy = {}
setmetatable(copy, getmetatable(t))
for k,v in next, t do
copy[k] = v
end
return copy
end
function util.deepcopy(t, deepmt, cache)
local spy = require 'luassert.spy'
if type(t) ~= "table" then return t end
local copy = {}
-- handle recursive tables
local cache = cache or {}
if cache[t] then return cache[t] end
cache[t] = copy
for k,v in next, t do
copy[k] = (spy.is_spy(v) and v or util.deepcopy(v, deepmt, cache))
end
if deepmt then
debug.setmetatable(copy, util.deepcopy(debug.getmetatable(t), false, cache))
else
debug.setmetatable(copy, debug.getmetatable(t))
end
return copy
end
-----------------------------------------------
-- Copies arguments as a list of arguments
-- @param args the arguments of which to copy
-- @return the copy of the arguments
function util.copyargs(args)
local copy = {}
setmetatable(copy, getmetatable(args))
local match = require 'luassert.match'
local spy = require 'luassert.spy'
for k,v in pairs(args) do
copy[k] = ((match.is_matcher(v) or spy.is_spy(v)) and v or util.deepcopy(v))
end
return { vals = copy, refs = util.shallowcopy(args) }
end
-----------------------------------------------
-- Clear an arguments or return values list from a table
-- @param arglist the table to clear of arguments or return values and their count
-- @return No return values
function util.cleararglist(arglist)
for idx = arglist.n, 1, -1 do
util.tremove(arglist, idx)
end
arglist.n = nil
end
-----------------------------------------------
-- Test specs against an arglist in deepcopy and refs flavours.
-- @param args deepcopy arglist
-- @param argsrefs refs arglist
-- @param specs arguments/return values to match against args/argsrefs
-- @return true if specs match args/argsrefs, false otherwise
local function matcharg(args, argrefs, specs)
local match = require 'luassert.match'
for idx, argval in pairs(args) do
local spec = specs[idx]
if match.is_matcher(spec) then
if match.is_ref_matcher(spec) then
argval = argrefs[idx]
end
if not spec(argval) then
return false
end
elseif (spec == nil or not util.deepcompare(argval, spec)) then
return false
end
end
for idx, spec in pairs(specs) do
-- only check whether each element has an args counterpart,
-- actual comparison has been done in first loop above
local argval = args[idx]
if argval == nil then
-- no args counterpart, so try to compare using matcher
if match.is_matcher(spec) then
if not spec(argval) then
return false
end
else
return false
end
end
end
return true
end
-----------------------------------------------
-- Find matching arguments/return values in a saved list of
-- arguments/returned values.
-- @param invocations_list list of arguments/returned values to search (list of lists)
-- @param specs arguments/return values to match against argslist
-- @return the last matching arguments/returned values if a match is found, otherwise nil
function util.matchargs(invocations_list, specs)
-- Search the arguments/returned values last to first to give the
-- most helpful answer possible. In the cases where you can place
-- your assertions between calls to check this gives you the best
-- information if no calls match. In the cases where you can't do
-- that there is no good way to predict what would work best.
assert(not util.is_arglist(invocations_list), "expected a list of arglist-object, got an arglist")
for ii = #invocations_list, 1, -1 do
local val = invocations_list[ii]
if matcharg(val.vals, val.refs, specs) then
return val
end
end
return nil
end
-----------------------------------------------
-- Find matching oncall for an actual call.
-- @param oncalls list of oncalls to search
-- @param args actual call argslist to match against
-- @return the first matching oncall if a match is found, otherwise nil
function util.matchoncalls(oncalls, args)
for _, callspecs in ipairs(oncalls) do
-- This lookup is done immediately on *args* passing into the stub
-- so pass *args* as both *args* and *argsref* without copying
-- either.
if matcharg(args, args, callspecs.vals) then
return callspecs
end
end
return nil
end
-----------------------------------------------
-- table.insert() replacement that respects nil values.
-- The function will use table field 'n' as indicator of the
-- table length, if not set, it will be added.
-- @param t table into which to insert
-- @param pos (optional) position in table where to insert. NOTE: not optional if you want to insert a nil-value!
-- @param val value to insert
-- @return No return values
function util.tinsert(...)
-- check optional POS value
local args = {...}
local c = select('#',...)
local t = args[1]
local pos = args[2]
local val = args[3]
if c < 3 then
val = pos
pos = nil
end
-- set length indicator n if not present (+1)
t.n = (t.n or #t) + 1
if not pos then
pos = t.n
elseif pos > t.n then
-- out of our range
t[pos] = val
t.n = pos
end
-- shift everything up 1 pos
for i = t.n, pos + 1, -1 do
t[i]=t[i-1]
end
-- add element to be inserted
t[pos] = val
end
-----------------------------------------------
-- table.remove() replacement that respects nil values.
-- The function will use table field 'n' as indicator of the
-- table length, if not set, it will be added.
-- @param t table from which to remove
-- @param pos (optional) position in table to remove
-- @return No return values
function util.tremove(t, pos)
-- set length indicator n if not present (+1)
t.n = t.n or #t
if not pos then
pos = t.n
elseif pos > t.n then
local removed = t[pos]
-- out of our range
t[pos] = nil
return removed
end
local removed = t[pos]
-- shift everything up 1 pos
for i = pos, t.n do
t[i]=t[i+1]
end
-- set size, clean last
t[t.n] = nil
t.n = t.n - 1
return removed
end
-----------------------------------------------
-- Checks an element to be callable.
-- The type must either be a function or have a metatable
-- containing an '__call' function.
-- @param object element to inspect on being callable or not
-- @return boolean, true if the object is callable
function util.callable(object)
return type(object) == "function" or type((debug.getmetatable(object) or {}).__call) == "function"
end
-----------------------------------------------
-- Checks an element has tostring.
-- The type must either be a string or have a metatable
-- containing an '__tostring' function.
-- @param object element to inspect on having tostring or not
-- @return boolean, true if the object has tostring
function util.hastostring(object)
return type(object) == "string" or type((debug.getmetatable(object) or {}).__tostring) == "function"
end
-----------------------------------------------
-- Find the first level, not defined in the same file as the caller's
-- code file to properly report an error.
-- @param level the level to use as the caller's source file
-- @return number, the level of which to report an error
function util.errorlevel(level)
local level = (level or 1) + 1 -- add one to get level of the caller
local info = debug.getinfo(level)
local source = (info or {}).source
local file = source
while file and (file == source or source == "=(tail call)") do
level = level + 1
info = debug.getinfo(level)
source = (info or {}).source
end
if level > 1 then level = level - 1 end -- deduct call to errorlevel() itself
return level
end
-----------------------------------------------
-- Extract modifier and namespace keys from list of tokens.
-- @param nspace the namespace from which to match tokens
-- @param tokens list of tokens to search for keys
-- @return table, list of keys that were extracted
function util.extract_keys(nspace, tokens)
local namespace = require 'luassert.namespaces'
-- find valid keys by coalescing tokens as needed, starting from the end
local keys = {}
local key = nil
local i = #tokens
while i > 0 do
local token = tokens[i]
key = key and (token .. '_' .. key) or token
-- find longest matching key in the given namespace
local longkey = i > 1 and (tokens[i-1] .. '_' .. key) or nil
while i > 1 and longkey and namespace[nspace][longkey] do
key = longkey
i = i - 1
token = tokens[i]
longkey = (token .. '_' .. key)
end
if namespace.modifier[key] or namespace[nspace][key] then
table.insert(keys, 1, key)
key = nil
end
i = i - 1
end
-- if there's anything left we didn't recognize it
if key then
error("luassert: unknown modifier/" .. nspace .. ": '" .. key .."'", util.errorlevel(2))
end
return keys
end
-----------------------------------------------
-- store argument list for return values of a function in a table.
-- The table will get a metatable to identify it as an arglist
function util.make_arglist(...)
local arglist = { ... }
arglist.n = select('#', ...) -- add values count for trailing nils
return setmetatable(arglist, arglist_mt)
end
-----------------------------------------------
-- check a table to be an arglist type.
function util.is_arglist(object)
return getmetatable(object) == arglist_mt
end
return util