SkipDecode: バッチ処理とキャッシュを活用した自己回帰型スキップデコードによる効率的な大規模言語モデル推論
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
著者: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
要旨
自己回帰型大規模言語モデル(LLM)は、様々な自然言語生成タスクにおいて顕著な進歩を遂げてきました。しかし、トークンごとに生成を行う自己回帰的な性質から、高い計算コストとレイテンシが発生します。この問題に対処するため、早期終了戦略を用いて計算コストを削減するいくつかのアプローチが提案されています。これらの戦略は、各トークンに完全な計算グラフを適用することなく、計算量を削減しながら高速なテキスト生成を可能にします。既存のトークンレベルの早期終了手法は、オンライン推論において有望な結果を示していますが、バッチ推論やKey-Valueキャッシングには容易に適用できません。これは、バッチ内の最後のトークンが終了するまで計算を停止できないためです。この制約により、そのような技術の実用的な応用が大幅に制限されています。本論文では、バッチ推論とKVキャッシングとシームレスに連携する、シンプルで効果的なトークンレベルの早期終了手法「SkipDecode」を提案します。この手法は、バッチ内の各トークンに対して各シーケンス位置で単一の終了点を設定することで、従来の制約を克服します。また、終了点が単調減少することを保証し、先行するトークンのKVキャッシュを再計算する必要をなくします。従来の研究のように計算を早期に終了するのではなく、本手法は中下位層をバイパスし、計算リソースの大部分を上位層に集中させることで、後続のトークンが先行するトークンの計算支出の恩恵を受けられるようにします。実験結果から、SkipDecodeは、1.3億パラメータと6.7億パラメータのOPTモデルを使用して、様々なタスクにおいて無視できる程度の精度低下で2倍から5倍の推論速度向上を達成できることが示されています。これは、バッチ処理とKVキャッシングの最適化技術と直接互換性を保ちながら実現されています。
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.