TPLA: Atención Latente en Paralelo de Tensores para una Inferencia Eficiente de Prellenado y Decodificación Desagregada
TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference
August 21, 2025
Autores: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI
Resumen
Multi-Head Latent Attention (MLA), introducido en DeepSeek-V2, comprime los estados clave-valor en un vector latente de bajo rango, almacenando en caché solo este vector para reducir el uso de memoria. Sin embargo, en el paralelismo de tensores (TP), las cabezas de atención se calculan en múltiples dispositivos, y cada dispositivo debe cargar la caché completa, lo que erosiona la ventaja de MLA sobre Grouped Query Attention (GQA). Proponemos Tensor-Parallel Latent Attention (TPLA): un esquema que divide tanto la representación latente como la dimensión de entrada de cada cabeza entre dispositivos, realiza la atención de manera independiente por fragmento y luego combina los resultados con un all-reduce. TPLA preserva los beneficios de una caché KV comprimida mientras aprovecha la eficiencia del TP. A diferencia de Grouped Latent Attention (GLA), cada cabeza en TPLA sigue aprovechando la representación latente completa, manteniendo una mayor capacidad de representación. TPLA es compatible de manera directa con modelos preentrenados usando MLA: admite el prefilling al estilo MLA y permite una decodificación eficiente en paralelismo de tensores sin necesidad de reentrenamiento. La aplicación de transformaciones ortogonales simples —por ejemplo, la transformada de Hadamard o PCA— antes del corte en TP mitiga aún más la interferencia entre fragmentos, resultando en una degradación mínima de la precisión. Al reducir la caché KV por dispositivo para DeepSeek-V3 y Kimi-K2, logramos aceleraciones de 1.79x y 1.93x, respectivamente, en un contexto de 32K tokens, manteniendo el rendimiento en pruebas de sentido común y LongBench. TPLA puede implementarse con FlashAttention-3, permitiendo una aceleración práctica de extremo a extremo.
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses
key-value states into a low-rank latent vector, caching only this vector to
reduce memory. In tensor parallelism (TP), however, attention heads are
computed across multiple devices, and each device must load the full cache,
eroding the advantage of MLA over Grouped Query Attention (GQA). We propose
Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the
latent representation and each head's input dimension across devices, performs
attention independently per shard, and then combines results with an
all-reduce. TPLA preserves the benefits of a compressed KV cache while
unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in
TPLA still leverages the full latent representation, maintaining stronger
representational capacity. TPLA is drop-in compatible with models pre-trained
using MLA: it supports MLA-style prefilling and enables efficient
tensor-parallel decoding without retraining. Applying simple orthogonal
transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further
mitigates cross-shard interference, yielding minimal accuracy degradation. By
reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x
and 1.93x speedups, respectively, at a 32K-token context length while
maintaining performance on commonsense and LongBench benchmarks. TPLA can be
implemented with FlashAttention-3, enabling practical end-to-end acceleration.