SkipDecode : Décodage autoregressif avec saut, regroupement par lots et mise en cache pour une inférence efficace des grands modèles de langage
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
Auteurs: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
Résumé
Les grands modèles de langage (LLM) autoregressifs ont réalisé des progrès remarquables dans diverses tâches de génération de langage naturel. Cependant, ils entraînent des coûts de calcul élevés et une latence importante en raison de la génération token par token de manière autoregressive. Pour résoudre ce problème, plusieurs approches ont été proposées afin de réduire les coûts de calcul en utilisant des stratégies de sortie précoce. Ces stratégies permettent une génération de texte plus rapide en utilisant un calcul réduit, sans appliquer le graphe de calcul complet à chaque token. Bien que les méthodes existantes de sortie précoce au niveau des tokens montrent des résultats prometteurs pour l'inférence en ligne, elles ne peuvent pas être facilement appliquées à l'inférence par lots et à la mise en cache Key-Value (KV). En effet, elles doivent attendre que le dernier token d'un lot sorte avant de pouvoir arrêter le calcul. Cela limite considérablement l'application pratique de ces techniques. Dans cet article, nous proposons une méthode simple et efficace de sortie précoce au niveau des tokens, appelée SkipDecode, conçue pour fonctionner de manière transparente avec l'inférence par lots et la mise en cache KV. Elle surmonte les contraintes précédentes en établissant un point de sortie unique pour chaque token d'un lot à chaque position de séquence. Elle garantit également une diminution monotone des points de sortie, éliminant ainsi la nécessité de recalculer les caches KV pour les tokens précédents. Plutôt que d'interrompre prématurément le calcul comme dans les travaux précédents, notre approche contourne les couches inférieures à intermédiaires, consacrant la majeure partie des ressources de calcul aux couches supérieures, permettant ainsi aux tokens ultérieurs de bénéficier des dépenses de calcul des tokens précédents. Nos résultats expérimentaux montrent que SkipDecode peut obtenir des accélérations d'inférence de 2x à 5x avec une régression négligeable sur une variété de tâches. Cela est réalisé en utilisant des modèles OPT de 1,3 milliard et 6,7 milliards de paramètres, tout en étant directement compatible avec les techniques d'optimisation de batching et de mise en cache KV.
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.