ChatPaper.aiChatPaper

TPLA: 효율적인 분산 프리필 및 디코드 추론을 위한 텐서 병렬 잠재 어텐션

TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference

August 21, 2025
저자: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI

초록

DeepSeek-V2에서 소개된 Multi-Head Latent Attention(MLA)은 키-값 상태를 저차원 잠재 벡터로 압축하고, 이 벡터만 캐싱하여 메모리를 절약합니다. 그러나 텐서 병렬화(TP)에서는 어텐션 헤드가 여러 장치에 걸쳐 계산되며, 각 장치는 전체 캐시를 로드해야 하기 때문에 MLA의 장점이 Grouped Query Attention(GQA)에 비해 약화됩니다. 우리는 Tensor-Parallel Latent Attention(TPLA)을 제안합니다: 이 방식은 잠재 표현과 각 헤드의 입력 차원을 장치 간에 분할하고, 각 샤드에서 독립적으로 어텐션을 수행한 후 all-reduce를 통해 결과를 결합합니다. TPLA는 압축된 KV 캐시의 이점을 유지하면서 TP 효율성을 극대화합니다. Grouped Latent Attention(GLA)과 달리, TPLA의 모든 헤드는 여전히 전체 잠재 표현을 활용하여 더 강력한 표현 능력을 유지합니다. TPLA는 MLA를 사용해 사전 학습된 모델과 즉시 호환됩니다: MLA 스타일의 프리필링을 지원하고 재학습 없이도 효율적인 텐서 병렬 디코딩을 가능하게 합니다. TP 슬라이싱 전에 Hadamard 변환이나 PCA와 같은 간단한 직교 변환을 적용하면 샤드 간 간섭을 추가로 완화하여 정확도 저하를 최소화할 수 있습니다. DeepSeek-V3과 Kimi-K2에서 장치당 KV 캐시를 줄임으로써, 32K 토큰 컨텍스트 길이에서 각각 1.79배와 1.93배의 속도 향상을 달성하면서도 상식 및 LongBench 벤치마크에서 성능을 유지했습니다. TPLA는 FlashAttention-3로 구현 가능하여 실용적인 종단 간 가속을 가능하게 합니다.
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses key-value states into a low-rank latent vector, caching only this vector to reduce memory. In tensor parallelism (TP), however, attention heads are computed across multiple devices, and each device must load the full cache, eroding the advantage of MLA over Grouped Query Attention (GQA). We propose Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the latent representation and each head's input dimension across devices, performs attention independently per shard, and then combines results with an all-reduce. TPLA preserves the benefits of a compressed KV cache while unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in TPLA still leverages the full latent representation, maintaining stronger representational capacity. TPLA is drop-in compatible with models pre-trained using MLA: it supports MLA-style prefilling and enables efficient tensor-parallel decoding without retraining. Applying simple orthogonal transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further mitigates cross-shard interference, yielding minimal accuracy degradation. By reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x and 1.93x speedups, respectively, at a 32K-token context length while maintaining performance on commonsense and LongBench benchmarks. TPLA can be implemented with FlashAttention-3, enabling practical end-to-end acceleration.
PDF52August 25, 2025