SkipDecode:具有批次處理和緩存的自回歸跳過解碼,用於高效的LLM推論
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
作者: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
摘要
自回歸大型語言模型(LLMs)在各種自然語言生成任務中取得了顯著進展。然而,由於自回歸逐令牌生成,它們產生了高計算成本和延遲。為了解決這個問題,已提出了幾種方法來使用提前退出策略來降低計算成本。這些策略使得在不對每個令牌應用完整計算圖的情況下更快地生成文本成為可能。雖然現有的令牌級提前退出方法對於在線推斷顯示出有希望的結果,但無法直接應用於批量推斷和鍵值緩存。這是因為它們必須等到批次中的最後一個令牌退出後才能停止計算。這嚴重限制了這些技術的實際應用。在本文中,我們提出了一種簡單而有效的令牌級提前退出方法SkipDecode,旨在與批量推斷和KV緩存無縫配合。它通過在每個序列位置為每個批次中的每個令牌設置單一退出點來克服先前的限制。它還保證退出點的單調下降,從而消除了需要重新計算先前令牌的KV緩存的必要性。與先前的作品不同,我們的方法不會過早終止計算,而是跳過較低到中間層,將大部分計算資源用於上層,使後續令牌能夠從先前令牌的計算支出中受益。我們的實驗結果表明,SkipDecode可以在各種任務中實現2倍至5倍的推斷加速,並且幾乎沒有回歸。這是通過使用13億和67億參數的OPT模型實現的,同時與批量處理和KV緩存優化技術直接兼容。
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.