Transformadores en Tándem para LLMs Eficientes en Inferencia
Tandem Transformers for Inference Efficient LLMs
February 13, 2024
Autores: Aishwarya P S, Pranav Ajit Nair, Yashas Samaga, Toby Boyd, Sanjiv Kumar, Prateek Jain, Praneeth Netrapalli
cs.AI
Resumen
La naturaleza autorregresiva de los modelos de lenguaje grandes (LLMs) convencionales limita inherentemente la velocidad de inferencia, ya que los tokens se generan secuencialmente. Aunque las técnicas de decodificación especulativa y paralela intentan mitigar esto, enfrentan limitaciones: ya sea dependiendo de modelos más pequeños y menos precisos para la generación o no aprovechando completamente las representaciones del LLM base.
Introducimos una arquitectura novedosa, los transformadores en tándem, para abordar estos problemas. Esta arquitectura combina de manera única (1) un modelo autorregresivo pequeño y (2) un modelo grande que opera en modo de bloque (procesando múltiples tokens simultáneamente). La precisión predictiva del modelo pequeño se mejora sustancialmente al permitirle prestar atención a las representaciones más ricas del modelo grande. En el conjunto de datos de preentrenamiento de PaLM2, un tándem de PaLM2-Bison y PaLM2-Gecko demuestra una mejora del 3.3% en la precisión de predicción del siguiente token en comparación con un PaLM2-Gecko independiente, ofreciendo una aceleración de 1.16x en comparación con un modelo PaLM2-Otter con un rendimiento comparable en tareas posteriores. Además, integramos el modelo tándem dentro del marco de decodificación especulativa (SPEED), donde el modelo grande valida los tokens generados por el modelo pequeño. Esto garantiza que el tándem de PaLM2-Bison y PaLM2-Gecko logre una aceleración sustancial (alrededor de 1.14x más rápido que usar PaLM2-Gecko estándar en SPEED) mientras mantiene una precisión idéntica en las tareas posteriores.
English
The autoregressive nature of conventional large language models (LLMs)
inherently limits inference speed, as tokens are generated sequentially. While
speculative and parallel decoding techniques attempt to mitigate this, they
face limitations: either relying on less accurate smaller models for generation
or failing to fully leverage the base LLM's representations.
We introduce a novel architecture, Tandem transformers, to address these
issues. This architecture uniquely combines (1) a small autoregressive model
and (2) a large model operating in block mode (processing multiple tokens
simultaneously). The small model's predictive accuracy is substantially
enhanced by granting it attention to the large model's richer representations.
On the PaLM2 pretraining dataset, a tandem of PaLM2-Bison and PaLM2-Gecko
demonstrates a 3.3% improvement in next-token prediction accuracy over a
standalone PaLM2-Gecko, offering a 1.16x speedup compared to a PaLM2-Otter
model with comparable downstream performance. We further incorporate the tandem
model within the speculative decoding (SPEED) framework where the large model
validates tokens from the small model. This ensures that the Tandem of
PaLM2-Bison and PaLM2-Gecko achieves substantial speedup (around 1.14x faster
than using vanilla PaLM2-Gecko in SPEED) while maintaining identical downstream
task accuracy.