Mehrkopf-Aufmerksamkeit mit niedrigem Rang
Multi-Head Low-Rank Attention
March 2, 2026
Autoren: Songtao Liu, Hongwu Peng, Zhiwei Zhang, Zhengyu Chen, Yue Guo
cs.AI
Zusammenfassung
Die Langkontext-Inferenz in großen Sprachmodellen wird durch das Laden des Key-Value (KV)-Caches während der Dekodierphase zum Engpass, da der sequenzielle Charakter der Generierung eine wiederholte Übertragung des KV-Caches vom Off-Chip-Hochbandbreiten-Speicher (HBM) zum On-Chip-Static-Random-Access-Memory (SRAM) in jedem Schritt erfordert. Während Multi-Head Latent Attention (MLA) die Gesamtgröße des KV-Caches erheblich reduziert, leidet es unter einem Sharding-Engpass bei der verteilten Dekodierung mittels Tensor-Parallelismus (TP). Da sein einzelner latenter Kopf nicht partitioniert werden kann, ist jedes Gerät gezwungen, den vollständigen KV-Cache für jedes Token redundant zu laden, was übermäßigen Speicherverkehr verursacht und Vorteile von TP wie Gewichts-Sharding schmälert. In dieser Arbeit schlagen wir Multi-Head Low-Rank Attention (MLRA) vor, das partitionierbare latente Zustände für eine effiziente 4-Wege-TP-Dekodierung ermöglicht. Umfangreiche Experimente zeigen, dass MLRA state-of-the-art Perplexität und Leistung bei nachgelagerten Aufgaben erreicht und gleichzeitig eine 2,8-fache Beschleunigung der Dekodierung gegenüber MLA liefert. Der Code ist verfügbar unter https://github.com/SongtaoLiu0823/MLRA. Vortrainierte Gewichte sowie die Trainings- und Evaluierungsdaten sind verfügbar unter https://huggingface.co/Soughing/MLRA.
English
Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8times decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.