SparDA: Atenção Desacoplada Esparsa para Inferência Eficiente de LLM com Contexto Longo
SparDA: Sparse Decoupled Attention for Efficient Long-Context LLM Inference
June 3, 2026
Autores: Yaosheng Fu, Guangxuan Xiao, Xin Dong, Song Han, Oreste Villa
cs.AI
Resumo
A atenção esparsa reduz o uso de computação e largura de banda de memória na inferência de LLMs com contexto longo. No entanto, dois desafios centrais persistem: (1) a capacidade do cache KV ainda cresce com o comprimento da sequência, e o descarregamento para a memória da CPU introduz um gargalo de transferência PCIe; (2) a própria etapa de seleção esparsa mantém complexidade O(T²) e pode dominar o custo da atenção em contextos longos. Propomos o SparDA, uma arquitetura de atenção esparsa desacoplada que introduz uma quarta projeção por camada, a Previsão, ao lado de Consulta, Chave e Valor. A Previsão prevê os blocos KV necessários para a próxima camada, permitindo uma seleção antecipada que sobrepõe a pré-busca da GPU para a CPU com a execução da camada atual. Como a Previsão é desacoplada da consulta de atenção, nossa implementação GQA usa uma cabeça de Previsão por grupo GQA, reduzindo a sobrecarga de seleção em comparação com o seletor multi-cabeça original. O SparDA adiciona <0,5% de parâmetros e treina apenas as projeções de Previsão, igualando a distribuição de atenção do seletor original. Em dois modelos pré-treinados esparsos de 8B, o SparDA iguala ou melhora ligeiramente a acurácia e oferece até 1,25× de aceleração no preenchimento e 1,7× de aceleração na decodificação em relação à linha de base de descarregamento com atenção esparsa. Ao permitir tamanhos de lote viáveis maiores em uma única GPU, o SparDA alcança ainda até 5,3× maior throughput de decodificação do que a linha de base esparsa sem descarregamento. Nosso código-fonte está disponível em https://github.com/NVlabs/SparDA.
English
Sparse attention reduces compute and memory bandwidth for long-context LLM inference. However, two key challenges remain: (1) KV cache capacity still grows with sequence length, and offloading to CPU memory introduces a PCIe transfer bottleneck; (2) the sparse selection step itself retains O(T^2) complexity and can dominate attention cost at long contexts. We propose SparDA, a decoupled sparse attention architecture that introduces a fourth per-layer projection, the Forecast, alongside Query, Key, and Value. The Forecast predicts the KV blocks needed by the next layer, enabling lookahead selection that overlaps CPU-to-GPU prefetch with current-layer execution. Because Forecast is decoupled from the attention query, our GQA implementation uses one Forecast head per GQA group, reducing selection overhead versus the original multi-head selector. SparDA adds <0.5% parameters and trains only the Forecast projections by matching the original selector's attention distribution. On two sparse-pretrained 8B models, SparDA matches or slightly improves accuracy and delivers up to 1.25times prefill speedup and 1.7times decode speedup over the sparse-attention offload baseline. By enabling larger feasible batch sizes on a single GPU, SparDA further reaches up to 5.3times higher decode throughput than the non-offload sparse baseline. Our source code is available at https://github.com/NVlabs/SparDA.