Atención Delta: Inferencia Rápida y Precisa de Atención Dispersa mediante Corrección Delta
Delta Attention: Fast and Accurate Sparse Attention Inference by Delta Correction
May 16, 2025
Autores: Jeffrey Willette, Heejun Lee, Sung Ju Hwang
cs.AI
Resumen
El mecanismo de atención de un transformador tiene una complejidad cuadrática, lo que conlleva altos costos de inferencia y latencia para secuencias largas. Sin embargo, las matrices de atención son mayormente dispersas, lo que implica que muchas entradas pueden omitirse del cálculo para una inferencia eficiente. Los métodos de inferencia de atención dispersa buscan reducir esta carga computacional; no obstante, también vienen acompañados de una problemática degradación del rendimiento. Descubrimos que una de las razones de esta degradación es que el cálculo disperso induce un cambio distribucional en las salidas de atención. Este cambio distribucional hace que las consultas en tiempo de decodificación no se alineen adecuadamente con las claves apropiadas de la etapa de prellenado, lo que resulta en una caída del rendimiento. Proponemos un procedimiento simple, novedoso y efectivo para corregir este cambio distribucional, acercando la distribución de las salidas de atención dispersa a la de la atención cuadrática. Nuestro método puede aplicarse sobre cualquier método de atención dispersa y resulta en un aumento promedio del rendimiento de 36 puntos porcentuales, recuperando el 88% de la precisión de la atención cuadrática en el benchmark RULER de 131K cuando se aplica sobre la atención de ventana deslizante con tokens sumidero, mientras añade solo un pequeño sobrecosto. Nuestro método puede mantener aproximadamente un 98.5% de dispersión sobre la atención cuadrática completa, haciendo que nuestro modelo sea 32 veces más rápido que Flash Attention 2 al procesar prellenados de 1 millón de tokens.
English
The attention mechanism of a transformer has a quadratic complexity, leading
to high inference costs and latency for long sequences. However, attention
matrices are mostly sparse, which implies that many entries may be omitted from
computation for efficient inference. Sparse attention inference methods aim to
reduce this computational burden; however, they also come with a troublesome
performance degradation. We discover that one reason for this degradation is
that the sparse calculation induces a distributional shift in the attention
outputs. The distributional shift causes decoding-time queries to fail to align
well with the appropriate keys from the prefill stage, leading to a drop in
performance. We propose a simple, novel, and effective procedure for correcting
this distributional shift, bringing the distribution of sparse attention
outputs closer to that of quadratic attention. Our method can be applied on top
of any sparse attention method, and results in an average 36%pt performance
increase, recovering 88% of quadratic attention accuracy on the 131K RULER
benchmark when applied on top of sliding window attention with sink tokens
while only adding a small overhead. Our method can maintain approximately 98.5%
sparsity over full quadratic attention, making our model 32 times faster than
Flash Attention 2 when processing 1M token prefills.Summary
AI-Generated Summary