Griffin : Combinaison de récurrences linéaires à portes et d'attention locale pour des modèles de langage efficaces
Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
February 29, 2024
Auteurs: Soham De, Samuel L. Smith, Anushan Fernando, Aleksandar Botev, George Cristian-Muraru, Albert Gu, Ruba Haroun, Leonard Berrada, Yutian Chen, Srivatsan Srinivasan, Guillaume Desjardins, Arnaud Doucet, David Budden, Yee Whye Teh, Razvan Pascanu, Nando De Freitas, Caglar Gulcehre
cs.AI
Résumé
Les réseaux de neurones récurrents (RNN) offrent une inférence rapide et une mise à l'échelle efficace sur des séquences longues, mais ils sont difficiles à entraîner et à scaler. Nous proposons Hawk, un RNN avec des récurrences linéaires à portes, et Griffin, un modèle hybride qui combine des récurrences linéaires à portes avec une attention locale. Hawk dépasse les performances rapportées de Mamba sur des tâches en aval, tandis que Griffin atteint les performances de Llama-2 malgré un entraînement sur plus de 6 fois moins de tokens. Nous montrons également que Griffin peut extrapoler sur des séquences significativement plus longues que celles vues pendant l'entraînement. Nos modèles égalent l'efficacité matérielle des Transformers pendant l'entraînement, et pendant l'inférence, ils ont une latence plus faible et un débit significativement plus élevé. Nous avons mis à l'échelle Griffin jusqu'à 14 milliards de paramètres, et expliquons comment partitionner nos modèles pour un entraînement distribué efficace.
English
Recurrent neural networks (RNNs) have fast inference and scale efficiently on
long sequences, but they are difficult to train and hard to scale. We propose
Hawk, an RNN with gated linear recurrences, and Griffin, a hybrid model that
mixes gated linear recurrences with local attention. Hawk exceeds the reported
performance of Mamba on downstream tasks, while Griffin matches the performance
of Llama-2 despite being trained on over 6 times fewer tokens. We also show
that Griffin can extrapolate on sequences significantly longer than those seen
during training. Our models match the hardware efficiency of Transformers
during training, and during inference they have lower latency and significantly
higher throughput. We scale Griffin up to 14B parameters, and explain how to
shard our models for efficient distributed training.