通用輸入演算法

lua-users home
wiki

Lua 中的通用函數和演算法

本文件說明了 func 函式庫的設施,該函式庫旨在讓輸入的工作更直覺。可以在以下位置找到原始碼:檔案:wiki_insecure/func.lua

有兩個特殊的輸入迭代器,numbers()words(),它們的工作方式類似於非常有用的 io.lines() 迭代器。要列印標準輸入中找到的所有字詞

-- words.lua
require 'func'
for w in words() do
  print(w)
end
要在某些文字上測試此功能,您可以在作業系統指令提示字元中輸入
$ lua words.lua < test.txt
列印迭代器產生的值是很常見的操作,因此 func 提供了非常方便的函數 printall()。它會將序列的所有成員寫入標準輸出。預設情況下,它會以空格分隔每行輸出 7 個項目,但您可以選擇變更這些值。在這種情況下,我們希望每個值在自己的行上
printall(words(),'\n')
numbers() 會建立輸入中找到的所有數值的序列。例如,要加總所有輸入數字
require 'func'
local s 
for x in numbers() do
  s = s + x
end
print(s)
加總在分析資料時是很常見的操作,因此 func 定義了一般的 sum() 函數。它會回傳加總值和欄位數目,因此可以輕鬆計算平均值。
local s,n = sum(numbers())
print('average value =',s/n)
請注意,這些迭代器會尋找適當的模式,因此它們不依賴於字詞或數字以空格分隔。numbers() 會尋找檔案中看起來像數字的所有內容,並會安全地忽略所有其他內容。因此,它對於高度註解的資料或輸出檔案很有用。這些迭代器會接收一個選用的額外參數,可以是檔案或字串。例如,要列印命令列參數傳遞的檔案中項目的加總值和數目
f = io.open(arg[1])
print(sum(numbers(f)))
f:close()
將迭代器的輸出收集為表格很有用。由於它非常簡單且具有指導性,以下是 copy() 的簡化定義
function copy(iter)
  local res = {}
  local k = 1
  for v in iter do
     res[k] = v
     k = k + 1
  end
  return res
end
以下範例會製作字串中找到的所有數字的陣列。顯然這是個簡單的案例,但必須從字串中萃取數字是很常見的情況,而且可能會很棘手。特別是,這會確保數字確實轉換正確 - 我曾不止一次因為 arr['1'] 和 arr[1] 不同而困擾!
t = copy(numbers '10 20 30') -- will be {10,20,30}
s = sum(list(t))             -- will be 60
請注意表格序列適配器 list(),它允許表格用作序列。使用這些函數來操作陣列很常見,因此如果您傳遞表格,會自動假設為 list()。要以特定格式列印數字陣列,可以使用類似 printall(t,' ',5,'%7.3f') 的指令來適當地格式化它們。以下是系統指令 sort 的實作方式,它使用 printall() 函數來輸出序列的每個值。我不能簡單地說 table.foreach(t,print),因為該操作會同時傳遞索引和值,因此我實際上還會取得行號!
t = copy(io.lines())
table.sort(t)
printall(t,'\n')   -- try table.foreach(t,print) and see!
使用 sort() 函數後,它就會變成一行式的
printall(sort(io.lines()),'\n')
您可以使用 slice() 函數來迭代序列的部分。這是通過一個迭代函數、一個起始索引和一個項目數量來實現的。例如,這是 head 命令的一個簡單版本;它顯示了輸入的前十行。
printall(slice(io.lines(),1,10),'\n')
有時我們只想計數一個序列;例如,這是計算文件中所有字數的完整腳本
require 'func'
print(count(words()))
使用這種形式時,count() 函數並不太有用。但是它可以用一個函數來選擇要計數的項目。例如,這給我提供了 Lua 文件中有多少個公共函數的粗略了解。(如果我沒有將匹配約束在開頭,它也會選取本地函數和匿名函數)
require 'func'
print(count(io.lines(),matching '^%s*function'))
其中 matching() 是以下簡單函數。它建立了一個封閉函數(與本地環境綁定的函數),並對序列中的每一個項目進行調用
function matching(s)
  local strfind = string.find
  return function(v)
    return strfind(v,s)
  end
end

