(テスト時に学習することを学ぶ):表現力豊かな隠れ状態を持つRNN
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
July 5, 2024
著者: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI
要旨
Self-attentionは長い文脈において優れた性能を発揮しますが、計算量が二次的に増加するという課題があります。既存のRNN層は計算量が線形であるものの、その隠れ状態の表現力によって長い文脈での性能が制限されています。本研究では、線形計算量でありながら表現力の高い隠れ状態を持つ新しいシーケンスモデリング層を提案します。鍵となるアイデアは、隠れ状態自体を機械学習モデルとし、更新ルールを自己教師あり学習のステップとすることです。隠れ状態はテストシーケンスにおいても訓練によって更新されるため、我々の層はTest-Time Training (TTT)層と呼ばれます。具体的な実装として、隠れ状態が線形モデルであるTTT-Linearと、2層MLPであるTTT-MLPの2つを考案しました。125Mから1.3Bパラメータの規模で評価を行い、強力なTransformerと現代的なRNNであるMambaと比較しました。その結果、TTT-LinearとTTT-MLPは両方ともベースラインを上回るか同等の性能を示しました。Transformerと同様に、これらはより多くのトークンを条件付けすることでパープレキシティを継続的に低減できますが、Mambaは16k以上の文脈ではそれができませんでした。システムの最適化を予備的に行った結果、TTT-Linearは8k文脈においてすでにTransformerよりも高速であり、Mambaと同等の実時間性能を達成しました。TTT-MLPはメモリI/Oにおいてまだ課題を抱えていますが、長い文脈においてより大きな可能性を示しており、今後の研究にとって有望な方向性を指し示しています。
English
Self-attention performs well in long context but has quadratic complexity.
Existing RNN layers have linear complexity, but their performance in long
context is limited by the expressive power of their hidden state. We propose a
new class of sequence modeling layers with linear complexity and an expressive
hidden state. The key idea is to make the hidden state a machine learning model
itself, and the update rule a step of self-supervised learning. Since the
hidden state is updated by training even on test sequences, our layers are
called Test-Time Training (TTT) layers. We consider two instantiations:
TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer
MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B
parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both
TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer,
they can keep reducing perplexity by conditioning on more tokens, while Mamba
cannot after 16k context. With preliminary systems optimization, TTT-Linear is
already faster than Transformer at 8k context and matches Mamba in wall-clock
time. TTT-MLP still faces challenges in memory I/O, but shows larger potential
in long context, pointing to a promising direction for future research.Summary
AI-Generated Summary