ChatPaper.aiChatPaper

Apprendre à (Apprendre au Moment du Test) : RNNs avec des États Cachés Expressifs

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

July 5, 2024
papers.authors: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI

papers.abstract

L'auto-attention performe bien dans des contextes longs mais présente une complexité quadratique. Les couches RNN existantes ont une complexité linéaire, mais leur performance dans des contextes longs est limitée par la puissance expressive de leur état caché. Nous proposons une nouvelle classe de couches de modélisation de séquences avec une complexité linéaire et un état caché expressif. L'idée clé est de faire de l'état caché un modèle d'apprentissage automatique lui-même, et de la règle de mise à jour une étape d'apprentissage auto-supervisé. Comme l'état caché est mis à jour par entraînement même sur des séquences de test, nos couches sont appelées couches d'Entraînement au Moment du Test (TTT). Nous considérons deux instanciations : TTT-Linéaire et TTT-MLP, dont l'état caché est respectivement un modèle linéaire et un MLP à deux couches. Nous évaluons nos instanciations à l'échelle de 125M à 1,3B de paramètres, en les comparant à un Transformer robuste et à Mamba, un RNN moderne. TTT-Linéaire et TTT-MLP égalent ou surpassent les bases de référence. Similairement au Transformer, ils peuvent continuer à réduire la perplexité en se conditionnant sur plus de tokens, alors que Mamba ne le peut pas après un contexte de 16k. Avec une optimisation préliminaire des systèmes, TTT-Linéaire est déjà plus rapide que le Transformer à un contexte de 8k et égalise Mamba en temps réel. TTT-MLP rencontre encore des défis en termes d'entrée/sortie mémoire, mais montre un potentiel plus important dans des contextes longs, indiquant une direction prometteuse pour de futures recherches.
English
Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer, they can keep reducing perplexity by conditioning on more tokens, while Mamba cannot after 16k context. With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O, but shows larger potential in long context, pointing to a promising direction for future research.
PDF352November 28, 2024