Delta注意力机制:通过Delta校正实现快速且准确的稀疏注意力推理
Delta Attention: Fast and Accurate Sparse Attention Inference by Delta Correction
May 16, 2025
作者: Jeffrey Willette, Heejun Lee, Sung Ju Hwang
cs.AI
摘要
Transformer的注意力机制具有二次复杂度,导致长序列推理时的高成本和延迟。然而,注意力矩阵大多稀疏,这意味着许多计算项可被省略以实现高效推理。稀疏注意力推理方法旨在减轻这一计算负担,但同时也伴随着性能下降的困扰。我们发现,性能下降的一个原因是稀疏计算引发了注意力输出的分布偏移。这种分布偏移导致解码阶段的查询无法与预填充阶段的相应键良好对齐,从而造成性能下降。我们提出了一种简单、新颖且有效的方法来纠正这种分布偏移,使稀疏注意力输出的分布更接近二次注意力的分布。我们的方法可应用于任何稀疏注意力方法之上,在131K RULER基准测试中,当应用于带汇聚标记的滑动窗口注意力之上时,平均带来36%的性能提升,恢复了88%的二次注意力准确率,同时仅增加了少量开销。我们的方法能保持约98.5%的稀疏度,相较于全二次注意力,在处理100万标记的预填充时,使模型速度比Flash Attention 2快32倍。
English
The attention mechanism of a transformer has a quadratic complexity, leading
to high inference costs and latency for long sequences. However, attention
matrices are mostly sparse, which implies that many entries may be omitted from
computation for efficient inference. Sparse attention inference methods aim to
reduce this computational burden; however, they also come with a troublesome
performance degradation. We discover that one reason for this degradation is
that the sparse calculation induces a distributional shift in the attention
outputs. The distributional shift causes decoding-time queries to fail to align
well with the appropriate keys from the prefill stage, leading to a drop in
performance. We propose a simple, novel, and effective procedure for correcting
this distributional shift, bringing the distribution of sparse attention
outputs closer to that of quadratic attention. Our method can be applied on top
of any sparse attention method, and results in an average 36%pt performance
increase, recovering 88% of quadratic attention accuracy on the 131K RULER
benchmark when applied on top of sliding window attention with sink tokens
while only adding a small overhead. Our method can maintain approximately 98.5%
sparsity over full quadratic attention, making our model 32 times faster than
Flash Attention 2 when processing 1M token prefills.Summary
AI-Generated Summary