函數元組

lua-users home
wiki

本文說明一種新穎的設計模式,能僅使用函數表達元組。

元組是不可變的物件序列。許多編程語言都支援元組,包括 [Python][Erlang],以及大多數的 [函數式語言]。Lua 字串是特定類型的元組,其元素僅限於單一字元。

由於元組不可變,因此無需複製即可共用。另一方面,無法修改元組;修改元組必須建立新的元組。

為了說明此概念,我們實作以元組 <x, y, z> 儲存三維空間中的點。

Lua 提供許多實作元組的方法;以下是改編自優秀教科書 [結構與電腦程式解讀] 的實作。

和艾伯森與蘇斯曼一樣,我們將元組表示成一個引數為函數的函數;引數本身必須是函數;我們可以將它視為一種方法或插槽存取控制項。

首先,我們需要一個建構函數和一些成員選擇器

function Point(_x, _y, _z)
  return function(fn) return fn(_x, _y, _z) end
end

function x(_x, _y, _z) return _x end
function y(_x, _y, _z) return _y end
function z(_x, _y, _z) return _z end

由此可知,Point 會使用三個引數 (點的座標) 並傳回一個函數;在此目的中,我們將傳回值視為不透明。呼叫含有函數的 Point,等於將組成元組的物件作為引數提供給該函數;選擇器只會傳回其中一個物件並忽略其他物件。

> p1 = Point(1, 2, 3)
> =p1(x)
1
> =p1(z)
3

不過,我們不限於選擇器;我們可以寫入任何任意函數

function vlength(_x, _y, _z)
  return math.sqrt(_x * _x + _y * _y + _z * _z)
end

> =p1(vlength)
3.7416573867739

現在,儘管我們無法修改元組,但我們可以寫入函數,建立具有特定修改的元組 (這類似於標準 Lua 函式庫中的 string.gsub)

function subst_x(_x)
  return function(_, _y, _z) return Point(_x, _y, _z) end
end
function subst_y(_y)
  return function(_x, _, _z) return Point(_x, _y, _z) end
end
function subst_z(_z)
  return function(_x, _y, _) return Point(_x, _y, _z) end
end

gsub 相同,這些函數不會影響原始點的內容

> p2 = p1(subst_x(42))
> =p1(x)
1
> =p2(x)
42

值得注意的是,我們可以使用接受任意引數序列的任何函數

> p2(print)
42      2       3

同樣,我們可以寫入將兩個點結合的函數

function vadd(v2)
  return function(_x, _y, _z)
    return Point(_x + v2(x), _y + v2(y), _z + v2(z))
  end
end

function vsubtract(v2)
  return function(_x, _y, _z)
    return Point(_x - v2(x), _y - v2(y), _z - v2(z))
  end
end

> =p1(vadd(p1))(print)
2       4       6

仔細檢視 vaddvsubtract (還有各種替換函數),會發現它們實際上會建立具有封閉 upvalue (它們的原始引數) 的暫時函數。不過,這些函數沒有必要是暫時的。事實上,我們實際上可能希望多次使用特定轉換,這種情況下我們可以將其儲存起來

> shiftDiagonally = vadd(Point(1, 1, 1))
> p2(print)
42      2       3
> p2(shiftDiagonally)(print)
43      3       4
> p2(shiftDiagonally)(shiftDiagonally)(print)
44      4       5

這可能會讓想要重新檢視 vadd 的定義,以避免建立、然後對引數解構

function subtractPoint(x, y, z)
  return function(_x, _y, _z) return _x - x, _y - y, _z - z end
end

function addPoint(x, y, z)
  return function(_x, _y, _z) return _x + x, _y + y, _z + z end
end

同時在檢視時,我們來新增其他幾項轉換

function scaleBy(q)
  return function(_x, _y, _z) return q * _x, q * _y, q * _z end
end

function rotateBy(theta)
  local sintheta, costheta = math.sin(theta), math.cos(theta)
  return function(_x, _y, _z)
    return _x * costheta - _y * sintheta, _x * sintheta + _y * costheta, _z
  end
end

請注意,在 rotateBy 中,我們預先計算正弦和餘弦,以便在每次套用函數時不需呼叫數學函數庫。

現在這些函式並非傳回 Point;它們僅傳回構成 Point 的值。若要使用它們,我們必須明確建立點

> p3 = Point(p1(scaleBy(10)))
> p3(print)
10      20      30

這樣有點麻煩。但我們將看到,這樣做有其優點。

但首先,讓我們再次檢視 addPoint。如果我們構思一個變換,這樣是很好的,但如果我們想要針對特定點位移呢?p1(addPoint(p2)) 顯然無法運作。然而,答案非常簡單

> centre = Point(0.5, 0.5, 0.5)
> -- This doesn't work
> =p1(subtractPoint(centre))
stdin:2: attempt to perform arithmetic on a function value
stack traceback:
        stdin:2: in function <stdin:2>
        (tail call): ?
        (tail call): ?
        [C]: ?
> -- But this works just fine:
> =p1(centre(subtractPoint))
0.5     1.5     2.5

此外,這些新函式可以進行組合;實際上,我們可以將一組變換建立為單一基本元素

-- A complex transformation
function transform(centre, expand, theta)
  local shift = centre(subtractPoint)
  local exp = scaleBy(expand)
  local rot = rotateBy(theta)
  local unshift = centre(addPoint)
  return function(_x, _y, _z)
    return unshift(exp(rot(shift(_x, _y, _z))))
  end
end

> xform = transform(centre, 10, math.pi / 4)
> =p1(xform)
-6.5710678118655        14.642135623731 25.5

