Attention Multi-Têtes à Faible Rang
Multi-Head Low-Rank Attention
March 2, 2026
Auteurs: Songtao Liu, Hongwu Peng, Zhiwei Zhang, Zhengyu Chen, Yue Guo
cs.AI
Résumé
L'inférence à contexte long dans les grands modèles de langage est limitée par le chargement du cache Clé-Valeur (KV) lors de l'étape de décodage, où la nature séquentielle de la génération nécessite de transférer répétitivement le cache KV de la mémoire haute bande passante (HBM) externe vers la mémoire statique à accès aléatoire (SRAM) interne à chaque étape. Bien que l'attention latente multi-têtes (MLA) réduise considérablement la taille totale du cache KV, elle souffre d'un goulot d'étranglement de partitionnement lors du décodage distribué via le parallélisme par tenseurs (TP). Puisque sa tête latente unique ne peut pas être partitionnée, chaque dispositif est contraint de charger redondamment le cache KV complet pour chaque token, consommant un trafic mémoire excessif et réduisant les avantages du TP comme le partitionnement des poids. Dans ce travail, nous proposons l'attention à faible rang multi-têtes (MLRA), qui permet des états latents partitionnables pour un décodage TP à 4 voies efficace. Des expériences approfondies montrent que MLRA atteint une perplexité et des performances sur tâches en aval à l'état de l'art, tout en offrant une accélération du décodage de 2,8 fois par rapport à MLA. Le code est disponible à l'adresse https://github.com/SongtaoLiu0823/MLRA. Les poids pré-entraînés, ainsi que les données d'entraînement et d'évaluation, sont disponibles à l'adresse https://huggingface.co/Soughing/MLRA.
English
Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8times decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.