DuoAttention: 検索とストリーミングヘッドを用いた効率的な長文脈LLM推論
DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads
October 14, 2024
著者: Guangxuan Xiao, Jiaming Tang, Jingwei Zuo, Junxian Guo, Shang Yang, Haotian Tang, Yao Fu, Song Han
cs.AI
要旨
長い文脈を持つ大規模言語モデル(LLM)を展開することは重要ですが、膨大な計算とメモリの課題を引き起こします。すべての注意ヘッドにわたるすべてのKeyとValue(KV)状態をキャッシュすることは、膨大なメモリを消費します。既存のKVキャッシュの剪定方法は、LLMの長い文脈能力を損なうか、効率の改善が限定されるものです。本論文では、ごく一部の注意ヘッド、すなわち、検索ヘッドとして知られるものが、長い文脈を処理する際に重要であり、すべてのトークンにわたって完全な注意を必要とすることを特定します。それに対し、最近のトークンと注意の焦点である他のすべてのヘッド、すなわち、ストリーミングヘッドは、完全な注意を必要としません。この洞察に基づいて、我々はDuoAttentionを導入します。これは、検索ヘッドにのみ完全なKVキャッシュを適用し、ストリーミングヘッドには軽量で一定長のKVキャッシュを使用するフレームワークです。これにより、LLMのデコーディングと事前充填のメモリと遅延が削減され、長い文脈能力が損なわれることなく、効率が向上します。DuoAttentionは、軽量で最適化ベースのアルゴリズムを使用し、合成データを用いて検索ヘッドを正確に特定します。当社の手法は、MHAモデルに対して最大2.55倍、GQAモデルに対して最大1.67倍の長い文脈推論メモリを削減し、デコーディングを最大2.18倍、1.50倍、事前充填を最大1.73倍、1.63倍高速化します。これにより、完全な注意と比較して最小限の精度損失で、Llama-3-8Bのデコーディングを単一のA100 GPUで330万のコンテキスト長で実現します。コードはhttps://github.com/mit-han-lab/duo-attentionで提供されています。
English
Deploying long-context large language models (LLMs) is essential but poses
significant computational and memory challenges. Caching all Key and Value (KV)
states across all attention heads consumes substantial memory. Existing KV
cache pruning methods either damage the long-context capabilities of LLMs or
offer only limited efficiency improvements. In this paper, we identify that
only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for
processing long contexts and require full attention across all tokens. In
contrast, all other heads, which primarily focus on recent tokens and attention
sinks--referred to as Streaming Heads--do not require full attention. Based on
this insight, we introduce DuoAttention, a framework that only applies a full
KV cache to retrieval heads while using a light-weight, constant-length KV
cache for streaming heads, which reduces both LLM's decoding and pre-filling
memory and latency without compromising its long-context abilities.
DuoAttention uses a lightweight, optimization-based algorithm with synthetic
data to identify retrieval heads accurately. Our method significantly reduces
long-context inference memory by up to 2.55x for MHA and 1.67x for GQA models
while speeding up decoding by up to 2.18x and 1.50x and accelerating
pre-filling by up to 1.73x and 1.63x for MHA and GQA models, respectively, with
minimal accuracy loss compared to full attention. Notably, combined with
quantization, DuoAttention enables Llama-3-8B decoding with 3.3 million context
length on a single A100 GPU. Code is provided in
https://github.com/mit-han-lab/duo-attention.