SparseD: Atención Dispersa para Modelos de Lenguaje de Difusión
SparseD: Sparse Attention for Diffusion Language Models
September 28, 2025
Autores: Zeqing Wang, Gongfan Fang, Xinyin Ma, Xingyi Yang, Xinchao Wang
cs.AI
Resumen
Si bien los modelos de lenguaje basados en difusión (DLMs, por sus siglas en inglés) representan una alternativa prometedora a los modelos autorregresivos (ARs), los DLMs de código abierto existentes presentan una latencia de inferencia elevada. Este cuello de botella se debe principalmente a la complejidad cuadrática de la atención con respecto a la longitud del contexto al calcular todos los pares consulta-clave. Intuitivamente, para reducir esta complejidad, una estrategia natural es restringir la atención a patrones dispersos que retengan únicamente las conexiones más relevantes. Estos enfoques están bien establecidos en los ARs, donde la atención sigue patrones dispersos fijos y claramente definidos. Sin embargo, en los DLMs, observamos comportamientos de dispersión distintos: (1) los patrones de atención varían entre las cabezas, (2) los patrones de atención en cada cabeza permanecen altamente similares a lo largo de los pasos de desruido, y (3) los pasos iniciales de desruido son críticos para la generación. Estos hallazgos hacen que los métodos de atención dispersa diseñados para ARs sean en gran medida incompatibles con los DLMs, ya que no logran capturar estructuras específicas de cada cabeza y corren el riesgo de degradar la generación cuando se aplican en los pasos iniciales de desruido. Para abordar estos desafíos, proponemos SparseD, un novedoso método de atención dispersa para DLMs. Aprovechando las observaciones, SparseD solo requiere precalcular una vez los patrones dispersos específicos de cada cabeza y los reutiliza en todos los pasos. Esto evita recalcular los patrones dispersos en cada paso de desruido. Al mismo tiempo, SparseD utiliza atención completa en los pasos iniciales y luego cambia a atención dispersa en etapas posteriores para mantener la calidad de la generación. En conjunto, esto establece a SparseD como una solución práctica y eficiente para implementar DLMs en aplicaciones de contexto largo. Los resultados experimentales demuestran que SparseD logra una aceleración sin pérdidas, alcanzando una velocidad hasta 1.50 veces mayor que FlashAttention con una longitud de contexto de 64k y 1,024 pasos de desruido.
English
While diffusion language models (DLMs) offer a promising alternative to
autoregressive models (ARs), existing open-source DLMs suffer from high
inference latency. This bottleneck is mainly due to the attention's quadratic
complexity with respect to context length in computing all query-key pairs.
Intuitively, to reduce this complexity, a natural strategy is to restrict
attention to sparse patterns that retain only the most relevant connections.
Such approaches are well-established in ARs, where attention follows fixed and
clearly defined sparse patterns. However, in DLMs, we observe distinct sparsity
behaviors: (1) attention patterns vary across heads, (2) attention patterns in
each head remain highly similar across denoising steps, and (3) early denoising
steps are critical for generation. These findings render sparse attention
methods designed for ARs largely incompatible with DLMs, as they fail to
capture head-specific structures and risk degrading generation when applied in
early denoising steps. To address these challenges, we propose SparseD, a novel
sparse attention method for DLMs. Leveraging the observations, SparseD only
requires pre-computing head-specific sparse patterns one time, and reuses them
across all steps. This prevents recomputing sparse patterns at each denoising
step. Meanwhile, SparseD uses full attention in the early steps, then switches
to sparse attention later to maintain generation quality. Together, these
establish SparseD as a practical and efficient solution for deploying DLMs in
long-context applications. Experimental results demonstrate that SparseD
achieves lossless acceleration, delivering up to 1.50times speedup over
FlashAttention at a 64k context length with 1,024 denoising steps.