ChatPaper.aiChatPaper

TPLA: Atención Latente en Paralelo de Tensores para una Inferencia Eficiente de Prellenado y Decodificación Desagregada

TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference

August 21, 2025
Autores: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI

Resumen

Multi-Head Latent Attention (MLA), introducido en DeepSeek-V2, comprime los estados clave-valor en un vector latente de bajo rango, almacenando en caché solo este vector para reducir el uso de memoria. Sin embargo, en el paralelismo de tensores (TP), las cabezas de atención se calculan en múltiples dispositivos, y cada dispositivo debe cargar la caché completa, lo que erosiona la ventaja de MLA sobre Grouped Query Attention (GQA). Proponemos Tensor-Parallel Latent Attention (TPLA): un esquema que divide tanto la representación latente como la dimensión de entrada de cada cabeza entre dispositivos, realiza la atención de manera independiente por fragmento y luego combina los resultados con un all-reduce. TPLA preserva los beneficios de una caché KV comprimida mientras aprovecha la eficiencia del TP. A diferencia de Grouped Latent Attention (GLA), cada cabeza en TPLA sigue aprovechando la representación latente completa, manteniendo una mayor capacidad de representación. TPLA es compatible de manera directa con modelos preentrenados usando MLA: admite el prefilling al estilo MLA y permite una decodificación eficiente en paralelismo de tensores sin necesidad de reentrenamiento. La aplicación de transformaciones ortogonales simples —por ejemplo, la transformada de Hadamard o PCA— antes del corte en TP mitiga aún más la interferencia entre fragmentos, resultando en una degradación mínima de la precisión. Al reducir la caché KV por dispositivo para DeepSeek-V3 y Kimi-K2, logramos aceleraciones de 1.79x y 1.93x, respectivamente, en un contexto de 32K tokens, manteniendo el rendimiento en pruebas de sentido común y LongBench. TPLA puede implementarse con FlashAttention-3, permitiendo una aceleración práctica de extremo a extremo.
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses key-value states into a low-rank latent vector, caching only this vector to reduce memory. In tensor parallelism (TP), however, attention heads are computed across multiple devices, and each device must load the full cache, eroding the advantage of MLA over Grouped Query Attention (GQA). We propose Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the latent representation and each head's input dimension across devices, performs attention independently per shard, and then combines results with an all-reduce. TPLA preserves the benefits of a compressed KV cache while unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in TPLA still leverages the full latent representation, maintaining stronger representational capacity. TPLA is drop-in compatible with models pre-trained using MLA: it supports MLA-style prefilling and enables efficient tensor-parallel decoding without retraining. Applying simple orthogonal transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further mitigates cross-shard interference, yielding minimal accuracy degradation. By reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x and 1.93x speedups, respectively, at a 32K-token context length while maintaining performance on commonsense and LongBench benchmarks. TPLA can be implemented with FlashAttention-3, enabling practical end-to-end acceleration.
PDF52August 25, 2025