TPLA: 効率的な分散型プリフィル&デコード推論のためのテンソル並列潜在アテンション
TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference
August 21, 2025
著者: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI
要旨
DeepSeek-V2で導入されたMulti-Head Latent Attention(MLA)は、キー・バリューの状態を低ランクの潜在ベクトルに圧縮し、このベクトルのみをキャッシュすることでメモリを削減します。しかし、テンソル並列処理(TP)では、アテンションヘッドが複数のデバイスにまたがって計算され、各デバイスはフルキャッシュをロードする必要があるため、MLAのGrouped Query Attention(GQA)に対する利点が損なわれます。本論文では、Tensor-Parallel Latent Attention(TPLA)を提案します。TPLAは、潜在表現と各ヘッドの入力次元をデバイス間で分割し、シャードごとに独立してアテンションを実行し、その後all-reduceで結果を結合する方式です。TPLAは、圧縮されたKVキャッシュの利点を維持しながら、TPの効率性を引き出します。Grouped Latent Attention(GLA)とは異なり、TPLAの各ヘッドは依然として完全な潜在表現を活用し、より強力な表現能力を維持します。TPLAは、MLAを使用して事前学習されたモデルにそのまま適用可能であり、MLAスタイルのプリフィリングをサポートし、再学習なしで効率的なテンソル並列デコードを可能にします。TPスライシングの前に、アダマール変換やPCAなどの単純な直交変換を適用することで、シャード間の干渉をさらに軽減し、精度の低下を最小限に抑えます。DeepSeek-V3とKimi-K2において、32Kトークンのコンテキスト長で、それぞれ1.79倍と1.93倍の高速化を達成し、常識推論およびLongBenchベンチマークでの性能を維持します。TPLAはFlashAttention-3で実装可能であり、実用的なエンドツーエンドの高速化を実現します。
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses
key-value states into a low-rank latent vector, caching only this vector to
reduce memory. In tensor parallelism (TP), however, attention heads are
computed across multiple devices, and each device must load the full cache,
eroding the advantage of MLA over Grouped Query Attention (GQA). We propose
Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the
latent representation and each head's input dimension across devices, performs
attention independently per shard, and then combines results with an
all-reduce. TPLA preserves the benefits of a compressed KV cache while
unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in
TPLA still leverages the full latent representation, maintaining stronger
representational capacity. TPLA is drop-in compatible with models pre-trained
using MLA: it supports MLA-style prefilling and enables efficient
tensor-parallel decoding without retraining. Applying simple orthogonal
transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further
mitigates cross-shard interference, yielding minimal accuracy degradation. By
reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x
and 1.93x speedups, respectively, at a 32K-token context length while
maintaining performance on commonsense and LongBench benchmarks. TPLA can be
implemented with FlashAttention-3, enabling practical end-to-end acceleration.