大模型最強架構(gòu)TTT問世!斯坦福UCSD等5年磨一劍, 一夜推翻Transformer
超越Transformer和Mamba的新架構(gòu),剛剛誕生了。斯坦福UCSD等機構(gòu)研究者提出的TTT方法,直接替代了注意力機制,語言模型方法從此或?qū)氐赘淖儭?/p>
一覺醒來,超越Transformer和Mamba的新架構(gòu)誕生了?
斯坦福、UCSD、UC伯克利和Meta的研究人員提出了一種全新架構(gòu),用機器學習模型取代RNN的隱藏狀態(tài)。
論文地址:https://arxiv.org/abs/2407.04620
這個模型通過對輸入token進行梯度下降來壓縮上下文,這種方法被稱為「測試時間訓練層(Test-Time-Training layers,TTT)」。
TTT層直接替代了注意力機制,解鎖了具有表現(xiàn)力記憶的線性復雜度架構(gòu),使我們能夠在上下文中訓練包含數(shù)百萬(未來可能是數(shù)十億)個token的LLM。
作者相信,這個研究了一年多的項目,將從根本上改變我們的語言模型方法。
而結(jié)果證明,TTT-Linear和TTT-MLP直接趕超或擊敗了最強的Transformer和Mamba!
作者之一的Xiaolong Wang驚喜地表示:不敢相信,我們真的做到了。
更令人興奮的是,雖然目前TTT只應(yīng)用于語言建模,但在未來,它也可以用在長視頻上,可謂前景遠大。
在將來,當我們對長視頻進行建模時,就可以對幀進行密集采樣,而不是采樣1FPS了。這些密集幀對Transformer是一種負擔,但對于TTT層來說,這卻是一種福音!
01 一個5年多的想法,終于實現(xiàn)了
作者表示,在過去的1.5年里,團隊一直在開發(fā)一種新的LLM架構(gòu),可以具有線性復雜度和更強的隱藏狀態(tài),用于長上下文建模。
而這個測試時訓練(TTT)的想法,已經(jīng)研究了超過5年。
Xiaolong清晰記得,在剛開始做博士后時,Alyosha曾讓自己去找Yu Sun討論TTT。
這次會面,就是這項研究的起點。
序列模型會把歷史上下文存儲在一個隱藏狀態(tài)中。
像Mamba這樣的RNN層,會隨著時間的推移壓縮成一個固定大小的狀態(tài),它們雖然效率很高,但性能受限于其表達能力。
注意力機制有一個KV緩存,它會隨著時間的推移不斷增長。這個狀態(tài)不會壓縮任何歷史上下文,但隨著上下文長度的增加,成本也會越來越高。
團隊成員想:既然這樣,為什么不把上下文壓縮到模型的權(quán)重中——就像LLM處理互聯(lián)網(wǎng)數(shù)據(jù)那樣呢?
這種「隱藏狀態(tài)模型」既能在時間上保持固定大小,又能大大增強表達能力。
研究人員使用了自監(jiān)督學習來更新隱藏狀態(tài)的權(quán)重,對每個token進行一次梯度下降。在處理一個序列時,該狀態(tài)已經(jīng)在其上下文窗口中的token上「訓練」過了。
值得注意的是,隱藏狀態(tài)只存在于端到端架構(gòu)中的一層。其他組件,比如QKV投影矩陣,是在預訓練期間通過標準的交叉熵目標函數(shù)學習的。
因此,端到端架構(gòu)實際上是在進行元學習,尋找壓縮上下文的最佳方式,以便更好地預測下一個token,也就是在「學習如何在測試時學習」。
結(jié)果顯示,與Mamba相比,TTT-Linear具有更好的困惑度和更少的FLOP(左),并且更好地利用了長上下文(右)。
下圖顯示了批大小為16的情況下,隨著上下文長度的變化,每個token的前向時間(延遲)。所有模型的參數(shù)都是1.3B(Mamba為1.4B)。
可以看到,隨著上下文長度的增加,Transformer每個token的前向時間呈線性增長,但其他兩種方法的前向時間基本保持不變。
在8k上下文時,TTT-Linear比Transformer更快,與Mamba相當。
02 RNN的尷尬現(xiàn)實
2020年,OpenAI縮放定律論文表明LSTM(RNN的一種)無法像Transformer那樣進行縮放,或有效地使用長上下文。
真的是這樣嗎?
在這個項目中,研究人員重新評估了圖2中的這些發(fā)現(xiàn)。
在左側(cè),可以觀察到Mamba(當今最流行的RNN之一)的擴展性與強大的Transformer類似,這是自2020年的LSTM以來顯示出的巨大進步。
然而,在右側(cè),可以觀察到與OpenAI相同的Mamba問題。
平均而言,序列中靠后的token應(yīng)該更容易預測,因為它們以更多信息為條件。
對Transformer來說確實如此,每個token索引的平均復雜度在其32k上下文中不斷減少。相比之下,Mamba在16k后就出現(xiàn)了同樣的情況。
對于現(xiàn)有的RNN來說,這個結(jié)果代表了一個尷尬的現(xiàn)實——
一方面,RNN(相對于Transformer)的主要優(yōu)勢就是它們的線性(相對于二次)復雜性。這種漸進優(yōu)勢實際上只會在長上下文中實現(xiàn)。
另一方面,一旦上下文足夠長,現(xiàn)有的RNN(如Mamba)就很難真正利用額外的條件信息。
長上下文的困難是RNN層本質(zhì)上的問題:與自注意力機制不同,RNN層必須將上下文壓縮為固定大小的隱藏狀態(tài)。
作為一種壓縮啟發(fā)式,更新規(guī)則需要發(fā)現(xiàn)成千上萬甚至數(shù)百萬個token之間的底層結(jié)構(gòu)和關(guān)系。
研究人員首先觀察到,自監(jiān)督學習可以將大量訓練集壓縮為LLM等模型的權(quán)重,該模型通常表現(xiàn)出對其訓練數(shù)據(jù)之間語義聯(lián)系的深刻理解,而這,恰恰是他們所需要的。
1. TTT層
受此啟發(fā),研究人員設(shè)計了一類新的序列建模層,其中隱藏狀態(tài)是模型,更新規(guī)則是自監(jiān)督學習的一個步驟。
由于更新測試序列上隱藏狀態(tài)的過程,相當于在測試時訓練模型,因此此類新層稱為測試時訓練(TTT)層。
研究人員引入兩個簡單的實例:TTT-Linear和TTT-MLP,其中隱藏狀態(tài)分別是線性模型和兩層MLP。TTT層可以集成到任何網(wǎng)絡(luò)架構(gòu)中并進行端到端優(yōu)化,類似于RNN層和自注意力。
2. 實際運行時間
TTT層在FLOP方面已經(jīng)非常高效,研究人員則更進一步地提出了兩項創(chuàng)新,使其在實際運行時間內(nèi)也能保持高效。
首先,與在常規(guī)訓練中對mini-batch序列采取梯度步進以實現(xiàn)更好的并行性類似,他們也在TTT中使用了mini-batch的token。
其次,研究人員為每個TTT mini-batch內(nèi)的操作開發(fā)了一種對偶形式,以更好地利用現(xiàn)代GPU和TPU。這種對偶形式的輸出與原始實現(xiàn)相當,但訓練速度卻快了5倍以上。
正如圖3所示,TTT-Linear在8k上下文中比Transformer更快,并且與Mamba相當。
03 Transformer殺手——TTT
如圖4所示,所有的序列建模層,都可以從將歷史上下文存儲到隱藏狀態(tài)的角度來看待。
比如,RNN層——如LSTM、RWKV和Mamba層——將上下文壓縮成一個固定大小的狀態(tài),這個狀態(tài)隨時間變化。
這種壓縮帶來了兩種結(jié)果:優(yōu)勢是處理效率高,因為每個token的處理時間是恒定的。劣勢是在處理長上下文時,RNN性能受限于隱藏狀態(tài)的「表達能力」。
自注意力機制(Self-attention)也可以從如上角度來理解。
不同之處在于,它的隱藏狀態(tài),通常稱為鍵值(KV)緩存是一個隨t增長的線性list。
它可以存儲所有的上下文,并且不會進行壓縮,具有很好的表達能力,不過其處理時間隨上下文長度線性增長。
因此,為了在長上下文中既保持效率,又具有表達能力,需要一個更好的“壓縮啟發(fā)式”(compression heuristic)方法。
具體來說,就需要將數(shù)百萬個token壓縮成一個能有效捕捉其底層結(jié)構(gòu)和關(guān)系的隱藏狀態(tài)。
1. TTT隱藏狀態(tài)
研究人員的關(guān)鍵思想是,使用自監(jiān)督學習來將歷史上下文x1,…,xt壓縮成一個隱藏狀態(tài)St。
方法是將上下文視為一個無標簽數(shù)據(jù)集,而將狀態(tài)視為一個模型。
具體來說,隱藏狀態(tài)St現(xiàn)在等同于一個模型f的權(quán)重Wt,這個模型f可以是線性模型、小型神經(jīng)網(wǎng)絡(luò)或其他任何形式。輸出規(guī)則簡單地表示為:
直觀來講,輸出token就是由更新后權(quán)重Wt的模型f對xt所做的預測。更新規(guī)則是在某個自監(jiān)督損失?上進行的一步梯度下降:
其中學習率為η。從壓縮的角度來看,每種啟發(fā)式方法都需要決定記住/忘記哪些輸入。W會記住那些產(chǎn)生大梯度的輸入——直觀地說,就是那些使W學習很多的輸入。
?的一種選擇是重構(gòu)xt本身。為了使學習問題變得非平凡,作者首先將xt處理成一個被破壞的輸入x?t,然后優(yōu)化:
類似于去噪自編碼器,f需要發(fā)現(xiàn)xt各維度之間的相關(guān)性,以便從部分信息x?t中重構(gòu)出xt。
如圖5所示,梯度下降能夠減少?,但無法將其降至零。
與其他RNN層和自注意力機制一樣,研究人員將輸入序列x1,…,xT映射到輸出序列Z1,…,ZT的算法可以被編程到序列建模層的前向傳播中,使用上述的隱藏狀態(tài)、更新規(guī)則和輸出規(guī)則。
即使在測試時,新層仍然為每個輸入序列訓練一個不同的權(quán)重序列W1,…,WT。
因此,研究人員將其稱之為測試-時間訓練層(TTT)。
4. 使用TTT層訓練神經(jīng)網(wǎng)絡(luò)
TTT層的前向傳播,也有相應(yīng)的后向傳播。
TTT層與RNN層、自注意力機制有著相同的接口,因此可以在任何更大的神經(jīng)網(wǎng)絡(luò)架構(gòu)中替換它們。
值得一提的是,訓練帶有TTT層神經(jīng)網(wǎng)絡(luò)的方式,與訓練任何其他Transformer模型相同。
可以使用相同的數(shù)據(jù)、方法和目標(如下一個token預測)來優(yōu)化網(wǎng)絡(luò)其余部分的參數(shù)。
在此,研究人員將訓練更大的神經(jīng)網(wǎng)絡(luò)稱為外循環(huán)(outer loop),而在每個TTT層內(nèi)訓練W稱為內(nèi)循環(huán)(inner loop)。
它們之間梯度計算的區(qū)別是,內(nèi)循環(huán)針對的是W(即模型f的參數(shù)),外循環(huán)針對的是網(wǎng)絡(luò)其余部分的參數(shù)θrest。
5. TTT學習自監(jiān)督任務(wù)
可以說,TTT最重要的部分是自監(jiān)督任務(wù),因為它決定了W從測試序列中學習的特征類型。
在這個任務(wù)的設(shè)計上,研究人員采取了更加端到端的方法——直接優(yōu)化自監(jiān)督任務(wù)以實現(xiàn)下一個token預測的最終目標。
具體來說,研究者將自監(jiān)督任務(wù)的學習,作為外循環(huán)的一部分。
從如上公式3中的簡單重構(gòu)任務(wù)開始,添加了一些外循環(huán)參數(shù)來讓這個任務(wù)可學習。最新的自監(jiān)督損失是:
在內(nèi)循環(huán)中,只有W被優(yōu)化,因此作為?的參數(shù)寫出;θ們是這個損失函數(shù)的“超參數(shù)”。在外循環(huán)中,θK,θV,θQ與θrest一起被優(yōu)化,而W僅僅是一個隱藏狀態(tài),不是參數(shù)。
圖6用代碼說明了這種區(qū)別,其中θK和θQ被實現(xiàn)為TTT層的參數(shù),類似于自注意力中的KV參數(shù)。
總的來說,θK,θV,θQ所有可能的選擇構(gòu)成了一系列多視圖重構(gòu)任務(wù),外循環(huán)可以被理解為從這個任務(wù)組中選擇一個具體任務(wù)。為了簡單起見,研究人員在這里將所有視圖設(shè)計為線性投影。
6. mini-batch TTT并行化
目前,開發(fā)的原生TTT層在浮點運算(FLOP)次數(shù)方面已經(jīng)非常高效。
然而,其更新規(guī)則:
無法實現(xiàn)并行化,因為Wt在兩個位置上依賴于Wt-1:負號和▽l。
對此,研究人員提出了mini-batch梯度下降,用b表示TTT批大小。
研究中使用Gt = ▽l(Wt’;xt),其中t’ = t – mod(t,b),其中代表著前一個mini-batch的最后一個時間步(或者第一個mini-batch 0),因此,可以一次并行b個梯度計算。
7. 對偶形式
上面介紹的并行化是必要的,但對于“實際運行時間”(wall-clock time)的效率來說還不夠。
然而,現(xiàn)實中,是無法對單個matmul來計算GtS所有的b。相反,需要b個外積來對其進行一一計算。更糟糕的是,對于每個
Gt是d×d,這會比大dXt產(chǎn)生更大的內(nèi)存占用和I/O成本。
為了解決這兩個問題,研究人員觀察到:我們實際上并不需要具體化G1, . . . , Gb,只要要我們可以在mini-batch結(jié)束時計算Wb,并且輸出token z1, . . . , zb(如上圖7所示)。
現(xiàn)在,就可以用上面簡化的TTT-Linear情況來演示這些計算,表示X = [x1, . . . , xb]:
所以Wb可以用matmul方便地計算出來。為了計算Z = [z1, . . . , zb],我們知道:
表示
和矩陣
可以得出:
如上過程,研究人員將其稱為「對偶形式」。
8. 理論等價
前面已經(jīng)提到f可以是線性模型,也可以是神經(jīng)網(wǎng)絡(luò)。還有更新規(guī)則的三種變體:online GD、batch GD和mini-batch GD。
如下圖所示,在這些2×3組合中,每一種都會引起TTT層的不同實例化。
研究中,作者分別從2個定理證明了在這些誘導實例中,具有線性模型和batch GD的TTT層等同于線性注意力——一個廣為人知的RNN層。
圖10總結(jié)了所有序列建模層的更廣泛范圍內(nèi)TTT層的一般定義。
9. 兩種變體
研究中,作者提出了TTT層的兩種變體TTT-Linear和TTT-MLP,僅在f的實例化方面有所不同。
對于TTT-Linear,
,其中W是平方。對于TTT-MLP,有兩層,類似于Transfomer的MLP。
具體來說,隱藏維度是4×輸入維度,然后是GELU激活。為了在TTT期間獲得更好的穩(wěn)定性,f始終包含層歸一化 (LN) 和殘差連接。
即,
,其中,可以是或。
04 實驗
通過與兩個基線Transformer和Mamba(現(xiàn)代RNN)比較,研究人員評估了TTT-Linear和TTT-MLP。
數(shù)據(jù)集
繼續(xù)Mamba論文之后,研究人員在Pile上執(zhí)行了2k和8k上下文長度的標準實驗,Pile是一個用于訓練開源LLM的流行文檔數(shù)據(jù)集。
主架構(gòu)
Transformer和Mamba使用不同的,除非另有說明,TTT-Linear和TTT-MLP始終使用Mamba架構(gòu)。
1. 短上下文:the Pile
在2k上下文中,TTT-Linear(M)、Mamba和Transformer具有相當?shù)男阅埽€條大部分重疊。
TTT-MLP(M)在較大的FLOP預算下表現(xiàn)稍差。盡管TTT-MLP在每個模型大小上,都比TTT-Linear具有更好的復雜度,但FLOP的額外成本抵消了這種優(yōu)勢。
在8k上下文中,TTT-Linear(M)和TTT-MLP(M)的表現(xiàn)均明顯優(yōu)于Mamba。即使是具有Transformer架構(gòu)的TTT-MLP(T),性能也比Mamba略好。
另外,研究人員還觀察到了一個非常明顯的現(xiàn)象:隨著上下文長度變長,TTT層相對于Mamba的優(yōu)勢就更大了。
2. 長上下文:Books
為了評估長上下文中的功能,研究人員使用了Pile的一個流行子集——Books,對從1k到32k以2個增量的上下文長度進行了實驗。
根據(jù)上圖,可以觀察到——
在Books的2k上下文中,Pile 2k的所有觀察結(jié)果仍然成立,唯一的例外是Mamba的表現(xiàn)略好于TTT-Linear。
在32k上下文中,TTT-Linear(M)和TTT-MLP(M)的性能均優(yōu)于Mamba,與Pile 8k的觀察結(jié)果類似。即使具有Transformer架構(gòu)的TTT-MLP(T),在32k上下文中的表現(xiàn)也比Mamba稍好。
在1.3B尺度上,TTT-MLP(T)僅比TTT-MLP(M)稍差。由于缺之清晰的線性擬合,很難推導出經(jīng)驗縮放定律。然而,TTT-MLP(T)的強勁趨勢表明,Transformer架構(gòu)可能更適合超出評估的更大模型和更長上下文。
上下文長度作為超參數(shù)
雖然輸入序列的長度由用戶確定,但語言模型處理輸入的上下文長度可以由工程師確定。因此,上下文長度也是一個可以選擇的超參數(shù)。
對于具有線性復雜度的LLM,研究人員選擇了困惑度中的argmin,因為每個上下文長度都有相同的FLOP。
從圖13中,可以觀察到以下結(jié)果——
- 性能最好的方法TTT-Linear和TTT-MLP的線幾乎完全重疊。Mamba和TF Finetune的線在10^20 FLOP后也大部分重疊。
- TF Finetune的性能明顯優(yōu)于TF Pretrain,因為它受益于長上下文,而不會在訓練FLOP中產(chǎn)生極大的成本。
- 對于所有從頭開始訓練的方法(包括TF預訓練),一旦上下文長度變得太大,困惑度就會變得更糟。
從上圖可見,與TTT-Linear相比,TTT-MLP在短上下文中表現(xiàn)稍差,但在長上下文中表現(xiàn)更好。
這一觀察結(jié)果正符合研究人員的預期,即作為隱藏狀態(tài)的MLP比線性模型更具表現(xiàn)力。同樣,所有方法都具有與Mamba 1.4B相同的訓練FLOP。
3. 實際運行時間
LLM訓練和推理可以分解為前向、后向和生成。
由于前向(在訓練和推理期間)和后向都可以并行化,因此研究人員使用對偶形式。生成新token(也稱為解碼)本質(zhì)上是順序的,因此研究人員使用原始形式。
由于資源限制,這項實驗是用JAX編寫并在TPU上運行的。
然而,由于Mamba(在PyTorch、Triton和CUDA中實現(xiàn))只能在GPU上運行,因此為了公平比較,研究人員還重寫了方法,以在GPU上運行。
具體來說,研究人員在ThunderKittens中編寫了一個用于前向的GPU內(nèi)核。從歷史上看,由于并行性和矩陣相乘的使用不當,RNN在前向和后向過程中效率低下。
這個前向內(nèi)核的目標,是證明mini-batch TTT和這些問題對偶形式的有效性。
圖15的左圖顯示了前向內(nèi)核批大小為16的延遲。所有模型參數(shù)均為1.3B(Mamba為 1.4B)。
對于Transformer,每個token的時間隨著上下文長度的增加而線性增長,但對于其他方法則大致保持不變。
此外,研究人員在Triton中編寫了另一個用于生成的GPU內(nèi)核,并在圖15的右圖中對批大小為512的速度進行了基準測試。
可以看出,TTT-Linear和Mamba的延遲幾乎相同,明顯小于Transformer和TTT-MLP。
Mamba之后,又看到TTT這么能打的新架構(gòu)誕生,少不了AI社區(qū)的熱議。
有網(wǎng)友稱,這會不會是最接近實時上下文的方法?很想聽聽大家的想法。這意味著TTT甚至在使用過程中,也能夠?qū)W習和適應(yīng),為長上下文提供更好的性能,而不會產(chǎn)生通常與Transformer相關(guān)的高昂計算成本。
OpenAI視頻生成研究人員對此表示,這項研究看起來很有趣。
如果scaling law依然存在,TTT將帶來難以置信的影響。對于長序列,Transformer的計算成本往往很高,當長序列變得更長時,RNN會遺忘。TTT訓練巧妙地利用神經(jīng)網(wǎng)絡(luò)解決RNN的不足。
作者介紹
論文最后,分別列出了這篇研究的作者貢獻。
其中的核心作者是,Yu Sun、Xinhao Li和Karan Dalal。
Yu Sun
Yu Sun是斯坦福大學計算機專業(yè)的博士后,導師是Carlos Guestrin、Tatsu Hashimoto和Sanmi Koyejo。
此前,他曾在加州大學伯克利分校完成了電子工程科學博士學位,導師是Alyosha Efros和Moritz Hardt。他還在康奈爾大學拿到了學士學位。
個人主頁中,他介紹自己的研究重點是一種名為測試時間訓練(test-time training)的算法框架。其核心思想是,每個測試實例都定義了自己的學習問題,都有自己的泛化目標。這通常使用自監(jiān)督學習,為每個實例即時訓練一個不同的模型來實現(xiàn)的。
在最新研究中,Yu Sun與Xinhao Li在2022年11月共同啟動了這一項目。自2023年6月起,Yu Sun專職負責該項目。
他提出了項目的概念框架,設(shè)計了mini-batch TTT和對偶形式(dual form)。
Xinhao Li
Xinhao Li是UC San Diego研二的學生,導師是Xiaolong Wang教授。他本人的研究興趣主要是深度學習和計算機視覺。
他在斯坦福大學Tatsunori Hashimoto教授的團隊中作為訪問學生,與Yu Sun博士和其他導師朋友一起工作。在此之前,他曾在電子科技大學獲得了學士學位。
在2024年3月之前,Xinhao Li是TTT早期代碼庫的主要貢獻者,這些代碼庫塑造了最新項目。
Karan Dalal
Karan Dalal是UC Berkeley電子工程科學系的本科生。他于2023年6月全職加入該項目,與Xinhao Li合作共同領(lǐng)導了當前代碼庫的開發(fā)工作。
參考資料:
https://x.com/karansdalal/status/1810338845659131940
https://x.com/xiaolonw/status/1810387662060269668
https://arxiv.org/abs/2407.04620
本文由人人都是產(chǎn)品經(jīng)理作者【新智元】,微信公眾號:【新智元】,原創(chuàng)/授權(quán) 發(fā)布于人人都是產(chǎn)品經(jīng)理,未經(jīng)許可,禁止轉(zhuǎn)載。
題圖來自Unsplash,基于 CC0 協(xié)議。
- 目前還沒評論,等你發(fā)揮!