ChatPaper.aiChatPaper

Imparare a (Imparare al Momento del Test): RNN con Stati Nascosti Espressivi

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

July 5, 2024
Autori: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI

Abstract

L'attenzione self-attention ottiene buoni risultati in contesti lunghi ma ha una complessità quadratica. Gli strati RNN esistenti hanno una complessità lineare, ma le loro prestazioni in contesti lunghi sono limitate dal potere espressivo del loro stato nascosto. Proponiamo una nuova classe di strati di modellazione sequenziale con complessità lineare e uno stato nascosto espressivo. L'idea chiave è rendere lo stato nascosto un modello di machine learning esso stesso, e la regola di aggiornamento un passo di apprendimento self-supervised. Poiché lo stato nascosto viene aggiornato attraverso l'addestramento anche su sequenze di test, i nostri strati sono chiamati strati Test-Time Training (TTT). Consideriamo due istanze: TTT-Linear e TTT-MLP, il cui stato nascosto è rispettivamente un modello lineare e un MLP a due strati. Valutiamo le nostre istanze su una scala da 125M a 1.3B parametri, confrontandole con un Transformer robusto e Mamba, un RNN moderno. Sia TTT-Linear che TTT-MLP eguagliano o superano i benchmark. Similmente al Transformer, possono continuare a ridurre la perplexità condizionandosi su più token, mentre Mamba non riesce dopo un contesto di 16k. Con un'ottimizzazione preliminare dei sistemi, TTT-Linear è già più veloce del Transformer a un contesto di 8k e eguaglia Mamba in termini di tempo reale. TTT-MLP affronta ancora sfide nell'I/O della memoria, ma mostra un potenziale maggiore in contesti lunghi, indicando una direzione promettente per la ricerca futura.
English
Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer, they can keep reducing perplexity by conditioning on more tokens, while Mamba cannot after 16k context. With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O, but shows larger potential in long context, pointing to a promising direction for future research.
PDF342November 28, 2024