拯救Transformer推理能力!DeepMind新研究TransNAR:給模型嵌入「算法推理大腦」

0 評論 836 瀏覽 1 收藏 19 分鐘

DeepMind最近發表的一篇論文提出用混合架構的方法解決Transformer模型的推理缺陷。將Transformer的NLU技能與基于GNN的神經算法推理器(NAR)的強大算法推理能力相結合,可以實現更加泛化、穩健、準確的LLM推理。

如今的NLP領域,已然是Transformer架構的天下。

從Bert到GPT,再到Llama、Claude,LLM模型使用Transformer已經是再正常不過的事情。

Transformer的「大一統」局面正是由于其簡單、高效的架構,以及在理解自然語言方面無與倫比的泛化能力。

然而,隨著研究的逐漸深入,Transformer的一個致命缺陷也逐漸暴露出來——無法勝任算法推理任務,尤其是不能進行精確、穩健的推理。

這嚴重限制了模型在數學、代碼等領域下游任務的應用,近年來對Transformer的各種調優、修改似乎也收效甚微。

于是DeepMind的研究人員想到了混合架構——將Transformers的語言理解能力與基于圖神經網絡(GNN)的神經算法推理器(NAR)的穩健性結合起來,提升其算法推理能力。

他們最近在arxiv上的一篇論文就提出了這個名為TransNAR的架構,但遺憾的是,目前還沒有公布源代碼。

論文地址:https://arxiv.org/abs/2406.09308

神經算法推理(NAR)由本文作者之一Petar Veleckovic在2021年與人合著的一篇論文中提出,并被接收為Patterns期刊的opinion paper。

論文地址:https://arxiv.org/abs/2105.02761

NAR被稱為「構建能執行算法的神經網絡的藝術」。作者提出,算法與深度學習的本質不同,但如果神經網絡能夠更好地模仿算法,它甚至可能具備算法的強泛化性。

更進一步,神經網絡若能表示出算法中連續空間內的元素,就會使已知算法更接近現實世界的問題,提出的解決方案可能超過人類科學家。

如上圖所示,NAR的整體想法是訓練出一個高維隱空間中的處理器網絡P(processor network),旨在不斷逼近算法的運行結果A(x)。

但由于算法的輸入和輸出一般是圖、樹、矩陣等抽象、結構化的形式,這與深度學習模型高維、嘈雜且多變的輸入很不兼容,因此還需要訓練編碼器f和解碼器g,將抽象形式轉換為自然形式。

NAR發布后,有多項研究證實了它有同時執行多種算法的能力,也能部署在各種下游任務中。更重要的是,它的泛化能力似乎遠遠優于Transformer架構。

原則上,NAR可以擴展到比訓練數據的分布大幾個數量級的系統上,有時這個數量級能達到1.8萬倍。

在使用適當的歸納偏差(inductive biases)時,即使輸入比訓練集大6倍,NAR也能在高度復雜的算法任務中保持完美的泛化能力。

找到了Transformer和NAR這兩種十分強大且各有所長的架構,下面最關鍵的問題就是如何進行相應的調整和修改,使這兩個似乎完全不相容的模型真正實現溝通和Embedding交換。

TransNAR:用預訓練NAR增強Transformer

如何實現NAR+Transformer的有效溝通?作者從多模態LLM中找到了靈感。

多模態LLM可以同時接收文本和圖像兩種模態的輸入,TransNAR也是如此。一邊是算法運行需要的圖結構,一邊是描述問題的自然語言。

作者的設想是,將預訓練的NAR作為Transformer中編碼的調制器(modulator),二者通過embedding溝通,同時借鑒VLM和Flamingo模型中所用的交叉注意算子,融合不同模態的信息。

TransNAR接受雙重輸入,包括文本形式的算法問題規范(T個token)及其對應的圖表征(N個節點),并輸出問題的文本答案。其中輸入的圖表征遵循算法推理基準CLRS-30的格式。

我們可以假設,編碼完成后,文本輸入存儲在T ∈ R^(T×k)中,圖輸入存儲在G ∈ R^(N×l)中。

TransNAR的前向傳播過程如下:

首先,我們通過設置T^(0) = T和G^(0) = G來正確初始化輸入。

接下來,為了計算第(t+1)步的表征,文本(token)表征被輸入到Transformer的當前層:

其中,Qt,Kt ∈ Rk×d_k,Vt ∈ Rk×k分別是鍵、查詢和值矩陣的變換,FFN是一個前饋神經網絡。

