符號微分 |
|
符號代數的第一步是定義表徵。將表達式放入適當的形式實際上很直接;不需要剖析表達式,因為我們已有 Lua 幫我們那樣做。使用 pl.func
函式庫會完成所有繁重的工作;它重新定義算術運算以作用於佔位符表達式 (PE),其中包含稱為佔位符的虛擬變數的 Lua 表達式。pl.func
為稱為 _1
、_2
等的參數定義標準佔位符,但 Var
函式會建立我們選擇的新佔位符
utils.import 'pl.func' a,b,c,d = Var 'a,b,c,d' print(a+b+c+d)
這的確會以可讀形式印出表達式。PE 運算符表達式儲存為類似的表格組合 {op='+',x,y}
,其中有一個相關的元表格,定義諸如 __add
等的後設方法。作為一棵樹,具有 Lua 運算符的慣常結合性,我們得到
繪製這些圖形很討厭,因此更好的表記法是 Lisp 風格的 S 表達式
1: (+ (+ (+ a b) c) d)
然而,藉由我們將執行的各種操作,這種標準形式並不只可能是 a+b+c+d
的表徵
2: (+ a (+ b (+ c d))) 3: (+ (+ a b) (+ c d))
現在,經驗顯示這會導致瘋狂。取而代之的是,轉而使用標準的 Lisp 表徵會比較容易
4: (+ a b c d)
一旦這麼做,許多運算就會很直接,例如與 (+ a c b d)
比較就只是對參數執行「不帶順序地比較」。以這種形式顯示 PE 很直接。isPE
只要檢查表達式即可得知是否為佔位符表達式,方法是查看元表格。op=='X'
的 PE 是佔位符變數,因此其餘的都一定是表達式節點。
function sexpr (e) if isPE(e) then if e.op ~= 'X' then local args = tablex.imap(sexpr,e) return '('..e.op..' '..table.concat(args,' ')..')' else return e.repr end else return tostring(e) end end
第一個工作是平衡表達式,這將表徵 1-3 轉換為 4。
function balance (e) if isPE(e) and e.op ~= 'X' then local op,args = e.op if op == '+' or op == '*' then args = rcollect(e) else args = imap(balance,e) end for i = 1,#args do e[i] = args[i] end end return e end
對於非交換式運算符,這個想法也只是透過對 PE 的陣列部分(也就是參數清單)實作 balance
來平衡所有子表達式。然後將其原封不動地複製回來。非必要的部份是處理 + 和 *,其中有必要從看起來像 1、2 或 3 的表達式樹收集所有參數,並將其轉換為第四種形式。
function tcollect (op,e,ls) if isPE(e) and e.op == op then for i = 1,#e do tcollect(op,e[i],ls) end else ls:append(e) return end end function rcollect (e) local res = List() tcollect(e.op,e,res) return res end
這會遞迴地沿著相同運算符鏈向下移動(前面提到的 (+ (+ ...)
),並收集參數,將其扁平化成 n 元 + 或 * 表達式。
以下是遵循相同遞迴模式的有用函式
-- does this PE contain a reference to x? function references (e,x) if isPE(e) then if e.op == 'X' then return x.repr == e.repr else return find_if(e,references,x) end else return false end end
以下是建立 n 元乘積和總和的函式
function muli (args) return PE{op='*',unpack(args)} end function addi (args) return PE{op='+',unpack(args)} end
有了這些函式,基本的微分規則不難。首先,只考慮真的包含變數的子表達式
function diff (e,x) if isPE(e) and references(e,x) then local op = e.op if op == 'X' then return 1 else local a,b = e[1],e[2] if op == '+' then -- differentiation is linear local args = imap(diff,e,x) return balance(addi(args)) elseif op == '*' then -- product rule local res,d,ee = {} for i = 1,#e do d = fold(diff(e[i],x)) if d ~= 0 then ee = {unpack(e)} -- make a copy ee[i] = d append(res,balance(muli(ee))) end end if #res > 1 then return addi(res) else return res[1] end elseif op == '^' and isnumber(b) then -- power rule return b*x^(b-1) end end else return 0 end end
總和表達式的微分就是微分的總和。同樣地,imap
執行在子表達式上遞迴地套用函式的任務。建構結果後,我們重新平衡看看情況。
以下提供此處產品規則的一般形式,並明確檢查會產生零的項目 - 這是 fold
的工作,即將要討論的工作。
(uvw..)' = u'vw.. + uv'w... + uvw'... + ...
最後,是簡單的冪規則。請注意,由於這些運算子都是作用於 PE,因此結果可以用直接的方式表達。
事實上,如果您使用 1 式、二進制的 + 和 *,所有這些規則肯定會更清楚!但這樣一來簡化就變得令人難以忍受了。而且簡化(「縮疊」)是難以正確執行的步驟。fold
是相當長的函式,所以我將分區處理。
local op = e.op local addmul = op == '*' or op == '+' -- first fold all arguments local args = imap(fold,e) if not addmul and not find_if(args,isPE) then -- no placeholders in these args, we can fold the expression. local opfn = optable[op] if opfn then return opfn(unpack(args)) else return '?' end elseif addmul then
第一個 if
函式會尋找子表達式沒有符號的情況,也就是類似 2*5
或 10^2
的情況;在此情況下,常數可以完全摺疊。optable
(在 pl.operator
中定義)會提供運算子名稱和實際實作它們的函式之間的對應關係。
elseif op == '^' then if args[2] == 1 then return args[1] end -- identity if args[2] == 0 then return 1 end end return PE{op=op,unpack(args)}
此子句會清理 x^1
和 y^0
等從 diff
的冪次運算規則中自然產生的表達式。處理過 args
之後,便可以重新組合表達式。
此例程的主體處理難以處理的一對運算子, yaitu + 和 *。
-- split the args into two classes, PE args and non-PE args. local classes = List.partition(args,isPE) local pe,npe = classes[true],classes[false]
List.partition
函式會取得一個清單,以及一個只需傳入一個參數並傳回單一值的函式。結果會是一個表,其鍵值是傳回的值,而值則是函式傳回該值的元素清單。因此
List{1,2,3,4}:partition(function(x) return x > 2 end) --> {false={1,2},true={3,4}} List{'one',math.sin,10,20,{1,2}}:partition(type) --> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }
(數學上,這些稱為 等價類別,而 分割
會稱為 商集)
在此情況下,我們要區分非符號參數和符號參數;順序無關緊要。非符號參數 npe
可以摺疊成常數。此時運算子同一性規則便能發揮作用,因此我們可以捨去 (* 0 x)
並將 (+ 0 x)
簡化成 x
。
最後的簡化動作是取代重複的值,因此 (* x x)
應成為 (^ x 2)
,而 (+ x x x)
應成為 (* x 3)
。來自 pl.tablex
的 count_map
會執行此工作。它會取得一個類似清單的表,以及一個定義等價關係的函式,並傳回一個從值到其出現次數的對應關係表,因此 count_map{'a','b','a'}
會是 {a=2,b=1}
。
考量以下測試函式
function testdiff (e) balance(e) e = diff(e,x) balance(e) print('+ ',e) e = fold(e) print('- ',e) end
以及這些範例
testdiff(x^2+1) testdiff(3*x^2) testdiff(x^2 + 2*x^3) testdiff(x^2 + 2*a*x^3 + x^4) testdiff(2*a*x^3) testdiff(x*x*x)
我們會得到以下結果,顯示出為何需要像 fold
這樣的工作來處理 diff
的結果。
+ 2 * x ^ 1 + 0 - 2 * x + 3 * 2 * x - 6 * x + 2 * x ^ 1 + 2 * 3 * x ^ 2 - 2 * x + 6 * x ^ 2 + 2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3 - 6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x + 2 * a * 3 * x ^ 2 - 6 * a * x ^ 2 + 1 * x * x + x * 1 * x + x * x * 1 - x ^ 2 * 3
https://github.com/stevedonovan/Penlight/blob/master/examples/symbols.lua
https://github.com/stevedonovan/Penlight/blob/master/examples/test-symbols.lua