Lernen, um (Zur Testzeit zu Lernen): RNNs mit Ausdrucksstarken Versteckten Zuständen
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
July 5, 2024
Autoren: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI
Zusammenfassung
Die Selbst-Aufmerksamkeit funktioniert gut bei langem Kontext, hat jedoch quadratische Komplexität. Bestehende RNN-Schichten haben lineare Komplexität, aber ihre Leistung bei langem Kontext wird durch die Ausdruckskraft ihres versteckten Zustands begrenzt. Wir schlagen eine neue Klasse von Sequenzmodellierungsschichten mit linearer Komplexität und einem ausdrucksstarken versteckten Zustand vor. Die Schlüsselidee besteht darin, den versteckten Zustand selbst zu einem maschinellen Lernmodell zu machen und die Aktualisierungsregel zu einem Schritt des selbstüberwachten Lernens. Da der versteckte Zustand durch Training auch auf Testsequenzen aktualisiert wird, werden unsere Schichten Testzeit-Trainings (TTT) Schichten genannt. Wir betrachten zwei Instantiierungen: TTT-Linear und TTT-MLP, deren versteckter Zustand jeweils ein lineares Modell und ein MLP mit zwei Schichten ist. Wir evaluieren unsere Instantiierungen im Maßstab von 125M bis 1.3B Parametern, verglichen mit einem leistungsstarken Transformer und Mamba, einem modernen RNN. Sowohl TTT-Linear als auch TTT-MLP entsprechen oder übertreffen die Basislinien. Ähnlich wie der Transformer können sie die Perplexität weiter reduzieren, indem sie sich auf mehr Tokens beziehen, während Mamba dies nach 16k Kontext nicht kann. Mit vorläufiger Systemoptimierung ist TTT-Linear bereits schneller als der Transformer bei 8k Kontext und entspricht Mamba in der Wanduhrzeit. TTT-MLP steht noch vor Herausforderungen im Speicher-I/O, zeigt jedoch ein größeres Potenzial bei langem Kontext und weist in eine vielversprechende Richtung für zukünftige Forschung.
English
Self-attention performs well in long context but has quadratic complexity.
Existing RNN layers have linear complexity, but their performance in long
context is limited by the expressive power of their hidden state. We propose a
new class of sequence modeling layers with linear complexity and an expressive
hidden state. The key idea is to make the hidden state a machine learning model
itself, and the update rule a step of self-supervised learning. Since the
hidden state is updated by training even on test sequences, our layers are
called Test-Time Training (TTT) layers. We consider two instantiations:
TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer
MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B
parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both
TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer,
they can keep reducing perplexity by conditioning on more tokens, while Mamba
cannot after 16k context. With preliminary systems optimization, TTT-Linear is
already faster than Transformer at 8k context and matches Mamba in wall-clock
time. TTT-MLP still faces challenges in memory I/O, but shows larger potential
in long context, pointing to a promising direction for future research.Summary
AI-Generated Summary