當然您可以在這些操作中使用任何序列。如果您載入了非常有用的 lfs(Lua 文件系統)函式庫,則 t = copy_if(lfs.dir(path),matching '%.cpp$') 將使用擴展名為 .cpp 的所有文件填滿一個列表。

修改 count() 輸入的另一種有用的方法是使用 unique() 函數

-- number of unique words in a file
print(count(unique(words())))
unique() 函數並不是按照通常的方式實現的,後者需要先對序列進行排序。相反,它使用 count_map() 函數建立一個映射表,其中鍵為項目,值為計數。一旦我們有了 keys() 函數(這是 list() 函數的備選函數),其餘操作就很簡單了
function unique(iter)
  local t = count_map(iter)
  return keys(t)
end
經典的“計算文件中字數”範例為
table.foreach(count_map(words()),print)
比較兩個序列時,將它們 join() 起來會很有用。這將打印出兩個文件之間的差異
for x,y in join(numbers(f1),numbers(f2)) do
  print(x-y)
end

Lua 的 AWK 編程樣式

在我發現 Lua 之前,AWK 是我用來處理文字檔的最喜歡的語言。(我甚至用“AWK 是命令行的 Excel 等價物”這句口號說服了部分同事。)為了讓您品嘗一下,以下是一個完整的 AWK 程序,用於列印出文件的第 1 和第 3 行,並使用第 4 行進行縮放——請注意,對所有行的循環都是隱含的
{ print $1/$4, $3/$4 }
func 函式庫會為此目的提供迭代器 fields()。以下是等價的 Lua 程式碼
for x,y,z in fields{1,3,4} do
   print(x/z,y/z)
end
這是目前我最喜歡的單行程式碼。它計數第 7 列中有多少值大於 44000,並且速度大約為等效 AWK 程序(使用 MAWK 執行)的一半。這並不差,因為 AWK 的速度已經過優化以應付其專門任務!
print(count(fields{7},greater_than(44000)))
{ if ($7 > 44000) k++ } END { print(k) }
fields() 函數可以用任何輸入分隔符。這會從逗號分隔的檔案中讀取一組值——請注意,傳遞 n 而不傳遞欄位識別碼列表等於 {1,2,...n}
for x,y in fields({1,2},',',f) do ...
for x,y in fields(2,',',f) do ...  --equivalent--

效能及運算式

我認為很明顯的,在使用此泛型編程風格時,這些通用的運算表達地相當簡潔,但是依我對 C++ 中標準範本程式庫 (STL) 的經驗,在此時人們傾向提出兩個保留。第一個反對意見是,函式風格較不具效率。理論上這是正確的,但是實際上效率低落到什麼程度呢?例如,這裡是使用順序 random() 的轉錄記載,用來建立一個含隨機值的表格
> tt = copy(random(10000))
> = sum(tt)
5039.542771691  10000
這些運算在我的老筆電中幾乎是同時發生的,而我僅在 1e5 個項目時才開始注意到。對於 1e6 個項目,第一個運算需花費 2.14 秒,而明確迴圈只需花費 2.08 秒!如果我仔細使用區域變數,其時間會降至 1.92 秒,因此最理想的明確版本為
local t = {}
local random = math.random
for i = 1,1e6 do
   t[i] = random()
end
此範例顯示,用較長的方法進行運算並無令人信服的速度優勢。(我選擇此範例的精確原因,是因為它不涉及檔案輸入/輸出,這往往會主導 words()numbers() 的執行時間。)其優勢在於較少錯誤的程式碼;泛型編程人員認為明確迴圈「繁瑣且容易出錯」,如同 Stroustrup 所述。

第二個反對意見是,這會導致奇怪且不自然的程式碼。對於 C++ 來說,這確實有可能發生,這是因為(讓我們面對現實吧)C++ 實際上並未適合函式風格;沒有閉包,而且嚴格的靜態輸入會不斷造成阻礙,導致所有內容都必需為範本。此風格更適合 Lua - 使用 Boost Lambda 函式庫時,用 C++ 執行此程式並不會好讀一半

-- sum of squares of input data using an internal iterator
for_each(numbers(),function(v)
    s = s + v*v
end)
-- sum of squares of input data using an external iterator
for v in numbers() do
    s = s + v*v
