SkipDecode: Autoregressives Überspringen der Dekodierung mit Batching und Caching für effiziente LLM-Inferenz
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
Autoren: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
Zusammenfassung
Autoregressive große Sprachmodelle (LLMs) haben bemerkenswerte Fortschritte in verschiedenen Aufgaben der natürlichen Sprachgenerierung erzielt. Allerdings verursachen sie hohe Rechenkosten und Latenzzeiten, die sich aus der autoregressiven Token-für-Token-Generierung ergeben. Um dieses Problem zu lösen, wurden mehrere Ansätze vorgeschlagen, um die Rechenkosten durch Early-Exit-Strategien zu reduzieren. Diese Strategien ermöglichen eine schnellere Textgenerierung mit reduziertem Rechenaufwand, ohne den vollen Berechnungsgraphen auf jedes Token anzuwenden. Während bestehende Token-Level-Early-Exit-Methoden vielversprechende Ergebnisse für Online-Inferenz zeigen, können sie nicht ohne Weiteres für Batch-Inferenz und Key-Value-Caching verwendet werden. Dies liegt daran, dass sie warten müssen, bis das letzte Token in einem Batch beendet ist, bevor sie die Berechnung stoppen können. Dies schränkt die praktische Anwendung solcher Techniken erheblich ein. In diesem Artikel schlagen wir eine einfache und effektive Token-Level-Early-Exit-Methode vor, SkipDecode, die nahtlos mit Batch-Inferenz und KV-Caching zusammenarbeitet. Sie überwindet frühere Einschränkungen, indem sie einen einzigen Ausstiegspunkt für jedes Token in einem Batch an jeder Sequenzposition festlegt. Sie gewährleistet auch eine monotone Abnahme der Ausstiegspunkte, wodurch die Notwendigkeit entfällt, KV-Caches für vorhergehende Token neu zu berechnen. Anstatt die Berechnung vorzeitig zu beenden wie in früheren Arbeiten, umgeht unser Ansatz die unteren bis mittleren Schichten und widmet den größten Teil der Rechenressourcen den oberen Schichten, sodass spätere Token von den Rechenaufwendungen früherer Token profitieren können. Unsere experimentellen Ergebnisse zeigen, dass SkipDecode eine 2x bis 5x schnellere Inferenz mit vernachlässigbarem Leistungsverlust über eine Vielzahl von Aufgaben erzielen kann. Dies wird mit OPT-Modellen mit 1,3 Milliarden und 6,7 Milliarden Parametern erreicht, wobei gleichzeitig eine direkte Kompatibilität mit Batch-Verarbeitung und KV-Caching-Optimierungstechniken gewährleistet ist.
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.