咖哩 lua |
|
你可以在任何支援函式為一階物件的語言中實作咖哩函式。例如,這裡有一個小型的[咖哩 JavaScript 簡介]。
以下是 lua 的一個簡單咖哩函式範例
function sum(number) return function(anothernumber) return number + anothernumber end end local f = sum(5) print(f(3)) --> 8
-- WalterCruz
以下另一個由 [GavinWraith] 提供,它會接受以「()
」結束的變數個數量的輸入
function addup(x) local sum = 0 local function f(n) if type(n) == "number" then sum = sum + n return f else return sum end end return f(x) end print(addup (1) (2) (3) ()) --> 6 print(addup (4) (5) (6) ()) --> 15
雖然這些預先作咖哩化的函式很實用,我們真正想做的是建立一個一般用途的函式,可以在任何其他函式上執行咖哩化作業。要做到這一點,我們需要了解函式可以由「高階函式」作業,也就是將函式視為輸入的函式。以下咖哩函式就是一個例子,它將一個 2 輸入的函式做咖哩化
function curry(f) return function (x) return function (y) return f(x,y) end end end powcurry = curry(math.pow) powcurry (2) (4) --> 16 pow2 = powcurry(2) pow2(3) --> 8 pow2(4) --> 16 pow2(8) --> 256
將咖哩化 2 個輸入變成咖哩化『n』個輸入會再複雜一些。我們需要儲存一個不確定的部分應用函式數量,而且不幸的是 lua 沒有辦法知道函式需要多少輸入參數;lua 函式可以順利接受任何數量的輸入值,不管是太多還是太少。因此,有必要告訴咖哩函式在套用這些收集好的輸入值到原始函式之前,它要接受多少個單一輸入呼叫。
(這個程式碼可以從 http://tinylittlelife.org/?p=249 免費取得,並且包含完全探討如何處理此問題的內容。)
-- curry(func, num_args) : take a function requiring a tuple for num_args arguments -- and turn it into a series of 1-argument functions -- e.g.: you have a function dosomething(a, b, c) -- curried_dosomething = curry(dosomething, 3) -- we want to curry 3 arguments -- curried_dosomething (a1) (b1) (c1) -- returns the result of dosomething(a1, b1, c1) -- partial_dosomething1 = curried_dosomething (a_value) -- returns a function -- partial_dosomething2 = partial_dosomething1 (b_value) -- returns a function -- partial_dosomething2 (c_value) -- returns the result of dosomething(a_value, b_value, c_value) function curry(func, num_args) -- currying 2-argument functions seems to be the most popular application num_args = num_args or 2 -- no sense currying for 1 arg or less if num_args <= 1 then return func end -- helper takes an argtrace function, and number of arguments remaining to be applied local function curry_h(argtrace, n) if 0 == n then -- kick off argtrace, reverse argument list, and call the original function return func(reverse(argtrace())) else -- "push" argument (by building a wrapper function) and decrement n return function (onearg) return curry_h(function () return onearg, argtrace() end, n - 1) end end end -- push the terminal case of argtrace into the function first return curry_h(function () return end, num_args) end -- reverse(...) : take some tuple and return a tuple of elements in reverse order -- -- e.g. "reverse(1,2,3)" returns 3,2,1 function reverse(...) --reverse args by building a function to do it, similar to the unpack() example local function reverse_h(acc, v, ...) if 0 == select('#', ...) then return v, acc() else return reverse_h(function () return v, acc() end, ...) end end -- initial acc is the end of the list return reverse_h(function () return end, ...) end
以上的程式碼相容於 lua 5.1。
由於 lua 5.2 (或 LuaJIT 2.0) 提供了進階的 debug.getinfo 函式,讓我們知道函式需要多少個輸入值,我們可以建立一個實用的函式,混合咖哩化與部分應用技巧。以下是程式碼
function curry(func, num_args) num_args = num_args or debug.getinfo(func, "u").nparams if num_args < 2 then return func end local function helper(argtrace, n) if n < 1 then return func(unpack(flatten(argtrace))) else return function (...) return helper({argtrace, ...}, n - select("#", ...)) end end end return helper({}, num_args) end function flatten(t) local ret = {} for _, v in ipairs(t) do if type(v) == 'table' then for _, fv in ipairs(flatten(v)) do ret[#ret + 1] = fv end else ret[#ret + 1] = v end end return ret end function multiplyAndAdd (a, b, c) return a * b + c end curried_multiplyAndAdd = curry(multiplyAndAdd) multiplyBySevenAndAdd = curried_multiplyAndAdd(7) multiplySevenByEightAndAdd_v1 = multiplyBySevenAndAdd(8) multiplySevenByEightAndAdd_v2 = curried_multiplyAndAdd(7, 8) assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v1(9)) assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v2(9)) assert(multiplyAndAdd(7, 8, 9) == multiplyBySevenAndAdd(8, 9)) assert(multiplyAndAdd(7, 8, 9) == curried_multiplyAndAdd(7, 8, 9))