Griffin: Mezcla de recurrencias lineales con compuertas y atención local para modelos de lenguaje eficientes
Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
February 29, 2024
Autores: Soham De, Samuel L. Smith, Anushan Fernando, Aleksandar Botev, George Cristian-Muraru, Albert Gu, Ruba Haroun, Leonard Berrada, Yutian Chen, Srivatsan Srinivasan, Guillaume Desjardins, Arnaud Doucet, David Budden, Yee Whye Teh, Razvan Pascanu, Nando De Freitas, Caglar Gulcehre
cs.AI
Resumen
Las redes neuronales recurrentes (RNN) tienen una inferencia rápida y escalan eficientemente en secuencias largas, pero son difíciles de entrenar y complicadas de escalar. Proponemos Hawk, una RNN con recurrencias lineales con compuertas, y Griffin, un modelo híbrido que combina recurrencias lineales con compuertas y atención local. Hawk supera el rendimiento reportado de Mamba en tareas posteriores, mientras que Griffin iguala el rendimiento de Llama-2 a pesar de haber sido entrenado con más de 6 veces menos tokens. También demostramos que Griffin puede extrapolar en secuencias significativamente más largas que las vistas durante el entrenamiento. Nuestros modelos igualan la eficiencia de hardware de los Transformers durante el entrenamiento, y durante la inferencia tienen una latencia más baja y un rendimiento significativamente mayor. Escalamos Griffin hasta 14B parámetros y explicamos cómo fragmentar nuestros modelos para un entrenamiento distribuido eficiente.
English
Recurrent neural networks (RNNs) have fast inference and scale efficiently on
long sequences, but they are difficult to train and hard to scale. We propose
Hawk, an RNN with gated linear recurrences, and Griffin, a hybrid model that
mixes gated linear recurrences with local attention. Hawk exceeds the reported
performance of Mamba on downstream tasks, while Griffin matches the performance
of Llama-2 despite being trained on over 6 times fewer tokens. We also show
that Griffin can extrapolate on sequences significantly longer than those seen
during training. Our models match the hardware efficiency of Transformers
during training, and during inference they have lower latency and significantly
higher throughput. We scale Griffin up to 14B parameters, and explain how to
shard our models for efficient distributed training.