SkipDecode: Decodifica Autoregressiva con Salto, Batching e Caching per un'Inferenza Efficiente nei Modelli Linguistici di Grande Dimensione
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
Autori: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
Abstract
I modelli linguistici autoregressivi di grandi dimensioni (LLM) hanno compiuto progressi significativi in vari compiti di generazione del linguaggio naturale. Tuttavia, comportano un elevato costo computazionale e una latenza derivante dalla generazione token per token di tipo autoregressivo. Per affrontare questo problema, sono state proposte diverse strategie per ridurre il costo computazionale utilizzando approcci di uscita anticipata. Queste strategie consentono una generazione più rapida del testo riducendo il calcolo senza applicare il grafo computazionale completo a ciascun token. Sebbene i metodi esistenti di uscita anticipata a livello di token mostrino risultati promettenti per l'inferenza online, non possono essere facilmente applicati per l'inferenza in batch e la memorizzazione Key-Value (KV). Ciò è dovuto al fatto che devono attendere che l'ultimo token in un batch esca prima di poter interrompere il calcolo. Questo limita fortemente l'applicazione pratica di tali tecniche. In questo articolo, proponiamo un metodo semplice ed efficace di uscita anticipata a livello di token, denominato SkipDecode, progettato per funzionare in modo fluido con l'inferenza in batch e la memorizzazione KV. Supera i limiti precedenti stabilendo un punto di uscita singolo per ogni token in un batch in ciascuna posizione della sequenza. Garantisce inoltre una diminuzione monotona dei punti di uscita, eliminando così la necessità di ricalcolare le cache KV per i token precedenti. Piuttosto che interrompere prematuramente il calcolo come nei lavori precedenti, il nostro approccio bypassa gli strati intermedi e inferiori, dedicando la maggior parte delle risorse computazionali agli strati superiori, consentendo ai token successivi di beneficiare del calcolo effettuato dai token precedenti. I nostri risultati sperimentali dimostrano che SkipDecode può ottenere un'accelerazione dell'inferenza da 2x a 5x con una regressione trascurabile in una varietà di compiti. Ciò è stato raggiunto utilizzando modelli OPT con 1,3 miliardi e 6,7 miliardi di parametri, mantenendo al contempo la compatibilità diretta con le tecniche di ottimizzazione del batching e della memorizzazione KV.
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.