ChatPaper.aiChatPaper

Leren om (te leren tijdens testtijd): RNN's met expressieve verborgen toestanden

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

July 5, 2024
Auteurs: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI

Samenvatting

Self-attention presteert goed in lange contexten maar heeft een kwadratische complexiteit. Bestaande RNN-lagen hebben een lineaire complexiteit, maar hun prestaties in lange contexten worden beperkt door de expressieve kracht van hun verborgen toestand. Wij stellen een nieuwe klasse van sequentiemodelleringslagen voor met lineaire complexiteit en een expressieve verborgen toestand. Het kernidee is om de verborgen toestand zelf een machine learning-model te maken, en de update-regel een stap van zelfgesuperviseerd leren. Omdat de verborgen toestand wordt bijgewerkt door training, zelfs op testsequenties, worden onze lagen Test-Time Training (TTT) lagen genoemd. We beschouwen twee instantiaties: TTT-Linear en TTT-MLP, waarvan de verborgen toestand respectievelijk een lineair model en een tweelaags MLP is. We evalueren onze instantiaties op een schaal van 125M tot 1,3B parameters, in vergelijking met een sterke Transformer en Mamba, een moderne RNN. Zowel TTT-Linear als TTT-MLP evenaren of overtreffen de referentiemodellen. Net als Transformer, kunnen ze de perplexiteit blijven verlagen door zich te conditioneren op meer tokens, terwijl Mamba dat niet kan na een context van 16k. Met voorlopige systeemoptimalisaties is TTT-Linear al sneller dan Transformer bij een context van 8k en evenaart het Mamba in wall-clock tijd. TTT-MLP heeft nog steeds uitdagingen met geheugen-I/O, maar toont groter potentieel in lange contexten, wat wijst op een veelbelovende richting voor toekomstig onderzoek.
English
Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer, they can keep reducing perplexity by conditioning on more tokens, while Mamba cannot after 16k context. With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O, but shows larger potential in long context, pointing to a promising direction for future research.
PDF342February 8, 2026