ChatPaper.aiChatPaper

SparDA:用於高效長上下文LLM推理的稀疏解耦注意力

SparDA: Sparse Decoupled Attention for Efficient Long-Context LLM Inference

June 3, 2026
作者: Yaosheng Fu, Guangxuan Xiao, Xin Dong, Song Han, Oreste Villa
cs.AI

摘要

稀疏注意力機制可降低長上下文大型語言模型推論的計算量與記憶體頻寬需求。然而,仍有兩項關鍵挑戰待解決:(1) KV快取容量隨序列長度增長,若卸載至CPU記憶體則會產生PCIe傳輸瓶頸;(2) 稀疏選取步驟本身仍維持O(T²) 的計算複雜度,在長上下文情境下可能主導注意力機制的整體成本。我們提出SparDA,一種分離式稀疏注意力架構,它在查詢、鍵與值之外,為每一層新增第四個投影——預測投影(Forecast)。該預測投影能推斷下一層所需的KV區塊,從而實現前瞻性選取,可將CPU到GPU的預提取與當前層的執行重疊。由於預測投影與注意力查詢相互分離,本研究的GQA實作中每個GQA群組僅使用一個預測頭,相較於原始多頭選取器可降低選取開銷。SparDA僅增加不到0.5%的參數量,且僅透過匹配原始選取器的注意力分佈來訓練預測投影。在兩個經過稀疏預訓練的8B模型上,SparDA達到與基準相當或略優的正確率,相較於稀疏注意力卸載基準可提供高達1.25倍的預填充加速及1.7倍的解碼加速。透過在單一GPU上支援更大的可行批次規模,SparDA更進一步達到比未卸載稀疏基準高出5.3倍的解碼吞吐量。我們的原始碼已於 https://github.com/NVlabs/SparDA 公開。
English
Sparse attention reduces compute and memory bandwidth for long-context LLM inference. However, two key challenges remain: (1) KV cache capacity still grows with sequence length, and offloading to CPU memory introduces a PCIe transfer bottleneck; (2) the sparse selection step itself retains O(T^2) complexity and can dominate attention cost at long contexts. We propose SparDA, a decoupled sparse attention architecture that introduces a fourth per-layer projection, the Forecast, alongside Query, Key, and Value. The Forecast predicts the KV blocks needed by the next layer, enabling lookahead selection that overlaps CPU-to-GPU prefetch with current-layer execution. Because Forecast is decoupled from the attention query, our GQA implementation uses one Forecast head per GQA group, reducing selection overhead versus the original multi-head selector. SparDA adds <0.5% parameters and trains only the Forecast projections by matching the original selector's attention distribution. On two sparse-pretrained 8B models, SparDA matches or slightly improves accuracy and delivers up to 1.25times prefill speedup and 1.7times decode speedup over the sparse-attention offload baseline. By enabling larger feasible batch sizes on a single GPU, SparDA further reaches up to 5.3times higher decode throughput than the non-offload sparse baseline. Our source code is available at https://github.com/NVlabs/SparDA.