SkipDecode: 배치 처리와 캐싱을 통한 자동회귀 건너뛰기 디코딩으로 효율적인 LLM 추론 구현
SkipDecode: Autoregressive Skip Decoding with Batching and Caching for Efficient LLM Inference
July 5, 2023
저자: Luciano Del Corro, Allie Del Giorno, Sahaj Agarwal, Bin Yu, Ahmed Awadallah, Subhabrata Mukherjee
cs.AI
초록
자기회귀적 대규모 언어 모델(LLMs)은 다양한 자연어 생성 작업에서 놀라운 진전을 이루었습니다. 그러나 이러한 모델들은 토큰 단위의 자기회귀적 생성 방식으로 인해 높은 계산 비용과 지연 시간을 초래합니다. 이 문제를 해결하기 위해, 조기 종료 전략을 사용하여 계산 비용을 줄이는 여러 접근 방식이 제안되었습니다. 이러한 전략은 각 토큰에 대해 전체 계산 그래프를 적용하지 않고도 계산량을 줄여 더 빠른 텍스트 생성을 가능하게 합니다. 기존의 토큰 수준 조기 종료 방법들은 온라인 추론에서 유망한 결과를 보여주지만, 배치 추론과 Key-Value 캐싱에는 바로 적용하기 어렵습니다. 이는 배치 내 마지막 토큰이 종료될 때까지 계산을 멈출 수 없기 때문이며, 이로 인해 이러한 기술의 실용적 적용이 심각하게 제한됩니다. 본 논문에서는 배치 추론과 KV 캐싱과 원활하게 동작하도록 설계된 간단하면서도 효과적인 토큰 수준 조기 종료 방법인 SkipDecode를 제안합니다. 이 방법은 각 시퀀스 위치에서 배치 내 모든 토큰에 대해 단일 종료 지점을 설정함으로써 기존의 제약을 극복합니다. 또한 종료 지점이 단조롭게 감소함을 보장하여 선행 토큰들에 대한 KV 캐시를 재계산할 필요를 없앱니다. 기존 연구들처럼 계산을 조기에 종료하는 대신, 우리의 접근 방식은 하위 및 중간 계층을 우회하고 대부분의 계산 자원을 상위 계층에 집중시켜, 후속 토큰들이 선행 토큰들의 계산 지출로부터 이익을 얻을 수 있도록 합니다. 실험 결과, SkipDecode는 다양한 작업에서 1.3억 및 67억 개의 파라미터를 가진 OPT 모델을 사용하여 최소한의 성능 저하와 함께 2배에서 5배의 추론 속도 향상을 달성할 수 있음을 보여줍니다. 이는 배치 처리 및 KV 캐싱 최적화 기술과 직접 호환되면서 이루어집니다.
English
Autoregressive large language models (LLMs) have made remarkable progress in
various natural language generation tasks. However, they incur high computation
cost and latency resulting from the autoregressive token-by-token generation.
To address this issue, several approaches have been proposed to reduce
computational cost using early-exit strategies. These strategies enable faster
text generation using reduced computation without applying the full computation
graph to each token. While existing token-level early exit methods show
promising results for online inference, they cannot be readily applied for
batch inferencing and Key-Value caching. This is because they have to wait
until the last token in a batch exits before they can stop computing. This
severely limits the practical application of such techniques. In this paper, we
propose a simple and effective token-level early exit method, SkipDecode,
designed to work seamlessly with batch inferencing and KV caching. It overcomes
prior constraints by setting up a singular exit point for every token in a
batch at each sequence position. It also guarantees a monotonic decrease in
exit points, thereby eliminating the need to recompute KV Caches for preceding
tokens. Rather than terminating computation prematurely as in prior works, our
approach bypasses lower to middle layers, devoting most of the computational
resources to upper layers, allowing later tokens to benefit from the compute
expenditure by earlier tokens. Our experimental results show that SkipDecode
can obtain 2x to 5x inference speedups with negligible regression across a
variety of tasks. This is achieved using OPT models of 1.3 billion and 6.7
billion parameters, all the while being directly compatible with batching and
KV caching optimization techniques.