簡單線性注意力語言模型平衡了召回率和吞吐量之間的折衷。
Simple linear attention language models balance the recall-throughput tradeoff
February 28, 2024
作者: Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, Dylan Zinsley, James Zou, Atri Rudra, Christopher Ré
cs.AI
摘要
最近的研究表明,基於注意力的語言模型在召回方面表現出色,即在先前上下文中看到的標記中生成的能力。然而,在推理過程中,基於注意力的模型的效率受到 KV-快取的高內存消耗的瓶頸限制。在這項研究中,我們探討是否可以提高語言模型的效率(例如通過降低內存消耗)而不影響召回。通過對各種架構進行實驗和理論分析,我們確定了模型狀態大小和召回能力之間的關鍵折衷。我們發現,與注意力相比,效率更高的替代方法(例如 H3、Mamba、RWKV)保持固定大小的循環狀態,但在召回方面表現不佳。我們提出了一種名為 BASED 的簡單架構,結合了線性和滑動窗口注意力。通過調整 BASED 的窗口大小和線性注意力特徵維度,我們可以調整模型的狀態大小,並在召回-內存折衷曲線的帕累托前沿上移動,一端恢復了注意力的完整質量,另一端則是注意力替代方案的小狀態大小。我們訓練了多達 13 億參數的語言模型,並展示出 BASED 在困惑度方面與最強的次二次模型(例如 Mamba)相匹敵,在真實世界的召回密集任務中表現優異,準確度提高了 6.22 個百分點。線性注意力的實現通常比優化的標準注意力實現效率低。為了使 BASED 具有競爭力,我們開發了 IO-aware 算法,使其在生成 1024 個標記時,使用 13 億參數模型的語言生成比 FlashAttention-2 更高 24 倍的吞吐量。此項工作的代碼可在以下網址找到:https://github.com/HazyResearch/based。
English
Recent work has shown that attention-based language models excel at recall,
the ability to ground generations in tokens previously seen in context.
However, the efficiency of attention-based models is bottle-necked during
inference by the KV-cache's aggressive memory consumption. In this work, we
explore whether we can improve language model efficiency (e.g. by reducing
memory consumption) without compromising on recall. By applying experiments and
theory to a broad set of architectures, we identify a key tradeoff between a
model's state size and recall ability. We show that efficient alternatives to
attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but
struggle at recall. We propose BASED a simple architecture combining linear and
sliding window attention. By varying BASED window size and linear attention
feature dimension, we can dial the state size and traverse the pareto frontier
of the recall-memory tradeoff curve, recovering the full quality of attention
on one end and the small state size of attention-alternatives on the other. We
train language models up to 1.3b parameters and show that BASED matches the
strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them
on real-world recall-intensive tasks by 6.22 accuracy points. Implementations
of linear attention are often less efficient than optimized standard attention
implementations. To make BASED competitive, we develop IO-aware algorithms that
enable 24x higher throughput on language generation than FlashAttention-2, when
generating 1024 tokens using 1.3b parameter models. Code for this work is
provided at: https://github.com/HazyResearch/based.