SparDA: Atención Dispersa Desacoplada para Inferencia Eficiente de LLM de Contexto Largo
SparDA: Sparse Decoupled Attention for Efficient Long-Context LLM Inference
June 3, 2026
Autores: Yaosheng Fu, Guangxuan Xiao, Xin Dong, Song Han, Oreste Villa
cs.AI
Resumen
La atención dispersa reduce el cómputo y el ancho de banda de memoria para la inferencia de LLM con contexto largo. Sin embargo, persisten dos desafíos clave: (1) la capacidad de la caché KV sigue creciendo con la longitud de la secuencia, y su descarga a la memoria de la CPU introduce un cuello de botella de transferencia PCIe; (2) el propio paso de selección dispersa conserva una complejidad de O(T²) y puede dominar el costo de atención en contextos largos. Proponemos SparDA, una arquitectura de atención dispersa desacoplada que introduce una cuarta proyección por capa, el Forecast, junto con Query, Key y Value. El Forecast predice los bloques KV que necesitará la siguiente capa, lo que permite una selección anticipada que superpone la precarga de CPU a GPU con la ejecución de la capa actual. Debido a que el Forecast está desacoplado de la consulta de atención, nuestra implementación de GQA utiliza una cabeza Forecast por grupo GQA, reduciendo la sobrecarga de selección en comparación con el selector multi-cabeza original. SparDA añade menos del 0,5% de parámetros y entrena solo las proyecciones Forecast igualando la distribución de atención del selector original. En dos modelos de 8B preentrenados con dispersión, SparDA iguala o mejora ligeramente la precisión y ofrece hasta 1,25 veces de aceleración en prefill y 1,7 veces en decodificación con respecto a la línea base de atención dispersa con descarga. Al permitir tamaños de lote factibles más grandes en una sola GPU, SparDA alcanza además hasta 5,3 veces mayor rendimiento de decodificación que la línea base dispersa sin descarga. Nuestro código fuente está disponible en https://github.com/NVlabs/SparDA.
English
Sparse attention reduces compute and memory bandwidth for long-context LLM inference. However, two key challenges remain: (1) KV cache capacity still grows with sequence length, and offloading to CPU memory introduces a PCIe transfer bottleneck; (2) the sparse selection step itself retains O(T^2) complexity and can dominate attention cost at long contexts. We propose SparDA, a decoupled sparse attention architecture that introduces a fourth per-layer projection, the Forecast, alongside Query, Key, and Value. The Forecast predicts the KV blocks needed by the next layer, enabling lookahead selection that overlaps CPU-to-GPU prefetch with current-layer execution. Because Forecast is decoupled from the attention query, our GQA implementation uses one Forecast head per GQA group, reducing selection overhead versus the original multi-head selector. SparDA adds <0.5% parameters and trains only the Forecast projections by matching the original selector's attention distribution. On two sparse-pretrained 8B models, SparDA matches or slightly improves accuracy and delivers up to 1.25times prefill speedup and 1.7times decode speedup over the sparse-attention offload baseline. By enabling larger feasible batch sizes on a single GPU, SparDA further reaches up to 5.3times higher decode throughput than the non-offload sparse baseline. Our source code is available at https://github.com/NVlabs/SparDA.