Aprendendo a (Aprender no Momento do Teste): RNNs com Estados Ocultos Expressivos
Learning to (Learn at Test Time): RNNs with Expressive Hidden States
July 5, 2024
Autores: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI
Resumo
A autoatenção tem um bom desempenho em contextos longos, mas possui complexidade quadrática. As camadas de RNN existentes têm complexidade linear, porém seu desempenho em contextos longos é limitado pela capacidade expressiva de seu estado oculto. Propomos uma nova classe de camadas de modelagem de sequência com complexidade linear e um estado oculto expressivo. A ideia principal é tornar o estado oculto um modelo de aprendizado de máquina em si e a regra de atualização um passo de aprendizado auto-supervisionado. Uma vez que o estado oculto é atualizado por meio de treinamento, mesmo em sequências de teste, nossas camadas são chamadas de camadas de Treinamento em Tempo de Teste (TTT). Consideramos duas instâncias: TTT-Linear e TTT-MLP, cujo estado oculto é um modelo linear e um MLP de duas camadas, respectivamente. Avaliamos nossas instâncias na escala de 125M a 1.3B parâmetros, comparando com um Transformer forte e o Mamba, uma RNN moderna. Tanto o TTT-Linear quanto o TTT-MLP correspondem ou superam as referências. Assim como o Transformer, eles conseguem reduzir a perplexidade ao condicionar em mais tokens, enquanto o Mamba não consegue após 16k contextos. Com otimização preliminar dos sistemas, o TTT-Linear já é mais rápido que o Transformer em 8k contextos e corresponde ao Mamba em tempo de parede. O TTT-MLP ainda enfrenta desafios em memória I/O, mas mostra um maior potencial em contextos longos, apontando para uma direção promissora para pesquisas futuras.
English
Self-attention performs well in long context but has quadratic complexity.
Existing RNN layers have linear complexity, but their performance in long
context is limited by the expressive power of their hidden state. We propose a
new class of sequence modeling layers with linear complexity and an expressive
hidden state. The key idea is to make the hidden state a machine learning model
itself, and the update rule a step of self-supervised learning. Since the
hidden state is updated by training even on test sequences, our layers are
called Test-Time Training (TTT) layers. We consider two instantiations:
TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer
MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B
parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both
TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer,
they can keep reducing perplexity by conditioning on more tokens, while Mamba
cannot after 16k context. With preliminary systems optimization, TTT-Linear is
already faster than Transformer at 8k context and matches Mamba in wall-clock
time. TTT-MLP still faces challenges in memory I/O, but shows larger potential
in long context, pointing to a promising direction for future research.