추론 효율적인 대형 언어 모델을 위한 탠덤 트랜스포머
Tandem Transformers for Inference Efficient LLMs
February 13, 2024
저자: Aishwarya P S, Pranav Ajit Nair, Yashas Samaga, Toby Boyd, Sanjiv Kumar, Prateek Jain, Praneeth Netrapalli
cs.AI
초록
기존 대규모 언어 모델(LLM)의 자기회귀적 특성은 토큰이 순차적으로 생성되기 때문에 추론 속도를 본질적으로 제한한다. 추측적 및 병렬 디코딩 기법이 이를 완화하려 시도하지만, 이들 기법은 정확도가 낮은 소형 모델에 의존하거나 기본 LLM의 표현을 완전히 활용하지 못하는 한계에 직면한다.
이러한 문제를 해결하기 위해 우리는 새로운 아키텍처인 탠덤 트랜스포머(Tandem transformers)를 소개한다. 이 아키텍처는 (1) 소형 자기회귀 모델과 (2) 블록 모드(여러 토큰을 동시에 처리)로 작동하는 대형 모델을 독창적으로 결합한다. 소형 모델의 예측 정확도는 대형 모델의 더 풍부한 표현에 주목함으로써 크게 향상된다. PaLM2 사전 학습 데이터셋에서, PaLM2-Bison과 PaLM2-Gecko로 구성된 탠덤 모델은 독립적인 PaLM2-Gecko 대비 다음 토큰 예측 정확도에서 3.3% 향상을 보였으며, 비슷한 다운스트림 성능을 가진 PaLM2-Otter 모델 대비 1.16배의 속도 향상을 제공한다. 또한, 우리는 탠덤 모델을 추측적 디코딩(SPEED) 프레임워크 내에 통합하여 대형 모델이 소형 모델의 토큰을 검증하도록 한다. 이를 통해 PaLM2-Bison과 PaLM2-Gecko로 구성된 탠덤 모델은 SPEED에서 일반 PaLM2-Gecko를 사용하는 경우 대비 약 1.14배 빠른 속도 향상을 달성하면서도 동일한 다운스트림 작업 정확도를 유지한다.
English
The autoregressive nature of conventional large language models (LLMs)
inherently limits inference speed, as tokens are generated sequentially. While
speculative and parallel decoding techniques attempt to mitigate this, they
face limitations: either relying on less accurate smaller models for generation
or failing to fully leverage the base LLM's representations.
We introduce a novel architecture, Tandem transformers, to address these
issues. This architecture uniquely combines (1) a small autoregressive model
and (2) a large model operating in block mode (processing multiple tokens
simultaneously). The small model's predictive accuracy is substantially
enhanced by granting it attention to the large model's richer representations.
On the PaLM2 pretraining dataset, a tandem of PaLM2-Bison and PaLM2-Gecko
demonstrates a 3.3% improvement in next-token prediction accuracy over a
standalone PaLM2-Gecko, offering a 1.16x speedup compared to a PaLM2-Otter
model with comparable downstream performance. We further incorporate the tandem
model within the speculative decoding (SPEED) framework where the large model
validates tokens from the small model. This ensures that the Tandem of
PaLM2-Bison and PaLM2-Gecko achieves substantial speedup (around 1.14x faster
than using vanilla PaLM2-Gecko in SPEED) while maintaining identical downstream
task accuracy.Summary
AI-Generated Summary