Mamba Rellena: Colapso del Estado y Capacidad del Estado en Modelado de Largo Contexto Basado en RNN
Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling
October 9, 2024
Autores: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
Resumen
Una ventaja esencial de las redes neuronales recurrentes (RNN) sobre los modelos de lenguaje basados en transformadores es su complejidad computacional lineal con respecto a la longitud de la secuencia, lo que las hace mucho más rápidas para manejar secuencias largas durante la inferencia. Sin embargo, la mayoría de las RNN disponibles públicamente (por ejemplo, Mamba y RWKV) están entrenadas en secuencias con menos de 10K tokens, y su efectividad en contextos más largos sigue siendo en gran medida insatisfactoria hasta ahora. En este artículo, estudiamos la causa de la incapacidad de procesar contextos largos para las RNN y sugerimos mitigaciones críticas. Examinamos dos preocupaciones prácticas al aplicar RNN de última generación a contextos largos: (1) la incapacidad de extrapolar a entradas más largas que la longitud de entrenamiento y (2) el límite superior de la capacidad de memoria. Para abordar la primera preocupación, investigamos primero el *colapso de estado* (SC), un fenómeno que causa una degradación severa del rendimiento en longitudes de secuencia no encontradas durante el entrenamiento. Con experimentos controlados, atribuimos esto al sobreajuste debido a que el estado recurrente está sobreparametrizado para la longitud de entrenamiento. Para la segunda preocupación, entrenamos una serie de modelos Mamba-2 en documentos largos para estimar empíricamente la capacidad del estado recurrente en modelado de lenguaje y recuperación de clave. Luego, se proponen tres métodos de mitigación de SC para mejorar la capacidad de generalización de longitud de Mamba-2, permitiendo que el modelo procese más de 1M tokens sin SC. También encontramos que la capacidad del estado recurrente en la recuperación de clave escala de manera exponencial con el tamaño del estado, y entrenamos empíricamente un Mamba-2 370M con una precisión de recuperación de clave casi perfecta en una longitud de contexto de 256K. Esto sugiere un futuro prometedor para el modelado de contextos largos basado en RNN.
English
One essential advantage of recurrent neural networks (RNNs) over
transformer-based language models is their linear computational complexity
concerning the sequence length, which makes them much faster in handling long
sequences during inference. However, most publicly available RNNs (e.g., Mamba
and RWKV) are trained on sequences with less than 10K tokens, and their
effectiveness in longer contexts remains largely unsatisfying so far. In this
paper, we study the cause of the inability to process long context for RNNs and
suggest critical mitigations. We examine two practical concerns when applying
state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to
inputs longer than the training length and (2) the upper bound of memory
capacity. Addressing the first concern, we first investigate *state collapse*
(SC), a phenomenon that causes severe performance degradation on sequence
lengths not encountered during training. With controlled experiments, we
attribute this to overfitting due to the recurrent state being
overparameterized for the training length. For the second concern, we train a
series of Mamba-2 models on long documents to empirically estimate the
recurrent state capacity in language modeling and passkey retrieval. Then,
three SC mitigation methods are proposed to improve Mamba-2's length
generalizability, allowing the model to process more than 1M tokens without SC.
We also find that the recurrent state capacity in passkey retrieval scales
exponentially to the state size, and we empirically train a Mamba-2 370M with
near-perfect passkey retrieval accuracy on 256K context length. This suggests a
promising future for RNN-based long-context modeling.Summary
AI-Generated Summary