Stufffed Mamba: RNN 기반 장기 맥락 모델링의 상태 붕괴와 상태 용량
Stuffed Mamba: State Collapse and State Capacity of RNN-Based Long-Context Modeling
October 9, 2024
저자: Yingfa Chen, Xinrong Zhang, Shengding Hu, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
초록
순환 신경망(RNN)이 트랜스포머 기반 언어 모델에 비해 가지는 중요한 장점 중 하나는 시퀀스 길이에 대한 선형 계산 복잡성으로, 이는 추론 중에 긴 시퀀스를 처리할 때 훨씬 빠르게 만들어줍니다. 그러나 대부분의 공개적으로 이용 가능한 RNN(예: Mamba 및 RWKV)은 1만 토큰 미만의 시퀀스로 훈련되어 왔으며, 그들의 긴 문맥에서의 효과는 현재까지 대부분 만족스럽지 못한 상태입니다. 본 논문에서는 RNN이 긴 문맥을 처리할 수 없는 원인을 연구하고 중요한 완화 방안을 제안합니다. 우리는 최신 RNN을 긴 문맥에 적용할 때 고려해야 할 두 가지 실용적인 고려 사항을 검토합니다: (1) 훈련 길이를 초과하는 입력에 대한 추정 불가능성과 (2) 메모리 용량의 상한선. 첫 번째 고려 사항에 대해, 우리는 먼저 *상태 붕괴*(SC)를 조사합니다. 이는 훈련 중에 경험하지 않은 시퀀스 길이에서 심각한 성능 저하를 일으키는 현상입니다. 통제된 실험을 통해, 우리는 이를 훈련 길이에 대해 과도하게 매개변수화된 순환 상태로 인한 과적합으로 귀속합니다. 두 번째 고려 사항에 대해, 우리는 언어 모델링 및 패스키 검색에서 순환 상태 용량을 경험적으로 추정하기 위해 긴 문서에 대해 일련의 Mamba-2 모델을 훈련시킵니다. 그런 다음, Mamba-2의 길이 일반화 능력을 향상시키기 위해 세 가지 SC 완화 방법을 제안하여, SC 없이 100만 토큰 이상을 처리할 수 있도록 합니다. 또한, 패스키 검색에서의 순환 상태 용량이 상태 크기에 지수적으로 확장되는 것을 발견하고, 25만 6천 길이의 문맥에서 거의 완벽한 패스키 검색 정확도를 갖는 Mamba-2 3억 7천만을 경험적으로 훈련시킵니다. 이는 RNN 기반의 긴 문맥 모델링에 밝은 미래를 제시합니다.
English
One essential advantage of recurrent neural networks (RNNs) over
transformer-based language models is their linear computational complexity
concerning the sequence length, which makes them much faster in handling long
sequences during inference. However, most publicly available RNNs (e.g., Mamba
and RWKV) are trained on sequences with less than 10K tokens, and their
effectiveness in longer contexts remains largely unsatisfying so far. In this
paper, we study the cause of the inability to process long context for RNNs and
suggest critical mitigations. We examine two practical concerns when applying
state-of-the-art RNNs to long contexts: (1) the inability to extrapolate to
inputs longer than the training length and (2) the upper bound of memory
capacity. Addressing the first concern, we first investigate *state collapse*
(SC), a phenomenon that causes severe performance degradation on sequence
lengths not encountered during training. With controlled experiments, we
attribute this to overfitting due to the recurrent state being
overparameterized for the training length. For the second concern, we train a
series of Mamba-2 models on long documents to empirically estimate the
recurrent state capacity in language modeling and passkey retrieval. Then,
three SC mitigation methods are proposed to improve Mamba-2's length
generalizability, allowing the model to process more than 1M tokens without SC.
We also find that the recurrent state capacity in passkey retrieval scales
exponentially to the state size, and we empirically train a Mamba-2 370M with
near-perfect passkey retrieval accuracy on 256K context length. This suggests a
promising future for RNN-based long-context modeling.Summary
AI-Generated Summary