以類似的方式,圖表征被輸入到NAR層,例如實現一個標準的max-MPNN:

其中,ψ,? : Rk × Rk → Rk分別是可學習的消息函數和更新函數,max是逐元素最大值聚合。

需要注意的是,方程2僅簡要提供了節點之間的成對交互——實際上,這里的NAR是一個Triplet-GMPNN,它還包含三元組交互和一個門控機制。

此外,還需注意,NAR的可學習部分沒有時間步索引——每一步都應用相同的共享函數。這很好地契合了圖算法計算的迭代和重復性質。

一旦兩個流都準備好它們的表征Θt+1和Gt+1,圖中的節點嵌入將對Transformer的token嵌入進行條件設置,從而產生Transformer流中TransNAR塊的最終結果:

其中,Qt×,Kt× ∈ Rk×d_k, Vtx ∈ Rk×k分別是交叉注意力的鍵、查詢和值變換。在結束這一層之前,對Gt+1不進行額外的變換。

這個過程會一直重復,直到最后的第Nl層,在這一層中,從TN_l讀取最終的文本輸出。

最終輸出通過最后一層生成的預測頭轉換為token logits,并通過標準的下一個token預測來監督訓練。

在開始TransNAR微調之前,首先預訓練NAR,使其能夠穩健地執行CLRS-30覆蓋的三十個算法。這種方法已知可以在圖空間中實現高達4倍輸入規模的分布外泛化。

在微調過程中,NAR的參數通常保持凍結狀態,因為額外的梯度會削弱模型的原有穩健性特性。同樣的原因,圖嵌入不會執行交叉注意力。

LLM本身可以在大規模數據集上進行預訓練,以建立其一般語言先驗,即使在開始時隨機初始化LM,也能獲得相同的實驗結果。

實驗設置

在實驗中,作者展示了TransNAR為大語言模型架構中的分布外推理帶來的顯著優勢。Transformer架構和初始化

論文使用Chinchilla家族的一個decoder-only架構、6層的Transformer模型,首先在MassiveText上進行了預訓練,參數量有70M,上下文大小為2048。

為了探究初始化設置的影響,作者設計了兩個變體進行消融實驗。

第一個變體中,Transformer權重用預訓練的結果初始化,模擬微調場景;第二個變體則是完全隨機的初始化。這兩個模型分別被標記為「預訓練」和「未訓練」。隨機位置編碼

之前DeepMind的一篇論文論證過,隨機位置編碼可以增強Transformer的長度泛化與推理穩健性。

論文地址:https://arxiv.org/abs/2305.16843

作者也提到,隨機位置嵌入確實在基線模型和TransNAR上都帶來了顯著增益,因此本文中的所有實驗也都使用隨機位置嵌入。預訓練NAR

論文使用CLRS-30基準中的問題預訓練了一個多任務、基于MPNN的NAR,輸入問題規模最多達16個。

由于CLRS-30的標準圖結構表達,這樣訓練出來的NAR有很強的分布外(OOD)泛化能力,有時在4倍大小的圖上仍保持競爭力,這種豐富的知識表達正是文本模型可資利用的。結合節點和邊緣的跨注意力貢獻

在上述的算法描述中,我們將NAR模型的圖輸入限于N個節點,但作者注意到了之前的研究曾嘗試過,同時對圖的節點和邊生成隱變量表達,也許可以添加有用的互補信息。

于是實驗中引入圖中邊的特征E(t) ∈ RN×N×k,并再次應用公式3讓Θ(t)對E(t)進行交叉注意力。

作者也嘗試其他方法,希望將E(t)和G(t)結合起來,比如拼接后加線性層組合、向量求和、2層MLP,或者用Gram-Schmidt過程使二者的貢獻正交化,但這些都沒有給原始方法帶來提升。數據集

訓練數據使用CLRS-Text基準,即CLRS-30基準的文本版本,以確定性的方式直接從基于圖的CLRS-30中派生,因此這兩個數據集傳達的是完全相同的信息。

表1展示了該數據集的幾個樣本,以及它們的輸入大小和token數量。

由于語言模型上下文長度的限制,實驗選擇用規模為4、8、12的問題訓練,并在規模為110、12、14的問題上評估。

值得注意的是,與當前的評估環境相比,CLRS-Text是對LM最具挑戰性的長程推理任務之一——相比小學數學,復雜度顯著提高。

