Gevulde Mamba: Staatssamenstorting en Staatscapaciteit van op RNN Gebaseerde Lang-Context Modellering
Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling
October 9, 2024
Auteurs: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
Samenvatting
Een essentieel voordeel van recurrente neurale netwerken (RNN's) ten opzichte van op transformatoren gebaseerde taalmodellen is hun lineaire rekenkundige complexiteit met betrekking tot de sequentielengte, waardoor ze veel sneller zijn in het verwerken van lange sequenties tijdens inferentie. Echter, de meeste publiekelijk beschikbare RNN's (bijv. Mamba en RWKV) zijn getraind op sequenties met minder dan 10K tokens, en hun effectiviteit in langere contexten blijft tot nu toe grotendeels onbevredigend. In dit artikel bestuderen we de oorzaak van het onvermogen van RNN's om lange contexten te verwerken en suggereren we kritieke verlichtingen. We onderzoeken twee praktische zorgen bij het toepassen van state-of-the-art RNN's op lange contexten: (1) het onvermogen om te extrapoleren naar invoer langer dan de trainingslengte en (2) de bovengrens van geheugencapaciteit. Om de eerste zorg aan te pakken, onderzoeken we eerst *state collapse* (SC), een fenomeen dat leidt tot ernstige prestatievermindering bij sequentielengtes die niet tijdens de training zijn tegengekomen. Met gecontroleerde experimenten schrijven we dit toe aan overfitting als gevolg van de overparameterisatie van de recurrente staat voor de trainingslengte. Voor de tweede zorg trainen we een reeks Mamba-2 modellen op lange documenten om empirisch de recurrente staatcapaciteit in taalmodellering en passkey-opvraging te schatten. Vervolgens worden drie SC-verminderingsmethoden voorgesteld om de lengtegeneraliseerbaarheid van Mamba-2 te verbeteren, waardoor het model meer dan 1M tokens kan verwerken zonder SC. We vinden ook dat de recurrente staatcapaciteit bij passkey-opvraging exponentieel schaalt met de staatgrootte, en we trainen empirisch een Mamba-2 370M met bijna perfecte passkey-opvraagnauwkeurigheid op een contextlengte van 256K. Dit wijst op een veelbelovende toekomst voor op RNN's gebaseerde modellering van lange contexten.
English
One essential advantage of recurrent neural networks (RNNs) over
transformer-based language models is their linear computational complexity
concerning the sequence length, which makes them much faster in handling long
sequences during inference. However, most publicly available RNNs (e.g., Mamba
and RWKV) are trained on sequences with less than 10K tokens, and their
effectiveness in longer contexts remains largely unsatisfying so far. In this
paper, we study the cause of the inability to process long context for RNNs and
suggest critical mitigations. We examine two practical concerns when applying
state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to
inputs longer than the training length and (2) the upper bound of memory
capacity. Addressing the first concern, we first investigate *state collapse*
(SC), a phenomenon that causes severe performance degradation on sequence
lengths not encountered during training. With controlled experiments, we
attribute this to overfitting due to the recurrent state being
overparameterized for the training length. For the second concern, we train a
series of Mamba-2 models on long documents to empirically estimate the
recurrent state capacity in language modeling and passkey retrieval. Then,
three SC mitigation methods are proposed to improve Mamba-2's length
generalizability, allowing the model to process more than 1M tokens without SC.
We also find that the recurrent state capacity in passkey retrieval scales
exponentially to the state size, and we empirically train a Mamba-2 370M with
near-perfect passkey retrieval accuracy on 256K context length. This suggests a
promising future for RNN-based long-context modeling.Summary
AI-Generated Summary