ChatPaper.aiChatPaper

TPLA:面向高效解耦预填充与解码推理的张量并行潜在注意力机制

TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference

August 21, 2025
作者: Xiaojuan Tang, Fanxu Meng, Pingzhi Tang, Yuxuan Wang, Di Yin, Xing Sun, Muhan Zhang
cs.AI

摘要

多头部潜在注意力机制(Multi-Head Latent Attention, MLA)在DeepSeek-V2中引入,通过将键值状态压缩为低秩潜在向量,仅缓存该向量以减少内存占用。然而,在张量并行(Tensor Parallelism, TP)环境下,注意力头需跨多个设备计算,每台设备必须加载完整的缓存,这削弱了MLA相较于分组查询注意力(Grouped Query Attention, GQA)的优势。为此,我们提出了张量并行潜在注意力机制(Tensor-Parallel Latent Attention, TPLA):一种将潜在表示及每个注意力头的输入维度跨设备分片、独立执行分片内注意力计算,并通过全归约(all-reduce)合并结果的方案。TPLA在保留压缩键值缓存优势的同时,解锁了TP的效率。与分组潜在注意力(Grouped Latent Attention, GLA)不同,TPLA中的每个头仍能利用完整的潜在表示,保持了更强的表征能力。TPLA与使用MLA预训练的模型无缝兼容:支持MLA风格的预填充,并能在无需重新训练的情况下实现高效的张量并行解码。在TP分片前应用简单的正交变换——如哈达玛变换或主成分分析(PCA)——进一步减轻了跨分片干扰,使得精度下降最小化。通过为DeepSeek-V3和Kimi-K2减少每台设备的键值缓存,在32K令牌上下文长度下,我们分别实现了1.79倍和1.93倍的加速,同时在常识推理和LongBench基准测试上保持了性能。TPLA可与FlashAttention-3结合实现,为端到端加速提供了实用方案。
English
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses key-value states into a low-rank latent vector, caching only this vector to reduce memory. In tensor parallelism (TP), however, attention heads are computed across multiple devices, and each device must load the full cache, eroding the advantage of MLA over Grouped Query Attention (GQA). We propose Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the latent representation and each head's input dimension across devices, performs attention independently per shard, and then combines results with an all-reduce. TPLA preserves the benefits of a compressed KV cache while unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in TPLA still leverages the full latent representation, maintaining stronger representational capacity. TPLA is drop-in compatible with models pre-trained using MLA: it supports MLA-style prefilling and enables efficient tensor-parallel decoding without retraining. Applying simple orthogonal transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further mitigates cross-shard interference, yielding minimal accuracy degradation. By reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x and 1.93x speedups, respectively, at a 32K-token context length while maintaining performance on commonsense and LongBench benchmarks. TPLA can be implemented with FlashAttention-3, enabling practical end-to-end acceleration.
PDF52August 25, 2025