最佳化 Str Rep

lua-users home
wiki


[!] VersionNotice:下列程式碼屬於舊版本 Lua,即 Lua 4。無法在 Lua 5 中執行。

以下是 strrep 函式的 lua 版本,此版本的速度比原始以 C 寫成的程式碼快。平均而言,速度快約 3 倍。

所有測試都在 Pentium II 上進行,以單使用者模式執行 linux,並具有足夠的記憶體以避免換頁,以及 Lua 版本 4.0。

此演算法的靈感來自 LTN 9,其作者為 RobertoIerusalimschy,並由 LuizCarlosSilveira 編寫。

此演算法的核心準則是執行最少次數的串接。

僅供參考:與原始 strrep 相比,lua strrepO() 函式疑問較低,但這有待進一步探討。


繪製出其「重複次數」X「時間」曲線,我們可以觀察到這兩個函式的行為
時間以秒為單位,重複次數以位元組為單位)


以下是包含此演算法的程式。此程式用於測試實作是否正確,以及產生用於繪製曲線的資料。

function log2(n)
    local _n = 2
    local x = 1
    if (_n < n) then
        repeat
            x = x + 1
            _n = _n + _n
        until (_n >= n)
    elseif (_n > n) then
        if (n == 1) then
            return 0
        else
            return nil
        end
    end 
    if (_n > n) then
        return x-1
    else
        return x
    end 
end 
    
function get_bits(n)
    local bits = {}
    local rest = n
    repeat
        local major_bit = log2(rest)
        rest = rest - 2^major_bit
        bits[major_bit] = 1
        if (bits.count == nil) then
            bits.count = major_bit
        end
    until (rest == 0)
    return bits
end



function fast_strrep(str, times)
    local bits = get_bits(times)
    local strs = {[0] = str}

    local count = bits.count

    for i = 1, count do
        strs[i] = strs[i-1] .. strs[i-1]
    end

    local result = ''
    for i = 0, count do
        if (bits[i]) then
            result = result .. strs[i]
        end
    end

    return result

end

for numreps = 1024, 30*1024*1024, 1024*64 do

    a = nil
    b = nil
    collectgarbage()

    start = clock()
    a = fast_strrep("a", numreps)
    print("L:"..numreps.." "..(clock() - start))
    start = clock()
    b = strrep("a", numreps)
    print("C:"..numreps.." "..(clock() - start))

    if (a~=b) then
        print("the algorithm is wrong!")
    else
        print("ok")
    end

    flush(_STDOUT)

end

        


您開發的版本較快,這其實不讓我覺得驚訝;strrep 的 lua 函式庫版本(至少是 v 4.0)在函式呼叫方面負擔很重(每個字元呼叫一次),而串接函式則沒有此負擔。我覺得很奇怪,函式庫版本沒有直接找出需要多少個字元,並分配記憶體,然後重複將來源字串 memcpy 進去,我認為這樣應該快一個數量級,且為 O(MN + M)(M 份長度為 N 的字串複本)。但我猜想他們並不是真的考慮過要串接這麼龐大的資料。

您使用的演算法讓我想到,透過極小乘法演算法來運算指數。以下是一個版本,它利用了 Lua 將 a .. b .. c 最佳化為單一運算的事實,且還能避免建立暫存向量。我想您會發現此版本的速度大約是您的版本兩倍(而且短很多)。另外,這也可能是一個範例,說明如何在不造成太多負擔的情況下,執行看起來像是位元運算的作業。 -- RiciLake

  -- Suppose that x = b[n]*2^n + b[n-1]*2^(n-1) + ... + b[0]*2^0
  --   (where every b[i] is either 0 or 1)
  -- This is exactly equivalent to:
  --    b[0] + 2 * (b[1] + 2 * (b[2] + (... + b[n])))
  -- So we've effectively eliminated all the multiplications, replacing them with doubling.

  -- Now, x * y (for any y) can be calculated by distributing multiplication over the
  -- above, which effectively replaces every b[i] with b[i] * y. However, every b[i]
  -- is either 0 or 1, so the product is either 0 or y.

  -- Now, if k is an integer and str1 and str2 are strings, and we write:
  --   str1 + str2       for the concatenation of str1 and str2
  --   k * str1          for "k copies of str1 concatenated"
  -- we can see that we have + and * are "just like" integer arithmetic in the sense that
  -- + and * are commutative and associative, and * distributes over +. So the equivalence
  -- continues to work, except that every term must be either "" (for 0) or y (the string).

  -- All that is left is to compute the expression from the inside out: each step is
  -- either 2 * r or y + 2 * r, where r is the cumulated value and y is the original string.
  -- In string terms, we can write these as result .. result (2 * r) and
  -- result .. result .. str (2 * r + y)

  -- We could use the same idea to compute integer exponents in the minimum number of
  -- multiplications, using * and ^ instead of + and * (which is where this algorithm
  -- comes from.)

  -- This makes use of the fact that Lua optimises a .. b .. c into a single concatenation.
  -- With a bit more work, we could use any base we wanted to, not just base 2. But it would
  -- require more options in the if statement.

