ChatPaper.aiChatPaper

Gevulde Mamba: Staatssamenstorting en Staatscapaciteit van op RNN Gebaseerde Lang-Context Modellering

Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling

October 9, 2024
Auteurs: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI

Samenvatting

Een essentieel voordeel van recurrente neurale netwerken (RNN's) ten opzichte van op transformatoren gebaseerde taalmodellen is hun lineaire rekenkundige complexiteit met betrekking tot de sequentielengte, waardoor ze veel sneller zijn in het verwerken van lange sequenties tijdens inferentie. Echter, de meeste publiekelijk beschikbare RNN's (bijv. Mamba en RWKV) zijn getraind op sequenties met minder dan 10K tokens, en hun effectiviteit in langere contexten blijft tot nu toe grotendeels onbevredigend. In dit artikel bestuderen we de oorzaak van het onvermogen van RNN's om lange contexten te verwerken en suggereren we kritieke verlichtingen. We onderzoeken twee praktische zorgen bij het toepassen van state-of-the-art RNN's op lange contexten: (1) het onvermogen om te extrapoleren naar invoer langer dan de trainingslengte en (2) de bovengrens van geheugencapaciteit. Om de eerste zorg aan te pakken, onderzoeken we eerst *state collapse* (SC), een fenomeen dat leidt tot ernstige prestatievermindering bij sequentielengtes die niet tijdens de training zijn tegengekomen. Met gecontroleerde experimenten schrijven we dit toe aan overfitting als gevolg van de overparameterisatie van de recurrente staat voor de trainingslengte. Voor de tweede zorg trainen we een reeks Mamba-2 modellen op lange documenten om empirisch de recurrente staatcapaciteit in taalmodellering en passkey-opvraging te schatten. Vervolgens worden drie SC-verminderingsmethoden voorgesteld om de lengtegeneraliseerbaarheid van Mamba-2 te verbeteren, waardoor het model meer dan 1M tokens kan verwerken zonder SC. We vinden ook dat de recurrente staatcapaciteit bij passkey-opvraging exponentieel schaalt met de staatgrootte, en we trainen empirisch een Mamba-2 370M met bijna perfecte passkey-opvraagnauwkeurigheid op een contextlengte van 256K. Dit wijst op een veelbelovende toekomst voor op RNN's gebaseerde modellering van lange contexten.
English
One essential advantage of recurrent neural networks (RNNs) over transformer-based language models is their linear computational complexity concerning the sequence length, which makes them much faster in handling long sequences during inference. However, most publicly available RNNs (e.g., Mamba and RWKV) are trained on sequences with less than 10K tokens, and their effectiveness in longer contexts remains largely unsatisfying so far. In this paper, we study the cause of the inability to process long context for RNNs and suggest critical mitigations. We examine two practical concerns when applying state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to inputs longer than the training length and (2) the upper bound of memory capacity. Addressing the first concern, we first investigate *state collapse* (SC), a phenomenon that causes severe performance degradation on sequence lengths not encountered during training. With controlled experiments, we attribute this to overfitting due to the recurrent state being overparameterized for the training length. For the second concern, we train a series of Mamba-2 models on long documents to empirically estimate the recurrent state capacity in language modeling and passkey retrieval. Then, three SC mitigation methods are proposed to improve Mamba-2's length generalizability, allowing the model to process more than 1M tokens without SC. We also find that the recurrent state capacity in passkey retrieval scales exponentially to the state size, and we empirically train a Mamba-2 370M with near-perfect passkey retrieval accuracy on 256K context length. This suggests a promising future for RNN-based long-context modeling.

Summary

AI-Generated Summary

PDF23November 16, 2024