Jump to content

မေႃႇၵျူး:memoize

လုၵ်ႉတီႈ ဝိၵ်ႇသျိၼ်ႇၼရီႇ မႃး

Documentation for this module may be created at မေႃႇၵျူး:memoize/doc

local format = string.format
local select = select
local unpack = unpack

----- M E M O I Z A T I O N-----
-- Memoizes a function or callable table.
-- Supports any number of arguments and return values.
-- If the optional parameter `simple` is set, then the memoizer will use a faster implementation, but this is only compatible with one argument and one return value. If `simple` is set, additional arguments will be accepted, but this should only be done if those arguments will always be the same.

-- Sentinels.
local nil_, neg_0, pos_nan, neg_nan

-- Certain values can't be used as table keys, so they require sentinels as well: e.g. f("foo", nil, "bar") would be memoized at memo["foo"][nil_]["bar"][memo]. These values are:
	-- nil.
	-- -0, which is equivalent to 0 in most situations, but becomes "-0" on conversion to string; it also behaves differently in some operations (e.g. 1/a evaluates to inf if a is 0, but -inf if a is -0).
	-- NaN and -NaN, which are the only values for which n == n is false; they only seem to differ on conversion to string ("nan" and "-nan").
local function get_key(input)
	-- nil
	if input == nil then
		if not nil_ then
			nil_ = {}
		end
		return nil_
	-- -0
	elseif input == 0 and 1 / input < 0 then
		if not neg_0 then
			neg_0 = {}
		end
		return neg_0
	-- Default
	elseif input == input then
		return input
	-- NaN
	elseif format("%f", input) == "nan" then
		if not pos_nan then
			pos_nan = {}
		end
		return pos_nan
	-- -NaN
	elseif not neg_nan then
		neg_nan = {}
	end
	return neg_nan
end

-- Return values are memoized as tables of return values, which are looked up using each input argument as a key, followed by `memo`. e.g. if the input arguments were (1, 2, 3), the memo would be located at t[1][2][3][memo]. `memo` is always used as the final lookup key so that (for example) the memo for f(1, 2, 3), f[1][2][3][memo], doesn't interfere with the memo for f(1, 2), f[1][2][memo].
local function get_memo(memo, n, nargs, key, ...)
	key = get_key(key)
	local next_memo = memo[key]
	if next_memo == nil then
		next_memo = {}
		memo[key] = next_memo
	end
	memo = next_memo
	return n == nargs and memo or get_memo(memo, n + 1, nargs, ...)
end

-- Catch the function output values, and return the hidden variable arg (which is {...}, and available when a function has ...). We do this instead of catching the output in a table directly, because arg also contains the key "n", which is equal to select("#", ...). i.e. it's the number of arguments in ..., including any nils returned after the last non-nil value (e.g. select("#", nil) == 1, select("#") == 0, select("#", nil, "foo", nil, nil) == 4 etc.). The distinction between nil and nothing affects some native functions (e.g. tostring() throws an error, but tostring(nil) returns "nil"), so it needs to be reconstructable from the memo.
local function catch_output(...)
	-- TODO uses arg; will not work if Scribunto is upgraded to Lua 5.2, 5.3, etc.
	return arg
end

return function(func, simple)
	local memo
	return simple and function(...)
		local key = get_key(...)
		if not memo then
			memo = {}
		end
		local output = memo[key]
		if output == nil then
			output = func(...)
			if output ~= nil then
				memo[key] = output
				return output
			elseif not nil_ then
				nil_ = {}
			end
			memo[key] = nil_
			return nil
		elseif output == nil_ then
			return nil
		end
		return output
	end or function(...)
		local nargs = select("#", ...)
		if not memo then
			memo = {}
		end
		-- Since all possible inputs need to be memoized (including true, false and nil), the memo table itself is used as the key for the arguments.
		local _memo = nargs == 0 and memo or get_memo(memo, 1, nargs, ...)
		local output = _memo[memo]
		if output == nil then
			output = catch_output(func(...))
			_memo[memo] = output
		end
		-- Unpack from 1 to the original number of return values (memoized as output.n); unpack returns nil for any values not in output.
		return unpack(output, 1, output.n)
	end
end