SkipDecode: Autoregressief Overslaan Decoderen met Batchverwerking en Caching voor Efficiënte LLM-inferentie
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
Auteurs: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
Samenvatting
Autoregressieve grote taalmodellen (LLMs) hebben opmerkelijke vooruitgang geboekt in verschillende taken voor natuurlijke taalgeneratie. Ze brengen echter hoge rekenkosten en latentie met zich mee als gevolg van de autoregressieve token-voor-token-generatie. Om dit probleem aan te pakken, zijn verschillende benaderingen voorgesteld om de rekenkosten te verlagen met behulp van early-exit-strategieën. Deze strategieën maken snellere tekstgeneratie mogelijk met minder rekenkracht, zonder het volledige rekenkundige grafiek op elke token toe te passen. Hoewel bestaande token-level early-exit-methoden veelbelovende resultaten laten zien voor online inferentie, kunnen ze niet direct worden toegepast voor batch-inferentie en Key-Value-caching. Dit komt omdat ze moeten wachten tot de laatste token in een batch uitstapt voordat ze kunnen stoppen met rekenen. Dit beperkt de praktische toepassing van dergelijke technieken ernstig. In dit artikel stellen we een eenvoudige en effectieve token-level early-exit-methode voor, SkipDecode, die naadloos werkt met batch-inferentie en KV-caching. Het overwint eerdere beperkingen door een enkel uitstappunt in te stellen voor elke token in een batch op elke sequentiepositie. Het garandeert ook een monotone afname van uitstappunten, waardoor het opnieuw berekenen van KV-caches voor voorgaande tokens overbodig wordt. In plaats van de berekening voortijdig te beëindigen zoals in eerdere werken, omzeilt onze aanpak de lagere tot middelste lagen en wijdt het het grootste deel van de rekenkracht aan de bovenste lagen, waardoor latere tokens kunnen profiteren van de rekenkracht die door eerdere tokens is besteed. Onze experimentele resultaten tonen aan dat SkipDecode een 2x tot 5x versnelling van de inferentie kan bereiken met verwaarloosbare terugval over een verscheidenheid aan taken. Dit wordt bereikt met OPT-modellen van 1,3 miljard en 6,7 miljard parameters, terwijl het direct compatibel is met batchverwerking en KV-caching-optimalisatietechnieken.
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.