简单的线性注意力语言模型平衡了召回率和吞吐量之间的权衡。
Simple linear attention language models balance the recall-throughput tradeoff
February 28, 2024
作者: Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, Dylan Zinsley, James Zou, Atri Rudra, Christopher Ré
cs.AI
摘要
最近的研究表明,基于注意力机制的语言模型在召回方面表现出色,即在先前上下文中看到的标记中生成的能力。然而,在推理过程中,基于注意力的模型的效率受到KV-缓存的内存消耗限制。在这项研究中,我们探讨了是否可以提高语言模型的效率(例如通过减少内存消耗)而不会影响召回能力。通过在广泛的架构上应用实验和理论,我们确定了模型状态大小和召回能力之间的关键权衡。我们发现,替代注意力的高效方法(例如H3、Mamba、RWKV)保持了固定大小的循环状态,但在召回方面存在困难。我们提出了一种名为BASED的简单架构,结合了线性和滑动窗口注意力。通过改变BASED的窗口大小和线性注意力特征维度,我们可以调整状态大小并遍历召回-内存权衡曲线的帕累托前沿,一端恢复了注意力的完整质量,另一端则是注意力替代方案的小状态大小。我们训练了多达13亿参数的语言模型,并展示了BASED在困惑度上与最强的次二次模型(例如Mamba)相匹配,并在真实世界的召回密集任务中表现出比它们高6.22个准确度点。线性注意力的实现通常不如经过优化的标准注意力实现高效。为了使BASED具有竞争力,我们开发了IO感知算法,使其在生成1024个标记时,使用13亿参数模型的语言生成比FlashAttention-2高出24倍的吞吐量。此工作的代码可在以下网址找到:https://github.com/HazyResearch/based。
English
Recent work has shown that attention-based language models excel at recall,
the ability to ground generations in tokens previously seen in context.
However, the efficiency of attention-based models is bottle-necked during
inference by the KV-cache's aggressive memory consumption. In this work, we
explore whether we can improve language model efficiency (e.g. by reducing
memory consumption) without compromising on recall. By applying experiments and
theory to a broad set of architectures, we identify a key tradeoff between a
model's state size and recall ability. We show that efficient alternatives to
attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but
struggle at recall. We propose BASED a simple architecture combining linear and
sliding window attention. By varying BASED window size and linear attention
feature dimension, we can dial the state size and traverse the pareto frontier
of the recall-memory tradeoff curve, recovering the full quality of attention
on one end and the small state size of attention-alternatives on the other. We
train language models up to 1.3b parameters and show that BASED matches the
strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them
on real-world recall-intensive tasks by 6.22 accuracy points. Implementations
of linear attention are often less efficient than optimized standard attention
implementations. To make BASED competitive, we develop IO-aware algorithms that
enable 24x higher throughput on language generation than FlashAttention-2, when
generating 1024 tokens using 1.3b parameter models. Code for this work is
provided at: https://github.com/HazyResearch/based.