這有一個巨大的好處,就是建立 xform 後,它可以在不建立任何堆疊物件的情況下執行。所有記憶體消耗都在堆疊之上。當然,這有點虛偽——為建立組元 (函式封裝和三個上層值) 以及建立個別變壓器做了相當多的儲存空間配置。

此外,我們尚未設法處理一些重要的語法問題,例如讓一般程式設計人員實際如何使用這些組元。

--RiciLake

推廣到任意大小 N

若要讓上述方案適用於任意大小的元組,我們可以使用 CodeGeneration,如下所示--DavidManura

function all(n, ...) return ... end     -- return all elements in tuple
function size(n) return n end           -- return size of tuple
function first(n,e, ...) return e end     -- return first element in tuple
function second(n,_,e, ...) return e end  -- return second element in tuple
function third(n,_,_,e, ...) return e end -- return third element in tuple
local nthf = {first, second, third}
function nth(n)
  return nthf[n] or function(...) return select(n+1, ...) end
end

local function make_tuple_equals(n)
  local ta, tb, te = {}, {}, {}
  for i=1,n do
    ta[#ta+1] = "a" .. i
    tb[#tb+1] = "b" .. i
    te[#te+1] = "a" .. i .. "==b" .. i
  end
  local alist = table.concat(ta, ",")
  if alist ~= "" then alist = "," .. alist end
  local blist = table.concat(tb, ",")
  if blist ~= "" then blist = "," .. blist end
  local elist = table.concat(te, " and ")
  if elist ~= "" then elist = "and " .. elist end
  local s = [[
    local t, n1 %s = ...
    local f = function(n2 %s)
      return n1==n2 %s
    end
    return t(f)
  ]]
  s = string.format(s, alist, blist, elist)
  return assert(loadstring(s))
end

local cache = {}
function equals(t)
  local n = t(size)
  local f = cache[n]; if not f then
    f = make_tuple_equals(n)
    cache[n] = f
  end
  return function(...) return f(t, ...) end
end

local function equals2(t1, t2)
  return t1(equals(t2))
end

local ops = {
  ['#'] = size,
  ['*'] = all,
}
local ops2 = {
  ["number"]   = function(x) return nth(x) end,
  ["function"] = function(x) return x end,
  ["string"]   = function(x) return ops[x] end
}

local function make_tuple_constructor(n)
  local ts = {}
  for i=1,n do ts[#ts+1] = "a" .. i end
  local slist = table.concat(ts, ",")
  local c = slist ~= "" and "," or ""
  local s =
    "local ops2 = ... " ..
    "return function(" .. slist .. ") " ..
    "  return function(f) " ..
     "    return (ops2[type(f)](f))(" ..
                 n .. c .. slist .. ") end " ..
    "end"
  return assert(loadstring(s))(ops2)
end

local cache = {}
function tuple(...)
  local n = select('#', ...)
  local f = cache[n]; if not f then
    f = make_tuple_constructor(n)
    cache[n] = f
  end
  return f(...)
end

測試

-- test suite
local t = tuple(1,nil,2,nil)
;(function(a,b,c,d) assert(a==1 and b==nil and c==2 and d==nil) end)(t(all))
;(function(a,b,c,d) assert(a==1 and b==nil and c==2 and d==nil) end)(t '*')
assert(t(size) == 4)
assert(t '#' == 4)
assert(t(nth(1)) == 1 and t(nth(2)) == nil and t(nth(3)) == 2 and
       t(nth(4)) == nil)
assert(t(1) == 1 and t(2) == nil and t(3) == 2 and t(4) == nil)
assert(t(first) == 1 and t(second) == nil and t(third) == 2)
local t = tuple(3,4,5,6)
assert(t(nth(1)) == 3 and t(nth(2)) == 4 and t(nth(3)) == 5 and
       t(nth(4)) == 6)
assert(t(first) == 3 and t(second) == 4 and t(third) == 5)
assert(tuple()(size) == 0 and tuple(3)(size) == 1 and tuple(3,4)(size) == 2)
assert(tuple(nil)(size) == 1)
assert(tuple(3,nil,5)(equals(tuple(3,nil,5))))
assert(not tuple(3,nil,5)(equals(tuple(3,1,5))))
assert(not tuple(3,nil)(equals(tuple(3,nil,5))))
assert(not tuple(3,5,nil)(equals(tuple(3,5))))
assert(tuple()(equals(tuple())))
assert(tuple(nil)(equals(tuple(nil))))
assert(tuple(1)(equals(tuple(1))))
assert(not tuple(1)(equals(tuple())))
assert(not tuple()(equals(tuple(1))))
assert(equals2(tuple(3,nil,5), tuple(3,nil,5)))
assert(not equals2(tuple(3,nil,5), tuple(3,1,5)))


-- example
function trace(f)
  return function(...)
    print("+function")
    local t = tuple(f(...))
    print("-function")
    return t(all)
  end
end
local test = trace(function (a,b,c)
  print("test",a+b+c)
end)
test(2,3,4)
--[[OUTPUT:
+function
test    9
-function
]]

意見

我認為這個頁面具有誤導性。這些不是元組。只有當元組按值比較時,才會很有用,以便可以在其中編制索引 (即用作表鍵)。如果沒有此屬性,它們就沒有比表更好。請參閱 [1],了解使用內部索引樹的 n 元組實作。--CosminApreutesei

另請參閱


RecentChanges · 偏好設定
編輯 · 歷史記錄
最後編輯 2014 年 9 月 12 日晚上 4:09 GMT (差異)