SkipDecode:带有批处理和缓存的自回归跳跃解码,用于高效的LLM推断
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
作者: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
摘要
自回归大型语言模型(LLMs)在各种自然语言生成任务中取得了显著进展。然而,由于自回归逐标记生成,它们产生了高计算成本和延迟。为了解决这一问题,已经提出了几种方法来减少计算成本,使用提前退出策略。这些策略可以通过减少计算量来实现更快的文本生成,而无需对每个标记应用完整的计算图。虽然现有的标记级提前退出方法在在线推断中显示出有希望的结果,但无法直接应用于批量推断和键-值缓存。这是因为它们必须等到批次中的最后一个标记退出后才能停止计算。这严重限制了这些技术的实际应用。在本文中,我们提出了一种简单而有效的标记级提前退出方法SkipDecode,旨在与批量推断和KV缓存无缝配合。它通过在每个序列位置为批次中的每个标记设置一个独立的退出点来克服先前的限制。它还保证退出点的单调减少,从而消除了为前面的标记重新计算KV缓存的需要。与先前的作品不同,我们的方法不会过早终止计算,而是绕过较低到中间层,将大部分计算资源用于上层,使后续标记能够从先前标记的计算支出中受益。我们的实验结果表明,SkipDecode可以在各种任务中实现2倍到5倍的推断加速,同时在1.3亿和6.7亿参数的OPT模型中实现,完全兼容批处理和KV缓存优化技术。
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.