TPLA: Tensor-Parallele Latente Aufmerksamkeit für effiziente disaggregierte Prefill- und Decode-Inferenz
TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference
August 21, 2025
papers.authors: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI
papers.abstract
Multi-Head Latent Attention (MLA), eingeführt in DeepSeek-V2, komprimiert
Key-Value-Zustände in einen niedrigrangigen latenten Vektor und speichert nur diesen Vektor, um den Speicherbedarf zu reduzieren. Bei Tensor-Parallelität (TP) jedoch werden die Aufmerksamkeitsköpfe über mehrere Geräte hinweg berechnet, und jedes Gerät muss den gesamten Cache laden, wodurch der Vorteil von MLA gegenüber Grouped Query Attention (GQA) geschmälert wird. Wir schlagen Tensor-Parallel Latent Attention (TPLA) vor: ein Schema, das sowohl die latente Repräsentation als auch die Eingabedimension jedes Kopfes über die Geräte partitioniert, die Aufmerksamkeit unabhängig pro Shard berechnet und dann die Ergebnisse mit einem All-Reduce kombiniert. TPLA bewahrt die Vorteile eines komprimierten KV-Caches, während es die Effizienz von TP freisetzt. Im Gegensatz zu Grouped Latent Attention (GLA) nutzt jeder Kopf in TPLA weiterhin die vollständige latente Repräsentation, wodurch eine stärkere Repräsentationskapazität erhalten bleibt. TPLA ist abwärtskompatibel mit Modellen, die mit MLA vortrainiert wurden: Es unterstützt MLA-ähnliches Prefilling und ermöglicht effizientes tensorparalleles Decodieren ohne Neutraining. Die Anwendung einfacher orthogonaler Transformationen – z.B. der Hadamard-Transformation oder PCA – vor dem TP-Slicing mildert weiterhin die Interferenz zwischen den Shards, was zu minimaler Genauigkeitseinbuße führt. Durch die Reduzierung des KV-Caches pro Gerät für DeepSeek-V3 und Kimi-K2 erreichen wir jeweils eine Beschleunigung um den Faktor 1,79x und 1,93x bei einer Kontextlänge von 32K Tokens, während die Leistung auf Commonsense- und LongBench-Benchmarks erhalten bleibt. TPLA kann mit FlashAttention-3 implementiert werden, was eine praktische end-to-end-Beschleunigung ermöglicht.
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses
key-value states into a low-rank latent vector, caching only this vector to
reduce memory. In tensor parallelism (TP), however, attention heads are
computed across multiple devices, and each device must load the full cache,
eroding the advantage of MLA over Grouped Query Attention (GQA). We propose
Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the
latent representation and each head's input dimension across devices, performs
attention independently per shard, and then combines results with an
all-reduce. TPLA preserves the benefits of a compressed KV cache while
unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in
TPLA still leverages the full latent representation, maintaining stronger
representational capacity. TPLA is drop-in compatible with models pre-trained
using MLA: it supports MLA-style prefilling and enables efficient
tensor-parallel decoding without retraining. Applying simple orthogonal
transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further
mitigates cross-shard interference, yielding minimal accuracy degradation. By
reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x
and 1.93x speedups, respectively, at a 32K-token context length while
maintaining performance on commonsense and LongBench benchmarks. TPLA can be
implemented with FlashAttention-3, enabling practical end-to-end acceleration.