ChatPaper.aiChatPaper

詰め物マンバ:RNNベースの長文脈モデリングの状態崩壊と状態容量

Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling

October 9, 2024
著者: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI

要旨

再帰ニューラルネットワーク(RNN)がトランスフォーマーベースの言語モデルに対して持つ重要な利点の1つは、シーケンス長に関する線形計算の複雑さであり、これにより推論中に長いシーケンスを処理する際にはるかに高速になります。ただし、ほとんどの公開されているRNN(例:MambaおよびRWKV)は、1万トークン未満のシーケンスで訓練されており、これまでに長い文脈での効果が不十分であることが大きな問題となっています。本論文では、RNNが長い文脈を処理できない原因を研究し、重要な緩和策を提案します。最先端のRNNを長い文脈に適用する際の2つの実用的な懸念を検討します:(1)訓練長よりも長い入力に外挿できないこと、および(2)メモリ容量の上限。最初の懸念に対処するために、まず、*state collapse*(SC)という現象を調査します。これは、訓練中に遭遇しなかったシーケンス長での性能劣化を引き起こす現象です。制御された実験により、これを訓練長に対して再帰状態が過剰にパラメータ化されることによる過学習と特定します。2つ目の懸念に対して、長い文書で一連のMamba-2モデルを訓練し、言語モデリングとパスキー検索における再帰状態の容量を経験的に推定します。その後、Mamba-2の長さの一般化性を向上させるために3つのSC緩和方法が提案され、モデルがSCなしで100万トークン以上を処理できるようになります。また、パスキー検索における再帰状態の容量は状態サイズに指数関数的にスケールし、256Kのコンテキスト長でほぼ完璧なパスキー検索精度を持つMamba-2 370Mを経験的に訓練します。これは、RNNに基づく長い文脈モデリングに有望な未来を示唆しています。
English
One essential advantage of recurrent neural networks (RNNs) over transformer-based language models is their linear computational complexity concerning the sequence length, which makes them much faster in handling long sequences during inference. However, most publicly available RNNs (e.g., Mamba and RWKV) are trained on sequences with less than 10K tokens, and their effectiveness in longer contexts remains largely unsatisfying so far. In this paper, we study the cause of the inability to process long context for RNNs and suggest critical mitigations. We examine two practical concerns when applying state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to inputs longer than the training length and (2) the upper bound of memory capacity. Addressing the first concern, we first investigate *state collapse* (SC), a phenomenon that causes severe performance degradation on sequence lengths not encountered during training. With controlled experiments, we attribute this to overfitting due to the recurrent state being overparameterized for the training length. For the second concern, we train a series of Mamba-2 models on long documents to empirically estimate the recurrent state capacity in language modeling and passkey retrieval. Then, three SC mitigation methods are proposed to improve Mamba-2's length generalizability, allowing the model to process more than 1M tokens without SC. We also find that the recurrent state capacity in passkey retrieval scales exponentially to the state size, and we empirically train a Mamba-2 370M with near-perfect passkey retrieval accuracy on 256K context length. This suggests a promising future for RNN-based long-context modeling.

Summary

AI-Generated Summary

PDF23November 16, 2024