Focused Transformer – Contrastive Training for Context Scaling
本文為 “Focused Transformer: Contrastive Training for Context Scaling” (2023.07) 的論文摘要
論文全文參考
Open Source
LongLLaMA: Focused Transformer Training for Context Scalinghttps://github.com/cstankonrad/long_llama
Description
Goal
- 有效的整合文本知識
- solution 1: fine-tuning → issue: 需要大量的資源 & 難以管理
- solution 2: 在上下文插入知識 → issue: model 有效輸入長度限制 & 分心 (distraction issue)
- 本文主要是解決這個問題
Contributions
- 開發了 FOT 機制來解決分心問題, 透過外部 memory 和 KNN lookup 來擴展上下文長度
- 需要額外針對 個別指定的 attention layer 進行 fine-tune
- 魔改了 OpenLLaMA 3B 和 7B, 使其成為了 LONGLLaMA
- proof: 256k 文長的 passkey 實驗 (在一堆無意義的文字中埋入 passkey, 並讓模型正確地找到)
Methodology
Focused Transformer (FOT) 機制是通過對比學習,增強了 attention layer (key, value) 中的空間結構,從而擴展了上下文長度
- Follow
memorizing transformer
的方法來降低 attention 計算複雜度 - 具體來說,對於外部插入的知識,會去訓練 value (local) 和 key (external) 間的 attention, 所以需要一定資料進行 pre-train 或 fine-tune
- 如何訓練出一個好的 attention 則是透過對比學習的方式,利用目前文章(正樣本)和前一篇文章(負樣本)來實現
- 更進一步可以透過 cross-batch 的手法打包多筆資料(超參數 )讓訓練結果更 robust,這也暗示著可以解決分心的問題
Conclusion
- 此篇論文實作了將原本只能處理 2k tokens 的 OpenLLaMA 透過 FOT 訓練後的 LONGLLaMA 可以處理高達 256k tokens。其他 LLM 均可以透過此方法訓練出更長的處理長度的版本。
- 有了這麼大量的有效處理長度,往後在 domain 應用場景也許就不需要再額外做 pre-train,就可以達成不錯的 inference 水準。
- 這篇的限制為如何加速 KNN search 的搜尋速度和論文提到增加參考文章量 是一個提升 performance 的有效方法,這種外掛式知識的手段極度考驗 GPU memory,通常會使用到分布式運算。實作上需要再研究此超參數設定帶來的 trade off。