(테스트 시간에 학습하기): 표현력 있는 은닉 상태를 가진 RNN 학습
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
July 5, 2024
저자: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI
초록
셀프 어텐션은 긴 문맥에서 우수한 성능을 보이지만 이차 복잡도를 가집니다. 기존의 RNN 계층은 선형 복잡도를 가지지만, 긴 문맥에서의 성능은 은닉 상태의 표현력에 의해 제한됩니다. 우리는 선형 복잡도와 표현력 있는 은닉 상태를 가진 새로운 시퀀스 모델링 계층을 제안합니다. 핵심 아이디어는 은닉 상태를 머신 러닝 모델 자체로 만들고, 업데이트 규칙을 자기 지도 학습의 한 단계로 만드는 것입니다. 은닉 상태가 테스트 시퀀스에서도 학습을 통해 업데이트되기 때문에, 우리의 계층을 테스트 시간 학습(Test-Time Training, TTT) 계층이라고 부릅니다. 우리는 두 가지 구현체를 고려합니다: TTT-Linear와 TTT-MLP로, 각각 은닉 상태가 선형 모델과 2층 MLP인 경우입니다. 우리는 125M에서 1.3B 파라미터 규모에서 강력한 Transformer와 현대적인 RNN인 Mamba와 비교하여 구현체를 평가합니다. TTT-Linear와 TTT-MLP 모두 기준 모델과 동등하거나 더 나은 성능을 보입니다. Transformer와 유사하게, 이들은 더 많은 토큰을 조건으로 삼아 perplexity를 계속해서 줄일 수 있지만, Mamba는 16k 문맥 이후에는 이를 할 수 없습니다. 초기 시스템 최적화를 통해 TTT-Linear는 이미 8k 문맥에서 Transformer보다 빠르며, Mamba와 실시간 성능에서 동등합니다. TTT-MLP는 여전히 메모리 I/O에서 어려움을 겪지만, 긴 문맥에서 더 큰 잠재력을 보여 미래 연구를 위한 유망한 방향을 제시합니다.
English
Self-attention performs well in long context but has quadratic complexity.
Existing RNN layers have linear complexity, but their performance in long
context is limited by the expressive power of their hidden state. We propose a
new class of sequence modeling layers with linear complexity and an expressive
hidden state. The key idea is to make the hidden state a machine learning model
itself, and the update rule a step of self-supervised learning. Since the
hidden state is updated by training even on test sequences, our layers are
called Test-Time Training (TTT) layers. We consider two instantiations:
TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer
MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B
parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both
TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer,
they can keep reducing perplexity by conditioning on more tokens, while Mamba
cannot after 16k context. With preliminary systems optimization, TTT-Linear is
already faster than Transformer at 8k context and matches Mamba in wall-clock
time. TTT-MLP still faces challenges in memory I/O, but shows larger potential
in long context, pointing to a promising direction for future research.