SparDA: 稀疏解耦注意力用于高效长上下文大语言模型推理
SparDA: Sparse Decoupled Attention for Efficient Long-Context LLM Inference
June 3, 2026
作者: Yaosheng Fu, Guangxuan Xiao, Xin Dong, Song Han, Oreste Villa
cs.AI
摘要
稀疏注意力机制降低了长上下文大语言模型推理时的计算量和内存带宽消耗。然而,仍存在两个关键挑战:(1)KV缓存容量随序列长度增长,卸载至CPU内存会引入PCIe传输瓶颈;(2)稀疏选择步骤本身仍保持O(T²)复杂度,在长上下文场景下可能占据注意力计算的主要成本。我们提出SparDA——一种解耦稀疏注意力架构,在查询、键、值之外引入第四层投影:预测器。该预测器可预判下一层所需的KV块,实现前瞻性选择,使GPU到CPU的预取与当前层执行重叠。由于预测器与注意力查询解耦,我们的分组查询注意力实现方案在每个GQA组中使用一个预测头,相较于原始多头选择器降低了选择开销。SparDA仅增加少于0.5%的参数,通过匹配原始选择器的注意力分布仅训练预测投影。在两个稀疏预训练的8B模型上,SparDA在保持或略微提升精度的同时,相较于稀疏注意力卸载基线实现了高达1.25倍的预填充加速和1.7倍的解码加速。通过支持单GPU上更大的可行批量大小,SparDA进一步使解码吞吐量比未卸载的稀疏基线提升高达5.3倍。我们的源代码已开源至https://github.com/NVlabs/SparDA。
English
Sparse attention reduces compute and memory bandwidth for long-context LLM inference. However, two key challenges remain: (1) KV cache capacity still grows with sequence length, and offloading to CPU memory introduces a PCIe transfer bottleneck; (2) the sparse selection step itself retains O(T^2) complexity and can dominate attention cost at long contexts. We propose SparDA, a decoupled sparse attention architecture that introduces a fourth per-layer projection, the Forecast, alongside Query, Key, and Value. The Forecast predicts the KV blocks needed by the next layer, enabling lookahead selection that overlaps CPU-to-GPU prefetch with current-layer execution. Because Forecast is decoupled from the attention query, our GQA implementation uses one Forecast head per GQA group, reducing selection overhead versus the original multi-head selector. SparDA adds <0.5% parameters and trains only the Forecast projections by matching the original selector's attention distribution. On two sparse-pretrained 8B models, SparDA matches or slightly improves accuracy and delivers up to 1.25times prefill speedup and 1.7times decode speedup over the sparse-attention offload baseline. By enabling larger feasible batch sizes on a single GPU, SparDA further reaches up to 5.3times higher decode throughput than the non-offload sparse baseline. Our source code is available at https://github.com/NVlabs/SparDA.