end
其想法並不是要取代所有迴圈,而只是其中通用的泛型模式。此類程式碼會較易讀,因為任何明確迴圈都將更加顯眼。Lua 特別適合此風格,這在 C++ 中常會顯得勉強。

撰寫自訂輸入物件

如果 f 不是字串,則 words(f) 將使用檔案物件 f。事實上,f 可以是任何具有 read 方法的物件。此程式碼假設的內容是 f:read() 將會傳回下一行的輸入文字。這裡是一個較複雜的範例,我在其中建立了一個類別 Files,用來允許我們從檔案列表中讀取內容。其顯而易見的應用程式是模仿 AWK 的行為,讓命令列中的每個檔案都成為標準輸入的一部分。
Files = {}
 
function Files.create(list)
   local files = {}
   files.list = {}
   local n = table.getn(list)
   for i = 1,n do
      files.list[i] = list[i]
   end
   files.open_next = Files.open_next
   files.read = Files.read
   files:open_next()
   return files
end
 
function Files:open_next()
   if self.f then self.f:close() end
   local nf = table.remove(self.list,1)
   if nf then
      self.f = io.open(nf)
      return true
   else
      self.f = nil
      return false
   end
end
 
function Files:read()
  local ret = self.f:read()
  if not ret then
     if not self:open_next() then return nil end
     return self.f:read()
  else
     return ret
  end
end
我需要說明一個明顯的不一致問題。在讚揚無迴圈編程的樂趣後,Files.create() 中有一個傳統的複製表格迴圈。Lua 程式會傳遞一個名為 arg 的全域表格,其中包含命令列引數 arg[1]arg[2] 等。但其中也有 arg[0],也就是腳本名稱,還有 arg[-1] 即為實際的程式名稱。有問題的明確迴圈就是要確定我們不會複製那些欄位!
files = Files.create(arg)
printall(words(files))

關於實作和進一步開發的注意事項

絕大多數的 func 都是該主題下直接的變型;函式和迭代器當作閉包使用。PiL [ref?] 的第 7.1 節很好地說明了這些問題,而我使用 allwords 範例當作 words()numbers() 的基礎。fields() 最初是用一種天真的方式實作,輪流擷取每個欄位,但過後則透過建立自訂正規表示式改用一通呼叫 string.find() 來實作。例如,如果需要以逗號分隔的欄位 1 和 3,那麼 regexp 看起來就像這樣 - 欄位定義為 不是 逗號的任何內容,我們使用 () 擷取所需的欄位。
'%s*([^,]+),[^,]+,([^,])'
串列 的概念非常概括,這表示很容易將 func 作業與提供迭代器的任何函式庫一起使用。這通常會大幅簡化程式碼。例如,以下是 luasql 如何使用它的方式。想一下對查詢結果的所有列進行存取的標準方式
cur = con:execute 'SELECT * FROM [Event Summaries]'
mag = -9
row = cur:fetch({},'a')
while row do
  if row.Magnitude > mag then 
     mag = row.Magnitude
  end
  row = cur:fetch(row,'n')
end
cur:close()

我只要建立一個能持續追蹤 row 的迭代器,就能讓這個過程變得簡單

function rows(cursor)
  local row = {}
  return function()
    return cursor:fetch(row,'a')
  end
end

for row in rows(cur) do
   if row.Magnitude > mag then 
      mag = row.Magnitude    
   end
end
這已經是一個更好的迴圈了,因為我們不需要呼叫 cursor:fetch 兩次,而要尋找一個區域的 row。我們也可以實作一個等同於 fields 的函式
function column(fieldname,cursor)
  local row = {}
  return function()
    row = cur:fetch(row,'a')
    if not row then return nil 
    else return row[fieldname]
    end
  end
end

local minm,maxm = minmax(column('Magnitude',cur))
不再有任何明確的迴圈!當然,通常更有效率的做法是用 SQL WHERE 子句限制串列。下列這段有作用,但不是完成這項工作的最適方式
print(count(column('Magnitude',cur),greater_than(2)))

-- SteveDonovan


最近變更 · 偏好設定
編輯 · 歷程
上次編輯時間為 2007 年 7 月 21 日下午 6:49 GMT (diff)