DuoAttention : Inférence efficace de LLM à long contexte avec des têtes de recherche et de diffusion
DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads
October 14, 2024
Auteurs: Guangxuan Xiao, Jiaming Tang, Jingwei Zuo, Junxian Guo, Shang Yang, Haotian Tang, Yao Fu, Song Han
cs.AI
Résumé
Le déploiement des grands modèles de langage à long contexte (LLM) est essentiel mais pose des défis computationnels et de mémoire importants. Mettre en cache tous les états Clé et Valeur (KV) à travers toutes les têtes d'attention consomme une mémoire substantielle. Les méthodes existantes d'élagage du cache KV endommagent soit les capacités à long contexte des LLM, soit n'offrent que des améliorations d'efficacité limitées. Dans cet article, nous identifions qu'une fraction seulement des têtes d'attention, appelées Têtes de Récupération, sont cruciales pour le traitement des longs contextes et nécessitent une attention complète sur tous les jetons. En revanche, toutes les autres têtes, qui se concentrent principalement sur les jetons récents et les puits d'attention - appelées Têtes de Diffusion - ne nécessitent pas une attention complète. Sur la base de cette observation, nous introduisons DuoAttention, un cadre qui n'applique un cache KV complet qu'aux têtes de récupération tout en utilisant un cache KV léger et de longueur constante pour les têtes de diffusion, ce qui réduit à la fois la mémoire de décodage et de pré-remplissage des LLM ainsi que la latence sans compromettre leurs capacités à long contexte. DuoAttention utilise un algorithme léger basé sur l'optimisation avec des données synthétiques pour identifier précisément les têtes de récupération. Notre méthode réduit significativement la mémoire d'inférence à long contexte jusqu'à 2,55 fois pour les modèles MHA et 1,67 fois pour les modèles GQA tout en accélérant le décodage jusqu'à 2,18 fois et 1,50 fois, et en accélérant le pré-remplissage jusqu'à 1,73 fois et 1,63 fois pour les modèles MHA et GQA respectivement, avec une perte de précision minimale par rapport à une attention complète. Notamment, combiné à la quantification, DuoAttention permet le décodage de Llama-3-8B avec une longueur de contexte de 3,3 millions sur un seul GPU A100. Le code est disponible sur https://github.com/mit-han-lab/duo-attention.
English
Deploying long-context large language models (LLMs) is essential but poses
significant computational and memory challenges. Caching all Key and Value (KV)
states across all attention heads consumes substantial memory. Existing KV
cache pruning methods either damage the long-context capabilities of LLMs or
offer only limited efficiency improvements. In this paper, we identify that
only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for
processing long contexts and require full attention across all tokens. In
contrast, all other heads, which primarily focus on recent tokens and attention
sinks--referred to as Streaming Heads--do not require full attention. Based on
this insight, we introduce DuoAttention, a framework that only applies a full
KV cache to retrieval heads while using a light-weight, constant-length KV
cache for streaming heads, which reduces both LLM's decoding and pre-filling
memory and latency without compromising its long-context abilities.
DuoAttention uses a lightweight, optimization-based algorithm with synthetic
data to identify retrieval heads accurately. Our method significantly reduces
long-context inference memory by up to 2.55x for MHA and 1.67x for GQA models
while speeding up decoding by up to 2.18x and 1.50x and accelerating
pre-filling by up to 1.73x and 1.63x for MHA and GQA models, respectively, with
minimal accuracy loss compared to full attention. Notably, combined with
quantization, DuoAttention enables Llama-3-8B decoding with 3.3 million context
length on a single A100 GPU. Code is provided in
https://github.com/mit-han-lab/duo-attention.Summary
AI-Generated Summary