TPLA:面向高效解耦預填充與解碼推理的張量並行潛在注意力機制
TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference
August 21, 2025
作者: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI
摘要
多頭潛在注意力機制(Multi-Head Latent Attention, MLA)在DeepSeek-V2中被引入,其將鍵值狀態壓縮為低秩潛在向量,僅緩存此向量以減少記憶體佔用。然而,在張量並行(Tensor Parallelism, TP)中,注意力頭被分散在多個設備上計算,每個設備都需加載完整的緩存,這削弱了MLA相較於分組查詢注意力(Grouped Query Attention, GQA)的優勢。我們提出了張量並行潛在注意力機制(Tensor-Parallel Latent Attention, TPLA):該方案將潛在表示和每個頭的輸入維度在設備間進行分區,獨立執行每個分片的注意力計算,然後通過全歸約(all-reduce)合併結果。TPLA在保持壓縮鍵值緩存優勢的同時,釋放了TP的效率。與分組潛在注意力(Grouped Latent Attention, GLA)不同,TPLA中的每個頭仍能利用完整的潛在表示,維持更強的表示能力。TPLA與使用MLA預訓練的模型完全兼容:它支持MLA風格的預填充,並能在不重新訓練的情況下實現高效的張量並行解碼。在TP分片前應用簡單的正交變換——例如哈達瑪變換或主成分分析(PCA)——進一步減輕了跨分片干擾,使精度下降最小化。通過減少DeepSeek-V3和Kimi-K2的每設備鍵值緩存,我們在32K令牌上下文長度下分別實現了1.79倍和1.93倍的加速,同時在常識推理和LongBench基準測試中保持了性能。TPLA可與FlashAttention-3結合實現,從而實現實用的端到端加速。
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses
key-value states into a low-rank latent vector, caching only this vector to
reduce memory. In tensor parallelism (TP), however, attention heads are
computed across multiple devices, and each device must load the full cache,
eroding the advantage of MLA over Grouped Query Attention (GQA). We propose
Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the
latent representation and each head's input dimension across devices, performs
attention independently per shard, and then combines results with an
all-reduce. TPLA preserves the benefits of a compressed KV cache while
unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in
TPLA still leverages the full latent representation, maintaining stronger
representational capacity. TPLA is drop-in compatible with models pre-trained
using MLA: it supports MLA-style prefilling and enables efficient
tensor-parallel decoding without retraining. Applying simple orthogonal
transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further
mitigates cross-shard interference, yielding minimal accuracy degradation. By
reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x
and 1.93x speedups, respectively, at a 32K-token context length while
maintaining performance on commonsense and LongBench benchmarks. TPLA can be
implemented with FlashAttention-3, enabling practical end-to-end acceleration.