ChatPaper.aiChatPaper

TPLA: 効率的な分散型プリフィル&デコード推論のためのテンソル並列潜在アテンション

TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference

August 21, 2025
著者: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI

要旨

DeepSeek-V2で導入されたMulti-Head Latent Attention(MLA)は、キー・バリューの状態を低ランクの潜在ベクトルに圧縮し、このベクトルのみをキャッシュすることでメモリを削減します。しかし、テンソル並列処理(TP)では、アテンションヘッドが複数のデバイスにまたがって計算され、各デバイスはフルキャッシュをロードする必要があるため、MLAのGrouped Query Attention(GQA)に対する利点が損なわれます。本論文では、Tensor-Parallel Latent Attention(TPLA)を提案します。TPLAは、潜在表現と各ヘッドの入力次元をデバイス間で分割し、シャードごとに独立してアテンションを実行し、その後all-reduceで結果を結合する方式です。TPLAは、圧縮されたKVキャッシュの利点を維持しながら、TPの効率性を引き出します。Grouped Latent Attention(GLA)とは異なり、TPLAの各ヘッドは依然として完全な潜在表現を活用し、より強力な表現能力を維持します。TPLAは、MLAを使用して事前学習されたモデルにそのまま適用可能であり、MLAスタイルのプリフィリングをサポートし、再学習なしで効率的なテンソル並列デコードを可能にします。TPスライシングの前に、アダマール変換やPCAなどの単純な直交変換を適用することで、シャード間の干渉をさらに軽減し、精度の低下を最小限に抑えます。DeepSeek-V3とKimi-K2において、32Kトークンのコンテキスト長で、それぞれ1.79倍と1.93倍の高速化を達成し、常識推論およびLongBenchベンチマークでの性能を維持します。TPLAはFlashAttention-3で実装可能であり、実用的なエンドツーエンドの高速化を実現します。
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses key-value states into a low-rank latent vector, caching only this vector to reduce memory. In tensor parallelism (TP), however, attention heads are computed across multiple devices, and each device must load the full cache, eroding the advantage of MLA over Grouped Query Attention (GQA). We propose Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the latent representation and each head's input dimension across devices, performs attention independently per shard, and then combines results with an all-reduce. TPLA preserves the benefits of a compressed KV cache while unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in TPLA still leverages the full latent representation, maintaining stronger representational capacity. TPLA is drop-in compatible with models pre-trained using MLA: it supports MLA-style prefilling and enables efficient tensor-parallel decoding without retraining. Applying simple orthogonal transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further mitigates cross-shard interference, yielding minimal accuracy degradation. By reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x and 1.93x speedups, respectively, at a 32K-token context length while maintaining performance on commonsense and LongBench benchmarks. TPLA can be implemented with FlashAttention-3, enabling practical end-to-end acceleration.
PDF52August 25, 2025