TPLA: Atenção Latente com Paralelismo de Tensores para Inferência Eficiente de Preenchimento e Decodificação Desagregados
TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference
August 21, 2025
Autores: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI
Resumo
A Multi-Head Latent Attention (MLA), introduzida no DeepSeek-V2, comprime os estados de chave-valor em um vetor latente de baixa classificação, armazenando em cache apenas esse vetor para reduzir a memória. No paralelismo de tensores (TP), no entanto, as cabeças de atenção são computadas em vários dispositivos, e cada dispositivo deve carregar o cache completo, o que diminui a vantagem da MLA em relação à Grouped Query Attention (GQA). Propomos o Tensor-Parallel Latent Attention (TPLA): um esquema que particiona tanto a representação latente quanto a dimensão de entrada de cada cabeça entre dispositivos, realiza a atenção de forma independente por fragmento e, em seguida, combina os resultados com um all-reduce. O TPLA preserva os benefícios de um cache KV comprimido enquanto desbloqueia a eficiência do TP. Diferente do Grouped Latent Attention (GLA), cada cabeça no TPLA ainda aproveita a representação latente completa, mantendo uma capacidade representacional mais forte. O TPLA é compatível com modelos pré-treinados usando MLA: ele suporta o preenchimento no estilo MLA e permite a decodificação eficiente em paralelismo de tensores sem retreinamento. A aplicação de transformações ortogonais simples — por exemplo, a transformada de Hadamard ou PCA — antes do corte do TP mitiga ainda mais a interferência entre fragmentos, resultando em uma degradação mínima da precisão. Ao reduzir o cache KV por dispositivo para o DeepSeek-V3 e o Kimi-K2, alcançamos acelerações de 1,79x e 1,93x, respectivamente, em um contexto de 32K tokens, mantendo o desempenho em benchmarks de senso comum e LongBench. O TPLA pode ser implementado com o FlashAttention-3, permitindo uma aceleração prática de ponta a ponta.
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses
key-value states into a low-rank latent vector, caching only this vector to
reduce memory. In tensor parallelism (TP), however, attention heads are
computed across multiple devices, and each device must load the full cache,
eroding the advantage of MLA over Grouped Query Attention (GQA). We propose
Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the
latent representation and each head's input dimension across devices, performs
attention independently per shard, and then combines results with an
all-reduce. TPLA preserves the benefits of a compressed KV cache while
unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in
TPLA still leverages the full latent representation, maintaining stronger
representational capacity. TPLA is drop-in compatible with models pre-trained
using MLA: it supports MLA-style prefilling and enables efficient
tensor-parallel decoding without retraining. Applying simple orthogonal
transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further
mitigates cross-shard interference, yielding minimal accuracy degradation. By
reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x
and 1.93x speedups, respectively, at a 32K-token context length while
maintaining performance on commonsense and LongBench benchmarks. TPLA can be
implemented with FlashAttention-3, enabling practical end-to-end acceleration.