TPLA : Attention Latente Parallèle en Tenseurs pour une Inférence Efficace de Préremplissage et de Décodage Disagrégé
TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference
August 21, 2025
papers.authors: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI
papers.abstract
L'Attention Latente Multi-Têtes (Multi-Head Latent Attention, MLA), introduite dans DeepSeek-V2, compresse les états clé-valeur en un vecteur latent de faible rang, ne conservant en mémoire que ce vecteur pour réduire l'utilisation de la mémoire. Cependant, dans le parallélisme tensoriel (Tensor Parallelism, TP), les têtes d'attention sont calculées sur plusieurs dispositifs, et chaque dispositif doit charger l'intégralité du cache, ce qui réduit l'avantage de MLA par rapport à l'Attention par Requêtes Groupées (Grouped Query Attention, GQA). Nous proposons l'Attention Latente Parallèle Tensorielle (Tensor-Parallel Latent Attention, TPLA) : un schéma qui partitionne à la fois la représentation latente et la dimension d'entrée de chaque tête sur plusieurs dispositifs, effectue l'attention indépendamment par fragment, puis combine les résultats avec une opération de réduction globale (all-reduce). TPLA préserve les avantages d'un cache KV compressé tout en exploitant l'efficacité du TP. Contrairement à l'Attention Latente Groupée (Grouped Latent Attention, GLA), chaque tête dans TPLA exploite toujours la représentation latente complète, conservant ainsi une capacité de représentation plus forte. TPLA est compatible sans modification avec les modèles pré-entraînés utilisant MLA : il prend en charge le pré-remplissage de style MLA et permet un décodage parallèle tensoriel efficace sans nécessiter de réentraînement. L'application de transformations orthogonales simples — par exemple, la transformée de Hadamard ou l'ACP (Analyse en Composantes Principales) — avant le découpage TP atténue davantage les interférences entre fragments, entraînant une dégradation minimale de la précision. En réduisant le cache KV par dispositif pour DeepSeek-V3 et Kimi-K2, nous obtenons des accélérations respectives de 1,79x et 1,93x pour une longueur de contexte de 32 000 tokens, tout en maintenant les performances sur les benchmarks de bon sens et LongBench. TPLA peut être implémenté avec FlashAttention-3, permettant une accélération pratique de bout en bout.
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses
key-value states into a low-rank latent vector, caching only this vector to
reduce memory. In tensor parallelism (TP), however, attention heads are
computed across multiple devices, and each device must load the full cache,
eroding the advantage of MLA over Grouped Query Attention (GQA). We propose
Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the
latent representation and each head's input dimension across devices, performs
attention independently per shard, and then combines results with an
all-reduce. TPLA preserves the benefits of a compressed KV cache while
unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in
TPLA still leverages the full latent representation, maintaining stronger
representational capacity. TPLA is drop-in compatible with models pre-trained
using MLA: it supports MLA-style prefilling and enables efficient
tensor-parallel decoding without retraining. Applying simple orthogonal
transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further
mitigates cross-shard interference, yielding minimal accuracy degradation. By
reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x
and 1.93x speedups, respectively, at a 32K-token context length while
maintaining performance on commonsense and LongBench benchmarks. TPLA can be
implemented with FlashAttention-3, enabling practical end-to-end acceleration.