填充式曼巴:基于RNN的长上下文建模的状态崩溃和状态容量
Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling
October 9, 2024
作者: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
摘要
相对于基于Transformer的语言模型,循环神经网络(RNNs)的一个重要优势是其在序列长度方面具有线性计算复杂度,这使得它们在推断过程中处理长序列时要快得多。然而,大多数公开可用的RNNs(例如Mamba和RWKV)是在少于10K个标记的序列上训练的,迄今为止它们在更长上下文中的有效性仍然令人不满。在本文中,我们研究了RNN无法处理长上下文的原因,并提出了关键的缓解方法。我们在将最先进的RNNs应用于长上下文时考虑了两个实际问题:(1)无法外推到超出训练长度的输入和(2)内存容量的上限。针对第一个问题,我们首先调查了*状态崩溃*(SC),这是一种现象,导致在训练期间未遇到的序列长度严重性能下降。通过受控实验,我们将这归因于由于循环状态对于训练长度而言过度参数化而导致的过拟合。对于第二个问题,我们在长文档上训练了一系列Mamba-2模型,以经验估计语言建模和密钥检索中的循环状态容量。然后,提出了三种SC缓解方法,以提高Mamba-2的长度泛化能力,使模型能够处理超过1M个标记而无SC。我们还发现,在密钥检索中,循环状态容量与状态大小呈指数关系,我们经验性地训练了一个具有接近完美密钥检索准确率的Mamba-2 370M模型,其上下文长度为256K。这表明了基于RNN的长上下文建模有着光明的未来。
English
One essential advantage of recurrent neural networks (RNNs) over
transformer-based language models is their linear computational complexity
concerning the sequence length, which makes them much faster in handling long
sequences during inference. However, most publicly available RNNs (e.g., Mamba
and RWKV) are trained on sequences with less than 10K tokens, and their
effectiveness in longer contexts remains largely unsatisfying so far. In this
paper, we study the cause of the inability to process long context for RNNs and
suggest critical mitigations. We examine two practical concerns when applying
state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to
inputs longer than the training length and (2) the upper bound of memory
capacity. Addressing the first concern, we first investigate *state collapse*
(SC), a phenomenon that causes severe performance degradation on sequence
lengths not encountered during training. With controlled experiments, we
attribute this to overfitting due to the recurrent state being
overparameterized for the training length. For the second concern, we train a
series of Mamba-2 models on long documents to empirically estimate the
recurrent state capacity in language modeling and passkey retrieval. Then,
three SC mitigation methods are proposed to improve Mamba-2's length
generalizability, allowing the model to process more than 1M tokens without SC.
We also find that the recurrent state capacity in passkey retrieval scales
exponentially to the state size, and we empirically train a Mamba-2 370M with
near-perfect passkey retrieval accuracy on 256K context length. This suggests a
promising future for RNN-based long-context modeling.Summary
AI-Generated Summary