填充式瑪巴:基於RNN的長文本建模的狀態崩潰和狀態能力
Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling
October 9, 2024
作者: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
摘要
循環神經網絡(RNN)相對於基於Transformer的語言模型的一個重要優勢是其在序列長度方面具有線性計算複雜度,這使得它們在推理過程中處理長序列時更快。然而,大多數公開可用的RNN(例如Mamba和RWKV)是在少於10K標記的序列上進行訓練的,迄今為止它們在更長範境中的有效性仍然令人不滿。在本文中,我們研究了RNN無法處理長範境的原因並提出了關鍵的緩解方法。我們在應用最先進的RNN到長範境時考慮了兩個實際問題:(1)無法對長於訓練長度的輸入進行外推和(2)記憶容量的上限。針對第一個問題,我們首先研究了*狀態崩潰*(SC),這是一種現象,導致在訓練期間未遇到的序列長度上性能嚴重下降。通過控制實驗,我們將這歸因於由於循環狀態對於訓練長度而言被過度參數化而導致的過度擬合。對於第二個問題,我們在長文檔上訓練了一系列Mamba-2模型,以實證估計語言建模和密鑰檢索中的循環狀態容量。然後,提出了三種SC緩解方法,以提高Mamba-2的長度泛化能力,使模型能夠處理超過1M標記而無SC。我們還發現密鑰檢索中的循環狀態容量與狀態大小呈指數級增長,我們在256K上下文長度上實證訓練了一個Mamba-2 370M,其密鑰檢索準確率接近完美。這表明了基於RNN的長範境建模有著令人期待的未來。
English
One essential advantage of recurrent neural networks (RNNs) over
transformer-based language models is their linear computational complexity
concerning the sequence length, which makes them much faster in handling long
sequences during inference. However, most publicly available RNNs (e.g., Mamba
and RWKV) are trained on sequences with less than 10K tokens, and their
effectiveness in longer contexts remains largely unsatisfying so far. In this
paper, we study the cause of the inability to process long context for RNNs and
suggest critical mitigations. We examine two practical concerns when applying
state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to
inputs longer than the training length and (2) the upper bound of memory
capacity. Addressing the first concern, we first investigate *state collapse*
(SC), a phenomenon that causes severe performance degradation on sequence
lengths not encountered during training. With controlled experiments, we
attribute this to overfitting due to the recurrent state being
overparameterized for the training length. For the second concern, we train a
series of Mamba-2 models on long documents to empirically estimate the
recurrent state capacity in language modeling and passkey retrieval. Then,
three SC mitigation methods are proposed to improve Mamba-2's length
generalizability, allowing the model to process more than 1M tokens without SC.
We also find that the recurrent state capacity in passkey retrieval scales
exponentially to the state size, and we empirically train a Mamba-2 370M with
near-perfect passkey retrieval accuracy on 256K context length. This suggests a
promising future for RNN-based long-context modeling.Summary
AI-Generated Summary