SkipDecode: Decodagem Autoregressiva com Batching e Caching para Inferência Eficiente em Modelos de Linguagem de Grande Escala
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
Autores: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
Resumo
Modelos de linguagem autoregressivos de grande escala (LLMs) têm alcançado progressos notáveis em diversas tarefas de geração de linguagem natural. No entanto, eles incorrem em altos custos computacionais e latência decorrentes da geração token por token autoregressiva. Para abordar esse problema, várias abordagens foram propostas para reduzir o custo computacional utilizando estratégias de saída antecipada. Essas estratégias permitem uma geração de texto mais rápida ao empregar computação reduzida, sem aplicar o grafo computacional completo a cada token. Embora os métodos existentes de saída antecipada em nível de token mostrem resultados promissores para inferência online, eles não podem ser facilmente aplicados para inferência em lote e armazenamento em cache de chave-valor (KV). Isso ocorre porque eles precisam aguardar até que o último token em um lote saia antes de interromper o cálculo. Isso limita severamente a aplicação prática de tais técnicas. Neste artigo, propomos um método simples e eficaz de saída antecipada em nível de token, chamado SkipDecode, projetado para funcionar de forma integrada com inferência em lote e armazenamento em cache KV. Ele supera as limitações anteriores ao estabelecer um ponto de saída único para cada token em um lote em cada posição da sequência. Ele também garante uma diminuição monotônica nos pontos de saída, eliminando a necessidade de recalcular caches KV para tokens anteriores. Em vez de interromper o cálculo prematuramente, como em trabalhos anteriores, nossa abordagem ignora as camadas inferiores e intermediárias, dedicando a maior parte dos recursos computacionais às camadas superiores, permitindo que tokens posteriores se beneficiem do gasto computacional dos tokens anteriores. Nossos resultados experimentais mostram que o SkipDecode pode obter acelerações de inferência de 2x a 5x com regressão negligenciável em uma variedade de tarefas. Isso é alcançado utilizando modelos OPT com 1,3 bilhão e 6,7 bilhões de parâmetros, mantendo compatibilidade direta com técnicas de otimização de lote e armazenamento em cache KV.
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.