Hydragen:具有共享前綴的高通量LLM推斷
Hydragen: High-Throughput LLM Inference with Shared Prefixes
February 7, 2024
作者: Jordan Juravsky, Bradley Brown, Ryan Ehrlich, Daniel Y. Fu, Christopher Ré, Azalia Mirhoseini
cs.AI
摘要
基於Transformer的大型語言模型(LLMs)現已部署到數億用戶。LLM推理通常在共享前綴的序列批次上執行,例如少量示例或聊天機器人系統提示。在這種大批次設置中,解碼可能會受到關注操作的瓶頸影響,該操作從內存中讀取大型鍵值(KV)緩存,並為批次中的每個序列計算低效的矩陣-向量乘積。在這項工作中,我們介紹了Hydragen,這是一個硬件感知的精確關注實現,具有共享前綴。Hydragen分別計算共享前綴和獨特後綴的關注。這種分解通過跨序列批次一起批處理查詢,實現了有效的前綴關注,減少了冗餘的內存讀取,並實現了硬件友好的矩陣乘法的使用。我們的方法可以將端到端的LLM吞吐量提高多達32倍,優於競爭基準,速度隨著批次大小和共享前綴長度的增加而增加。Hydragen還可以實現使用非常長的共享上下文:在高批次大小下,將前綴長度從1K增加到16K標記,Hydragen吞吐量下降不到15%,而基準的吞吐量下降超過90%。Hydragen不僅適用於簡單的前綴-後綴分解,還可以應用於基於樹的提示共享模式,使我們能夠進一步減少在競爭性編程問題上的推理時間達55%。
English
Transformer-based large language models (LLMs) are now deployed to hundreds
of millions of users. LLM inference is commonly performed on batches of
sequences that share a prefix, such as few-shot examples or a chatbot system
prompt. Decoding in this large-batch setting can be bottlenecked by the
attention operation, which reads large key-value (KV) caches from memory and
computes inefficient matrix-vector products for every sequence in the batch. In
this work, we introduce Hydragen, a hardware-aware exact implementation of
attention with shared prefixes. Hydragen computes attention over the shared
prefix and unique suffixes separately. This decomposition enables efficient
prefix attention by batching queries together across sequences, reducing
redundant memory reads and enabling the use of hardware-friendly matrix
multiplications. Our method can improve end-to-end LLM throughput by up to 32x
against competitive baselines, with speedup growing with the batch size and
shared prefix length. Hydragen also enables the use of very long shared
contexts: with a high batch size, increasing the prefix length from 1K to 16K
tokens decreases Hydragen throughput by less than 15%, while the throughput of
baselines drops by over 90%. Hydragen generalizes beyond simple prefix-suffix
decomposition and can be applied to tree-based prompt sharing patterns,
allowing us to further reduce inference time on competitive programming
problems by 55%.