Наполненный Мамба: Крах государства и государственные возможности моделирования длинных контекстов на основе RNN
Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling
October 9, 2024
Авторы: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
Аннотация
Одним из важных преимуществ рекуррентных нейронных сетей (RNN) перед языковыми моделями на основе трансформеров является их линейная вычислительная сложность по длине последовательности, что делает их намного быстрее в обработке длинных последовательностей во время вывода. Однако большинство общедоступных RNN (например, Mamba и RWKV) обучены на последовательностях с менее чем 10 тыс. токенов, и их эффективность в более длинных контекстах до сих пор остается в значительной степени неудовлетворительной. В данной статье мы изучаем причину неспособности обрабатывать длинный контекст для RNN и предлагаем критические меры по устранению этой проблемы. Мы рассматриваем две практические проблемы при применении современных RNN к длинным контекстам: (1) неспособность экстраполировать на входы длиннее длины обучения и (2) верхний предел памяти. Для решения первой проблемы мы в первую очередь исследуем *крах состояния* (SC), явление, которое вызывает серьезное снижение производительности на длинах последовательностей, не встреченных во время обучения. Проведя контролируемые эксперименты, мы приписываем это переобучению из-за избыточного параметризирования рекуррентного состояния для длины обучения. Для второй проблемы мы обучаем серию моделей Mamba-2 на длинных документах для эмпирической оценки емкости рекуррентного состояния в языковом моделировании и извлечении ключа. Затем предлагаются три метода устранения SC для улучшения обобщаемости Mamba-2 по длине, позволяя модели обрабатывать более 1 млн токенов без SC. Мы также обнаруживаем, что емкость рекуррентного состояния в извлечении ключа масштабируется экспоненциально по размеру состояния, и эмпирически обучаем Mamba-2 370M с практически идеальной точностью извлечения ключа на длине контекста 256 тыс. Это указывает на многообещающее будущее для моделирования длинного контекста на основе RNN.
English
One essential advantage of recurrent neural networks (RNNs) over
transformer-based language models is their linear computational complexity
concerning the sequence length, which makes them much faster in handling long
sequences during inference. However, most publicly available RNNs (e.g., Mamba
and RWKV) are trained on sequences with less than 10K tokens, and their
effectiveness in longer contexts remains largely unsatisfying so far. In this
paper, we study the cause of the inability to process long context for RNNs and
suggest critical mitigations. We examine two practical concerns when applying
state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to
inputs longer than the training length and (2) the upper bound of memory
capacity. Addressing the first concern, we first investigate *state collapse*
(SC), a phenomenon that causes severe performance degradation on sequence
lengths not encountered during training. With controlled experiments, we
attribute this to overfitting due to the recurrent state being
overparameterized for the training length. For the second concern, we train a
series of Mamba-2 models on long documents to empirically estimate the
recurrent state capacity in language modeling and passkey retrieval. Then,
three SC mitigation methods are proposed to improve Mamba-2's length
generalizability, allowing the model to process more than 1M tokens without SC.
We also find that the recurrent state capacity in passkey retrieval scales
exponentially to the state size, and we empirically train a Mamba-2 370M with
near-perfect passkey retrieval accuracy on 256K context length. This suggests a
promising future for RNN-based long-context modeling.Summary
AI-Generated Summary