ChatPaper.aiChatPaper

Aprendiendo a (Aprender en Tiempo de Prueba): RNR con Estados Ocultos Expresivos

Learning to (Learn at Test Time): RNNs with Expressive Hidden States

July 5, 2024
Autores: Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
cs.AI

Resumen

La autoatención funciona bien en contextos largos pero tiene una complejidad cuadrática. Las capas RNN existentes tienen complejidad lineal, pero su rendimiento en contextos largos está limitado por la capacidad expresiva de su estado oculto. Proponemos una nueva clase de capas de modelado de secuencias con complejidad lineal y un estado oculto expresivo. La idea clave es hacer que el estado oculto sea un modelo de aprendizaje automático en sí mismo, y la regla de actualización un paso de aprendizaje auto-supervisado. Dado que el estado oculto se actualiza mediante el entrenamiento incluso en secuencias de prueba, nuestras capas se llaman capas de Entrenamiento en Tiempo de Prueba (TTT). Consideramos dos instanciaciones: TTT-Lineal y TTT-MLP, cuyo estado oculto es un modelo lineal y un MLP de dos capas respectivamente. Evaluamos nuestras instanciaciones en una escala de 125M a 1.3B parámetros, comparando con un Transformer sólido y Mamba, una RNN moderna. Tanto TTT-Lineal como TTT-MLP igualan o superan los resultados base. Al igual que Transformer, pueden seguir reduciendo la perplejidad condicionando más tokens, mientras que Mamba no puede hacerlo después de 16k contextos. Con la optimización preliminar de sistemas, TTT-Lineal ya es más rápido que Transformer en 8k contextos y coincide con Mamba en tiempo de reloj. TTT-MLP todavía enfrenta desafíos en la memoria de E/S, pero muestra un mayor potencial en contextos largos, apuntando en una dirección prometedora para futuras investigaciones.
English
Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer, they can keep reducing perplexity by conditioning on more tokens, while Mamba cannot after 16k context. With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O, but shows larger potential in long context, pointing to a promising direction for future research.

Summary

AI-Generated Summary

PDF322November 28, 2024