ChatPaper.aiChatPaper

Gefüllte Mamba: Zustandskollaps und Staatskapazität von RNN-basiertem Langkontext-Modellieren

Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling

October 9, 2024
Autoren: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI

Zusammenfassung

Ein wesentlicher Vorteil von rekurrenten neuronalen Netzwerken (RNNs) gegenüber transformerbasierten Sprachmodellen ist ihre lineare Rechenkomplexität in Bezug auf die Sequenzlänge, was sie bei der Verarbeitung langer Sequenzen während der Inferenz wesentlich schneller macht. Die meisten öffentlich verfügbaren RNNs (z. B. Mamba und RWKV) sind jedoch auf Sequenzen mit weniger als 10.000 Tokens trainiert, und ihre Effektivität in längeren Kontexten bleibt bisher weitgehend unbefriedigend. In diesem Paper untersuchen wir die Ursache der Unfähigkeit von RNNs, lange Kontexte zu verarbeiten, und schlagen kritische Maßnahmen vor. Wir untersuchen zwei praktische Anliegen bei der Anwendung von modernen RNNs auf lange Kontexte: (1) die Unfähigkeit, auf Eingaben länger als die Trainingslänge zu extrapolieren, und (2) die obere Grenze der Speicherkapazität. Um das erste Anliegen anzugehen, untersuchen wir zunächst *state collapse* (SC), ein Phänomen, das zu schwerwiegenden Leistungseinbußen bei Sequenzlängen führt, die während des Trainings nicht aufgetreten sind. Mit kontrollierten Experimenten führen wir dies auf Overfitting zurück, das durch den überparametrisierten rekurrenten Zustand für die Trainingslänge verursacht wird. Für das zweite Anliegen trainieren wir eine Reihe von Mamba-2-Modellen auf langen Dokumenten, um die rekurrente Zustandskapazität in der Sprachmodellierung und Passwortabruf empirisch abzuschätzen. Anschließend werden drei SC-Minderungsmethoden vorgeschlagen, um die Längengeneralisierbarkeit von Mamba-2 zu verbessern und dem Modell zu ermöglichen, mehr als 1 Million Tokens ohne SC zu verarbeiten. Wir stellen auch fest, dass die rekurrente Zustandskapazität beim Passwortabruf exponentiell mit der Zustandsgröße skaliert, und wir trainieren empirisch ein Mamba-2 370M mit nahezu perfekter Passwortabrufgenauigkeit bei einer Kontextlänge von 256.000. Dies deutet auf eine vielversprechende Zukunft für RNN-basierte Modellierung langer Kontexte hin.
English
One essential advantage of recurrent neural networks (RNNs) over transformer-based language models is their linear computational complexity concerning the sequence length, which makes them much faster in handling long sequences during inference. However, most publicly available RNNs (e.g., Mamba and RWKV) are trained on sequences with less than 10K tokens, and their effectiveness in longer contexts remains largely unsatisfying so far. In this paper, we study the cause of the inability to process long context for RNNs and suggest critical mitigations. We examine two practical concerns when applying state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to inputs longer than the training length and (2) the upper bound of memory capacity. Addressing the first concern, we first investigate *state collapse* (SC), a phenomenon that causes severe performance degradation on sequence lengths not encountered during training. With controlled experiments, we attribute this to overfitting due to the recurrent state being overparameterized for the training length. For the second concern, we train a series of Mamba-2 models on long documents to empirically estimate the recurrent state capacity in language modeling and passkey retrieval. Then, three SC mitigation methods are proposed to improve Mamba-2's length generalizability, allowing the model to process more than 1M tokens without SC. We also find that the recurrent state capacity in passkey retrieval scales exponentially to the state size, and we empirically train a Mamba-2 370M with near-perfect passkey retrieval accuracy on 256K context length. This suggests a promising future for RNN-based long-context modeling.

Summary

AI-Generated Summary

PDF23November 16, 2024