Leren om (te leren tijdens testtijd): RNN's met expressieve verborgen toestanden
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
July 5, 2024
Auteurs: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI
Samenvatting
Self-attention presteert goed in lange contexten maar heeft een kwadratische complexiteit.
Bestaande RNN-lagen hebben een lineaire complexiteit, maar hun prestaties in lange
contexten worden beperkt door de expressieve kracht van hun verborgen toestand. Wij stellen een
nieuwe klasse van sequentiemodelleringslagen voor met lineaire complexiteit en een expressieve
verborgen toestand. Het kernidee is om de verborgen toestand zelf een machine learning-model
te maken, en de update-regel een stap van zelfgesuperviseerd leren. Omdat de
verborgen toestand wordt bijgewerkt door training, zelfs op testsequenties, worden onze lagen
Test-Time Training (TTT) lagen genoemd. We beschouwen twee instantiaties:
TTT-Linear en TTT-MLP, waarvan de verborgen toestand respectievelijk een lineair model en een tweelaags
MLP is. We evalueren onze instantiaties op een schaal van 125M tot 1,3B
parameters, in vergelijking met een sterke Transformer en Mamba, een moderne RNN. Zowel
TTT-Linear als TTT-MLP evenaren of overtreffen de referentiemodellen. Net als Transformer,
kunnen ze de perplexiteit blijven verlagen door zich te conditioneren op meer tokens, terwijl Mamba
dat niet kan na een context van 16k. Met voorlopige systeemoptimalisaties is TTT-Linear
al sneller dan Transformer bij een context van 8k en evenaart het Mamba in wall-clock
tijd. TTT-MLP heeft nog steeds uitdagingen met geheugen-I/O, maar toont groter potentieel
in lange contexten, wat wijst op een veelbelovende richting voor toekomstig onderzoek.
English
Self-attention performs well in long context but has quadratic complexity.
Existing RNN layers have linear complexity, but their performance in long
context is limited by the expressive power of their hidden state. We propose a
new class of sequence modeling layers with linear complexity and an expressive
hidden state. The key idea is to make the hidden state a machine learning model
itself, and the update rule a step of self-supervised learning. Since the
hidden state is updated by training even on test sequences, our layers are
called Test-Time Training (TTT) layers. We consider two instantiations:
TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer
MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B
parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both
TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer,
they can keep reducing perplexity by conditioning on more tokens, while Mamba
cannot after 16k context. With preliminary systems optimization, TTT-Linear is
already faster than Transformer at 8k context and matches Mamba in wall-clock
time. TTT-MLP still faces challenges in memory I/O, but shows larger potential
in long context, pointing to a promising direction for future research.