Delta Attention: Inferenza Rapida e Precisa dell'Attenzione Sparsa tramite Correzione Delta
Delta Attention: Fast and Accurate Sparse Attention Inference by Delta Correction
May 16, 2025
Autori: Jeffrey Willette, Heejun Lee, Sung Ju Hwang
cs.AI
Abstract
Il meccanismo di attenzione di un trasformatore ha una complessità quadratica, portando a costi di inferenza elevati e latenza per sequenze lunghe. Tuttavia, le matrici di attenzione sono per lo più sparse, il che implica che molte voci possono essere omesse dal calcolo per un'inferenza efficiente. I metodi di inferenza con attenzione sparsa mirano a ridurre questo onere computazionale; tuttavia, comportano anche un fastidioso degrado delle prestazioni. Scopriamo che una delle ragioni di questo degrado è che il calcolo sparso induce uno spostamento distributivo negli output di attenzione. Questo spostamento distributivo fa sì che le query al momento della decodifica non si allineino bene con le chiavi appropriate della fase di prefill, portando a un calo delle prestazioni. Proponiamo una procedura semplice, innovativa ed efficace per correggere questo spostamento distributivo, avvicinando la distribuzione degli output di attenzione sparsa a quella dell'attenzione quadratica. Il nostro metodo può essere applicato su qualsiasi metodo di attenzione sparsa e risulta in un aumento medio delle prestazioni del 36%pt, recuperando l'88% dell'accuratezza dell'attenzione quadratica sul benchmark RULER da 131K quando applicato su un'attenzione a finestra scorrevole con token sink, aggiungendo solo un piccolo overhead. Il nostro metodo può mantenere approssimativamente il 98,5% di sparsità rispetto all'attenzione quadratica completa, rendendo il nostro modello 32 volte più veloce di Flash Attention 2 quando elabora prefills da 1M token.
English
The attention mechanism of a transformer has a quadratic complexity, leading
to high inference costs and latency for long sequences. However, attention
matrices are mostly sparse, which implies that many entries may be omitted from
computation for efficient inference. Sparse attention inference methods aim to
reduce this computational burden; however, they also come with a troublesome
performance degradation. We discover that one reason for this degradation is
that the sparse calculation induces a distributional shift in the attention
outputs. The distributional shift causes decoding-time queries to fail to align
well with the appropriate keys from the prefill stage, leading to a drop in
performance. We propose a simple, novel, and effective procedure for correcting
this distributional shift, bringing the distribution of sparse attention
outputs closer to that of quadratic attention. Our method can be applied on top
of any sparse attention method, and results in an average 36%pt performance
increase, recovering 88% of quadratic attention accuracy on the 131K RULER
benchmark when applied on top of sliding window attention with sink tokens
while only adding a small overhead. Our method can maintain approximately 98.5%
sparsity over full quadratic attention, making our model 32 times faster than
Flash Attention 2 when processing 1M token prefills.