function another_strrep(str, times)
  local result = ""
  local high_bit = 1
  while high_bit < times do high_bit = high_bit * 2 end

  -- at this point, high_bit is the largest (integral) power of 2 smaller than times
  -- (unless times < 1 in which case high_bit is 1)
  -- The computation of highbit could be:
  --   local high_bit = 2 ^ floor(log(times) / log(2))
  -- which is probably faster but requires the math library

  -- we are now going to work through times, bit by bit, making use of the above formula:

  while high_bit >= 1 do
    if high_bit <= times then           -- the bit is 1 if times is >= high_bit
      times = times - high_bit          -- we "turn it off" for the next iteration
      result = result .. result .. str  -- and the next step is 2 * r + y
    else                                -- the bit is 0
      result = result .. result         -- so the next step is 2 * r
    end
    high_bit = high_bit / 2             -- Now go for the next bit
  end
  return result
end


您的演算法非常好。謝謝您。我們兩個(您和我的)的想法幾乎一樣,對吧?當我寫程式碼時,您執行了我無法做到的極佳化,這能防止產生補助部分。我繪製了這三個演算法的曲線,如下方所示。對於繪製在此曲線中的資料,三個函式的精確平均關係為
luiz/rici = 1.41  (rici is  1.41  times faster than luiz)
c/luiz    = 2.98  (luiz is  2.98  times faster than c)
c/rici    = 4.19  (rici is  4.19  times faster than c)
        

由於您的演算法較快(看上去使用較少的記憶體),因此屬於此頁面。如果您不同意此想法,我會移除我的演算法,並替換成您的。但是,在開始之前,請問您一件事:可以請您在程式碼中撰寫一些註解,以清楚說明此演算法嗎?由於經過您最佳化後,背後的概念已模糊不清…… --LuizCarlosSilveira

好的,我已經強制性地撰寫註解了。希望已經清楚;有時,我覺得程式碼本身比較清楚。此函式並非最佳化後可能達到的狀態,因為執行合併的次數比必要的次數多一次……我嘗試讓程式碼簡潔,因此倚賴 Lua 以極快的速度執行 "" .. "" .. str

只為好玩,也為了展示某樣事物(我不確定是什麼),我在上方演算法中新增了進位 10 版本。我使用 gsub 執行迴圈並將重複次數轉換為字串來找出數位,而不是使用條件式陳述式並進行逐位元計算。表格查詢(可能)比一連串的 if 陳述式快很多,因此我也使用了表格查詢。%state 是標準技巧,用於解決 Lua 4.0 沒有真正封閉式的事實。

我不宣稱這個函式易於閱讀,但是我的測試指出,它是更快的函式。(抱歉,沒有其他註解,但概念相同,所以您應該能夠找出原理。 :-) )只要展示出如果您的思考足夠異常,您就能做到什麼即可。我在某處還有這種情況的另一個範例:我編寫的 join 函式可以輕鬆編譯子程式,以執行問題的相同類型指數分解;即使它必須組合並編譯函式,但它比直接的 join 函式還要快很多。當然,函式必須經過備忘才能利用此特性。我也會嘗試張貼此程式。 -- RiciLake

do
  local concats = {
    ["0"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a end,
    ["1"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b end,
    ["2"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b end,
    ["3"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b end,
    ["4"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b end,
    ["5"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b end,
    ["6"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b .. b end,
    ["7"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b .. b .. b end,
    ["8"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b .. b .. b .. b end,
    ["9"] = function(a, b) return a .. a .. a .. a .. a .. a .. a .. a .. a .. a
                                    .. b .. b .. b .. b .. b .. b .. b .. b .. b end,
  }

  function decimal_strrep(str, times)
    local state = {r = "" }
    local concats = %concats
    times = tostring(times)
    if strfind(times, "^[0-9]+$") then
      gsub(times, "(.)",
           function(digit)
             %state.r = %concats[digit](%state.r, %str)
           end)
    end
    return state.r
  end
end

        


順帶一提,測試並不怎麼好,因為時間會隨著重複計數的二進位擴充中 1 位元的數量變動。(我想我的版本對於這方面的影響稍低,但仍會是一個因素。)您使用的測試計數在二進位擴充中幾乎只有 0。 -- RiciLake


我相信這是您表示您的演算法應比我的快兩倍的原因。實際發生的情況是,當我使用僅開啟 1 個位元所形成的重複數字時,這是記憶體消耗的最佳情況。最差的情況是,我的演算法似乎比您的演算法使用兩倍的記憶體。我同意我所做的測量應予檢討,但是此頁面才剛開始…… --LuizCarlosSilveira


相當不錯。嘗試使用比「a」字串更長的字串進行測試應該也非常有趣。由於垃圾回收時間會隨著總堆積大小而有所不同,因此我從未弄清楚如何充分評定 Lua 程式。最好先執行程式幾次以使堆積大小穩定化,然後再進行計時。 ——RiciLake

最新變更 · 喜好設定
編輯 · 歷程記錄
最後編輯於 2017 年 9 月 21 日 晚上 9:29 GMT (差異)