Curry 化記值

lua-users home
wiki

另一種函式記值實作。負責處理 M 引數,N 回傳值函式的通用案例。且會稍微注意保留出現在引數或結果清單中的 nil。

例如

  local mtest = weak_memoize_m_to_n( function(...) print 'exec' return ... end )

  print( mtest(nil,2) )    --> "exec", "nil,2"
  print( mtest(nil,2) )    --> "nil,2"
  collectgarbage()
  print( mtest ( nil,2 ) ) --> "exec", "nil,2"
  print( mtest ( ) )       --> "exec", ""
  print( mtest ( nil ) )   --> "exec", "nil"

設計等同於使用引數樹技巧,例如 [memoize.lua]。不過,樹是隱含產生而非明示產生,其方式為遞迴縮減 M>1 案例為 M-1 案例。

Lua 程式碼

  local _ENV = setmetatable({},{__index=_G})

  -- this code can be made more memory and speed efficient by
  -- defining catch() in C.  but, a table.unpack approach will also
  -- work. 
  function catch(...)
    local rvals = {...}
    local n = select('#',...)
    return function()
      return table.unpack(rvals,1,n)
    end
  end

  local weak_mt= {__mode='kv'}
  local function weak_table() return setmetatable({},weak_mt) end
  local function strong_table() return {} end

  local null = {}

  local function arg2key(arg)
    return (arg == nil and null) or arg
  end

  -- build a memoization function that can handle the 1-argument 
  -- to n rvals case.  
  local function new_memoizer_1_to_n(newtable)
    return function(fun)
      local lookup = newtable()

      return function (arg)
        local k = arg2key(arg)
        local r=lookup[k]
        if r then
          return r()
        end
        r=catch( fun(arg) )
        lookup[k] = r
        return r()
      end

    end
  end

  local function new_memoizer_m_to_n( newtable, memoize_1_to_n )

    -- return a memoization of f that assumes m arguments.
    local function memoize_m_to_n(m,f)
      --  handle the m==0 case
      if m==0 then
        local memoized
        return function()
          if memoized then
            return memoized()
          end
          memoized = catch(f())
          return memoized()
        end
      end

      if m==1 then
        return memoize_1_to_n(f)
      end

      local lookup = newtable()

      -- handle the general m-to-n case, for m>=2. 
      return function(arg, ...)

        local k = arg2key(arg)
        local r = lookup[k]

        if r then
          return r(...)
        end
    
        -- create a new (m-1) argument memoizer that will handle 
        -- this arg value in the future.  
        r = memoize_m_to_n(m-1, function(...)
          return f(arg,...)
        end)

        lookup[k]=r
        return r(...)
      end
    end

    -- return a memoizer that dispatches between the different m-argument cases.
    return function(f)
      local m_to_memoized = newtable()
      return function(...)
        local m = select('#',...)
        local memoized = m_to_memoized[m]
        if memoized then
          return memoized(...)
        end
        memoized = memoize_m_to_n(m,f)
        m_to_memoized[m]=memoized
        return memoized(...)
      end
    end
  end

  weak_memoize_1_to_n =  new_memoizer_1_to_n(weak_table)
  strong_memoize_1_to_n =  new_memoizer_1_to_n(strong_table)

  weak_memoize_m_to_n = new_memoizer_m_to_n(weak_table,weak_memoize_1_to_n)
  strong_memoize_m_to_n = new_memoizer_m_to_n(strong_table,strong_memoize_1_to_n)

  return _ENV

在 C 中實作 catch()

透過在 C-API 內部實作 catch(),可以同時改善記憶體使用量和效能。儘管其儲存容量僅限於 255 個值,但與 Lua 的泛用表格相較之下,C 閉包是更精簡、快速的資料結構。

  static int throw_upvalues(lua_State *L) {
    int n1=lua_tointeger(L,lua_upvalueindex(1));
    luaL_checkstack(L,n1-1,"too many upvalues");
    for(int i=2; i<=n1; i++) {
      lua_pushvalue(L,lua_upvalueindex(i));
    }
    return n1-1;
  }

  static int catch_args(lua_State *L) {
    int n1 = lua_gettop(L)+1;
    if(n1>MAXUPVAL) {
      return luaL_error(L,"can't catch more than %d args. (catch() called with %d arguments).",MAXUPVAL-1, n1-1);
    }
    lua_pushinteger(L,n1);
    lua_insert(L,1);
    lua_pushcclosure(L,throw_upvalues,n1);
    return 1;
  }

另請參閱

網路上分散著許多其他 Lua 記值實作。FuncTables 頁面似乎是實際上的 wiki 連結集散地。不過,這個主題也常在 lua 使用者清單中出現;而檔案中張貼了幾個基於字串序列化方法的實作,都很棒 [1][2]


RecentChanges · 偏好設定
編輯 · 歷程
最後編輯於 2013 年 12 月 31 日,下午 7:50 GMT (差異)