CLRS-Text的挑戰性主要源于它允許顯式控制分布外泛化。然而,每個問題都有清晰的多項式時間解法,這意味當今典型LLM的參數量應該足以解決這些問題。

該數據集每種算法的每種輸入規模包含一萬個樣本,總共240萬個數據點,其中70%用于訓練、30%用于驗證。

訓練細節

實驗將batch大小設置為256訓練了7個epoch,并使用Adam優化器,學習率為10-4。

如前所述,在所有Chinchilla Transformer的旋轉位置編碼(RoPE)之上應用隨機位置編碼,最大長度為8192,且訓練期間保持NAR凍結。評估指標

作者提出,合適的評估指標應該反映模型在特定樣本上失敗的原因,且需要度量型輸出與正確答案的接近程度。因此,使用精確字符串匹配來計算模型準確性是絕對不可行的。

論文選擇的性能指標包括以下三個:

1. 形狀分數:一個二元指標,用于判斷輸出是否具有正確的形狀。例如,在排序任務中,輸出應與輸入有完全相同的元素數量?;蛘?,如果輸出是一個矩陣,我們需要確保其形狀與輸入和任務一致。

2. 解析分數:一個二元指標,用于判斷輸出是否不含任何非法字符。例如,在對數字列表進行排序的任務中,輸出不應包含任何字母。

3. CLRS分數:輸出中與真實答案匹配的元素百分比,也常用于CLRS-30測試。形狀分數為0時,CLRS分數也會自動置零。

這種多方面的指標設計能夠捕捉到LLM在文本上進行推理任務的各種失敗模式。

比如在某個問題規模上過度專門化訓練(導致輸出的形狀不正確)、無法處理看不見的數字組合(導致解析錯誤),由于推理錯誤造成的答案不一致則由CLRS分數反映。

結果

實驗結果顯示,TransNAR整體上顯著優于Transformer模型,在動態規劃、幾何、圖、貪心算法、排序、字符串等任務上的OOD推理能力都有大幅提升。

并且在大多數單個算法上,無論是在分布內還是分布外都表現更佳。

特別值得注意的是,這種方法不僅增強了Transformer原有的OOD泛化能力,還激發了一些模型先前完全不具備的能力。

比如Graham掃描(graham_scan)、最長公子串長度(lcs_length)、強連通分量(scc)這些經典問題中,基線模型得分為零或接近零,但TransNAR卻實現了突破。

分析形狀分數可以進一步解釋,為什么TransNAR表現如此出色。

首先,回顧一下,如果形狀不匹配,CLRS得分必然為零。

從形狀得分來看,將Transformer的輸出建立在NAR嵌入基礎上顯著提高了答案中形狀正確的比例——這表明TransNAR緩解了一種特定的LLM故障模式。

此外,通過對比「預訓練」和「未訓練」兩種初始化方式的分數,可以看到模型較好的穩定性和可用性。在隨機初始化時,也能訓練到與微調相當的水準。

然而,在一些算法中,TransNAR仍未能超越基線,且在分布內和分布外都是如此。

這些算法包括二分搜索、尋找最大子數組、最小值和快速選擇等,都涉及在輸入列表中按照索引搜索特定元素。

這暗示了TransNAR的一種故障模式:模型無法泛化到訓練數據中未見過的新索引邊界。因此,使用索引提示或許是一條有前景的改進途徑。

另一種可能的解釋是,NAR最終計算出的隱藏狀態難以在交叉注意力層以可泛化的方式被解碼。如果原因在此,解決途徑可以是增加交叉注意力的容量,或者采用漸進式解碼。

此外,TransNAR在架構上有一個本質的局限性,就是必需一個能得出ground truth的模擬器或者數據標簽,用于將輸入的文本轉換為圖結構,再作為模型輸入。

但是作者強調,TransNAR的概念對于未來研究是有借鑒意義的??梢钥紤]將這種混合架構的想法移植到單模態LLM,或者將TransNAR訓練后獲得的知識提煉出來注入到普通的Transformer中。參考資料:

https://arxiv.org/abs/2406.09308

新智元報道 編輯:喬楊 好困

本文由人人都是產品經理作者【新智元】,微信公眾號:【新智元】,原創/授權 發布于人人都是產品經理,未經許可,禁止轉載。

題圖來自Unsplash,基于 CC0 協議。

更多精彩內容,請關注人人都是產品經理微信公眾號或下載App
評論
評論請登錄
  1. 目前還沒評論,等你發揮!