ChatPaper.aiChatPaper

SparDA: 효율적인 장문맥 LLM 추론을 위한 희소 분리 어텐션

SparDA: Sparse Decoupled Attention for Efficient Long-Context LLM Inference

June 3, 2026
저자: Yaosheng Fu, Guangxuan Xiao, Xin Dong, Song Han, Oreste Villa
cs.AI

초록

스파스 어텐션(Sparse attention)은 긴 문맥을 처리하는 대규모 언어 모델(LLM) 추론에서 계산량과 메모리 대역폭을 줄여준다. 그러나 여전히 두 가지 주요 과제가 남아 있다: (1) KV 캐시 용량이 시퀀스 길이에 따라 증가하며, 이를 CPU 메모리로 오프로드할 경우 PCIe 전송 병목이 발생한다. (2) 스파스 선택 단계 자체가 O(T²) 복잡도를 유지하여 긴 문맥에서 어텐션 비용을 지배할 수 있다. 본 논문에서는 SparDA라는 분리형 스파스 어텐션 아키텍처를 제안한다. SparDA는 Query, Key, Value 외에 네 번째 층별 투영(projection)인 Forecast를 도입한다. Forecast는 다음 층에서 필요한 KV 블록을 예측하여, 현재 층 실행과 CPU→GPU 프리페치를 중첩시키는 선행 선택(lookahead selection)을 가능하게 한다. Forecast는 어텐션 쿼리와 분리되어 있으므로, 본 구현에서는 GQA(Grouped Query Attention) 그룹당 하나의 Forecast 헤드를 사용하여 기존 다중 헤드 선택기 대비 선택 오버헤드를 줄인다. SparDA는 전체 매개변수의 0.5% 미만을 추가하며, 기존 선택기의 어텐션 분포를 일치시키는 방식으로 Forecast 투영만 학습시킨다. 두 개의 스파스 사전학습 8B 모델에서 SparDA는 정확도를 유지하거나 소폭 향상시키며, 스파스 어텐션 오프로드 기준선 대비 최대 1.25배의 프리필 속도 향상과 1.7배의 디코드 속도 향상을 제공한다. 단일 GPU에서 더 큰 배치 크기를 가능하게 함으로써, SparDA는 오프로드하지 않는 스파스 기준선 대비 최대 5.3배 높은 디코드 처리량을 달성한다. 소스 코드는 https://github.com/NVlabs/SparDA에서 확인할 수 있다.
English
Sparse attention reduces compute and memory bandwidth for long-context LLM inference. However, two key challenges remain: (1) KV cache capacity still grows with sequence length, and offloading to CPU memory introduces a PCIe transfer bottleneck; (2) the sparse selection step itself retains O(T^2) complexity and can dominate attention cost at long contexts. We propose SparDA, a decoupled sparse attention architecture that introduces a fourth per-layer projection, the Forecast, alongside Query, Key, and Value. The Forecast predicts the KV blocks needed by the next layer, enabling lookahead selection that overlaps CPU-to-GPU prefetch with current-layer execution. Because Forecast is decoupled from the attention query, our GQA implementation uses one Forecast head per GQA group, reducing selection overhead versus the original multi-head selector. SparDA adds <0.5% parameters and trains only the Forecast projections by matching the original selector's attention distribution. On two sparse-pretrained 8B models, SparDA matches or slightly improves accuracy and delivers up to 1.25times prefill speedup and 1.7times decode speedup over the sparse-attention offload baseline. By enabling larger feasible batch sizes on a single GPU, SparDA further reaches up to 5.3times higher decode throughput than the non-offload sparse baseline. Our source code is available at https://github.com/NVlabs/SparDA.