Griffin: Kombination von gated linearen Rekurrenzen mit lokalem Attention für effiziente Sprachmodelle
Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models
February 29, 2024
Autoren: Soham De, Samuel L. Smith, Anushan Fernando, Aleksandar Botev, George Cristian-Muraru, Albert Gu, Ruba Haroun, Leonard Berrada, Yutian Chen, Srivatsan Srinivasan, Guillaume Desjardins, Arnaud Doucet, David Budden, Yee Whye Teh, Razvan Pascanu, Nando De Freitas, Caglar Gulcehre
cs.AI
Zusammenfassung
Reziduelle neuronale Netze (RNNs) bieten schnelle Inferenz und skalieren effizient auf langen Sequenzen, sind jedoch schwierig zu trainieren und schwer zu skalieren. Wir stellen Hawk vor, ein RNN mit gated linearen Rekurrenzen, und Griffin, ein hybrides Modell, das gated lineare Rekurrenzen mit lokalem Attention-Mechanismus kombiniert. Hawk übertrifft die berichtete Leistung von Mamba bei nachgelagerten Aufgaben, während Griffin die Leistung von Llama-2 erreicht, obwohl es mit über sechsmal weniger Tokens trainiert wurde. Wir zeigen außerdem, dass Griffin auf Sequenzen extrapolieren kann, die deutlich länger sind als die während des Trainings gesehenen. Unsere Modelle erreichen die Hardware-Effizienz von Transformern während des Trainings und bieten während der Inferenz eine geringere Latenz und eine deutlich höhere Durchsatzrate. Wir skalieren Griffin auf bis zu 14 Milliarden Parameter und erläutern, wie unsere Modelle für effizientes verteiltes Training partitioniert werden können.
English
Recurrent neural networks (RNNs) have fast inference and scale efficiently on
long sequences, but they are difficult to train and hard to scale. We propose
Hawk, an RNN with gated linear recurrences, and Griffin, a hybrid model that
mixes gated linear recurrences with local attention. Hawk exceeds the reported
performance of Mamba on downstream tasks, while Griffin matches the performance
of Llama-2 despite being trained on over 6 times fewer tokens. We also show
that Griffin can extrapolate on sequences significantly longer than those seen
during training. Our models match the hardware efficiency of Transformers
during training, and during inference they have lower latency and significantly
higher throughput. We scale Griffin up to 14B parameters, and explain how to
shard our models for efficient distributed training.