Atenção Multi-Cabeça de Baixa Classificação
Multi-Head Low-Rank Attention
March 2, 2026
Autores: Songtao Liu, Hongwu Peng, Zhiwei Zhang, Zhengyu Chen, Yue Guo
cs.AI
Resumo
A inferência de contexto longo em modelos de linguagem de grande escala é limitada pelo carregamento da cache de Chave-Valor (KV) durante a fase de decodificação, onde a natureza sequencial da geração exige a transferência repetida da cache KV da Memória de Alta Largura de Banda (HBM) fora do chip para a Memória de Acesso Aleatório Estática (SRAM) dentro do chip a cada passo. Embora a Atenção Latente de Múltiplas Cabeças (MLA) reduza significativamente o tamanho total da cache KV, ela sofre com um gargalo de fragmentação durante a decodificação distribuída via Paralelismo de Tensores (TP). Como sua única cabeça latente não pode ser particionada, cada dispositivo é forçado a carregar redundantemente a cache KV completa para cada token, consumindo tráfego de memória excessivo e diminuindo os benefícios do TP, como a fragmentação de pesos. Neste trabalho, propomos a Atenção de Baixa Postura de Múltiplas Cabeças (MLRA), que permite estados latentes particionáveis para uma decodificação eficiente com TP de 4 vias. Experimentos extensivos mostram que a MLRA atinge a melhor perplexidade e desempenho em tarefas downstream do estado da arte, além de proporcionar uma aceleração de 2,8 vezes na velocidade de decodificação em comparação com a MLA. O código está disponível em https://github.com/SongtaoLiu0823/MLRA. Os pesos pré-treinados, juntamente com os dados de treinamento e avaliação, estão disponíveis em https://huggingface.co/Soughing/MLRA.
English
Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8times decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.