ChatPaper.aiChatPaper

Mamba Recheada: Colapso de Estado e Capacidade de Estado de Modelagem de Longo Contexto Baseada em RNN

Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling

October 9, 2024
Autores: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI

Resumo

Uma vantagem essencial das redes neurais recorrentes (RNNs) sobre os modelos de linguagem baseados em transformadores é a sua complexidade computacional linear em relação ao comprimento da sequência, o que as torna muito mais rápidas no processamento de sequências longas durante a inferência. No entanto, a maioria das RNNs disponíveis publicamente (por exemplo, Mamba e RWKV) são treinadas em sequências com menos de 10 mil tokens, e sua eficácia em contextos mais longos tem sido amplamente insatisfatória até o momento. Neste artigo, estudamos a causa da incapacidade de processar contextos longos para as RNNs e sugerimos mitigadores críticos. Examinamos duas preocupações práticas ao aplicar RNNs de última geração a contextos longos: (1) a incapacidade de extrapolar para entradas mais longas do que o comprimento de treinamento e (2) o limite superior da capacidade de memória. Para abordar a primeira preocupação, investigamos inicialmente o *colapso de estado* (SC), um fenômeno que causa degradação severa de desempenho em comprimentos de sequência não encontrados durante o treinamento. Com experimentos controlados, atribuímos isso ao overfitting devido ao estado recorrente estar superparametrizado para o comprimento de treinamento. Para a segunda preocupação, treinamos uma série de modelos Mamba-2 em documentos longos para estimar empiricamente a capacidade do estado recorrente em modelagem de linguagem e recuperação de passkey. Em seguida, três métodos de mitigação de SC são propostos para melhorar a capacidade de generalização de comprimento do Mamba-2, permitindo que o modelo processe mais de 1 milhão de tokens sem SC. Também descobrimos que a capacidade do estado recorrente na recuperação de passkey escala exponencialmente com o tamanho do estado, e treinamos empiricamente um Mamba-2 370M com precisão de recuperação de passkey quase perfeita em um comprimento de contexto de 256 mil. Isso sugere um futuro promissor para a modelagem de longo contexto baseada em RNNs.
English
One essential advantage of recurrent neural networks (RNNs) over transformer-based language models is their linear computational complexity concerning the sequence length, which makes them much faster in handling long sequences during inference. However, most publicly available RNNs (e.g., Mamba and RWKV) are trained on sequences with less than 10K tokens, and their effectiveness in longer contexts remains largely unsatisfying so far. In this paper, we study the cause of the inability to process long context for RNNs and suggest critical mitigations. We examine two practical concerns when applying state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to inputs longer than the training length and (2) the upper bound of memory capacity. Addressing the first concern, we first investigate *state collapse* (SC), a phenomenon that causes severe performance degradation on sequence lengths not encountered during training. With controlled experiments, we attribute this to overfitting due to the recurrent state being overparameterized for the training length. For the second concern, we train a series of Mamba-2 models on long documents to empirically estimate the recurrent state capacity in language modeling and passkey retrieval. Then, three SC mitigation methods are proposed to improve Mamba-2's length generalizability, allowing the model to process more than 1M tokens without SC. We also find that the recurrent state capacity in passkey retrieval scales exponentially to the state size, and we empirically train a Mamba-2 370M with near-perfect passkey retrieval accuracy on 256K context length. This suggests a promising future for RNN-based long-context modeling.

Summary

AI-Generated Summary

PDF23November 16, 2024