SparDA: 効率的な長文脈LLM推論のためのスパース分離アテンション
SparDA: Sparse Decoupled Attention for Efficient Long-Context LLM Inference
June 3, 2026
著者: Yaosheng Fu, Guangxuan Xiao, Xin Dong, Song Han, Oreste Villa
cs.AI
要旨
スパースアテンションは、長コンテキストのLLM推論における計算とメモリ帯域幅を削減する。しかし、二つの主要な課題が残る:(1)KVキャッシュ容量は依然としてシーケンス長とともに増加し、CPUメモリへのオフロードはPCIe転送のボトルネックをもたらす;(2)スパース選択ステップ自体がO(T^2)の複雑性を保持し、長コンテキストではアテンションコストを支配しうる。我々はSparDAを提案する。これは、Query、Key、Valueに加えて、第4のレイヤーごとの投影であるForecastを導入する分離型スパースアテンションアーキテクチャである。Forecastは次のレイヤーで必要とされるKVブロックを予測し、現在のレイヤーの実行とCPUからGPUへのプリフェッチをオーバーラップする先読み選択を可能にする。Forecastはアテンションクエリから分離されているため、我々のGQA実装ではGQAグループごとに一つのForecastヘッドを使用し、元のマルチヘッドセレクタと比較して選択オーバーヘッドを削減する。SparDAは0.5%未満のパラメータを追加し、元のセレクタのアテンション分布に一致させることでForecast投影のみを訓練する。二つのスパース事前学習済み8Bモデルにおいて、SparDAは精度を同等かわずかに向上させ、スパースアテンションオフロードベースラインに対して最大1.25倍のプリフィル高速化と1.7倍のデコード高速化を達成する。単一GPUでより大きな実現可能なバッチサイズを可能にすることにより、SparDAはさらに、オフロードなしのスパースベースラインと比較して最大5.3倍高いデコードスループットに達する。我々のソースコードはhttps://github.com/NVlabs/SparDAで入手可能である。
English
Sparse attention reduces compute and memory bandwidth for long-context LLM inference. However, two key challenges remain: (1) KV cache capacity still grows with sequence length, and offloading to CPU memory introduces a PCIe transfer bottleneck; (2) the sparse selection step itself retains O(T^2) complexity and can dominate attention cost at long contexts. We propose SparDA, a decoupled sparse attention architecture that introduces a fourth per-layer projection, the Forecast, alongside Query, Key, and Value. The Forecast predicts the KV blocks needed by the next layer, enabling lookahead selection that overlaps CPU-to-GPU prefetch with current-layer execution. Because Forecast is decoupled from the attention query, our GQA implementation uses one Forecast head per GQA group, reducing selection overhead versus the original multi-head selector. SparDA adds <0.5% parameters and trains only the Forecast projections by matching the original selector's attention distribution. On two sparse-pretrained 8B models, SparDA matches or slightly improves accuracy and delivers up to 1.25times prefill speedup and 1.7times decode speedup over the sparse-attention offload baseline. By enabling larger feasible batch sizes on a single GPU, SparDA further reaches up to 5.3times higher decode throughput than the non-offload sparse baseline. Our source code is available at https://github.com/NVlabs/SparDA.