Mamba imbottito: Collasso dello stato e capacità dello stato di modellizzazione a lungo contesto basata su RNN
Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling
October 9, 2024
Autori: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
Abstract
Un vantaggio essenziale delle reti neurali ricorrenti (RNN) rispetto ai modelli linguistici basati su trasformatori è la loro complessità computazionale lineare rispetto alla lunghezza della sequenza, il che le rende molto più veloci nel gestire sequenze lunghe durante l'inferenza. Tuttavia, la maggior parte delle RNN disponibili pubblicamente (ad esempio, Mamba e RWKV) sono addestrate su sequenze con meno di 10.000 token, e la loro efficacia in contesti più lunghi finora è rimasta in gran parte insoddisfacente. In questo articolo, studiamo la causa dell'incapacità di elaborare contesti lunghi per le RNN e suggeriamo mitigazioni critiche. Esaminiamo due preoccupazioni pratiche nell'applicare le RNN all'avanguardia a contesti lunghi: (1) l'incapacità di estrapolare a input più lunghi della lunghezza di addestramento e (2) il limite superiore della capacità di memoria. Affrontando la prima preoccupazione, indaghiamo prima il *collasso dello stato* (SC), un fenomeno che causa un grave degrado delle prestazioni su lunghezze di sequenza non incontrate durante l'addestramento. Con esperimenti controllati, attribuiamo ciò all'overfitting dovuto allo stato ricorrente che è sovradimensionato rispetto alla lunghezza di addestramento. Per la seconda preoccupazione, addestriamo una serie di modelli Mamba-2 su documenti lunghi per stimare empiricamente la capacità dello stato ricorrente nella modellizzazione del linguaggio e nel recupero della chiave di accesso. Successivamente, vengono proposti tre metodi di mitigazione dello SC per migliorare la generalizzabilità della lunghezza di Mamba-2, consentendo al modello di elaborare più di 1 milione di token senza SC. Troviamo anche che la capacità dello stato ricorrente nel recupero della chiave di accesso scala in modo esponenziale rispetto alla dimensione dello stato, e addestriamo empiricamente un Mamba-2 370M con un'accuratezza di recupero della chiave di accesso quasi perfetta su una lunghezza di contesto di 256.000. Ciò suggerisce un futuro promettente per la modellizzazione di contesti lunghi basata su RNN.
English
One essential advantage of recurrent neural networks (RNNs) over
transformer-based language models is their linear computational complexity
concerning the sequence length, which makes them much faster in handling long
sequences during inference. However, most publicly available RNNs (e.g., Mamba
and RWKV) are trained on sequences with less than 10K tokens, and their
effectiveness in longer contexts remains largely unsatisfying so far. In this
paper, we study the cause of the inability to process long context for RNNs and
suggest critical mitigations. We examine two practical concerns when applying
state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to
inputs longer than the training length and (2) the upper bound of memory
capacity. Addressing the first concern, we first investigate *state collapse*
(SC), a phenomenon that causes severe performance degradation on sequence
lengths not encountered during training. With controlled experiments, we
attribute this to overfitting due to the recurrent state being
overparameterized for the training length. For the second concern, we train a
series of Mamba-2 models on long documents to empirically estimate the
recurrent state capacity in language modeling and passkey retrieval. Then,
three SC mitigation methods are proposed to improve Mamba-2's length
generalizability, allowing the model to process more than 1M tokens without SC.
We also find that the recurrent state capacity in passkey retrieval scales
exponentially to the state size, and we empirically train a Mamba-2 370M with
near-perfect passkey retrieval accuracy on 256K context length. This suggests a
promising future for RNN-based long-context modeling.