StateX: 사후 학습 상태 확장을 통한 RNN 회상 능력 향상
StateX: Enhancing RNN Recall via Post-training State Expansion
September 26, 2025
저자: Xingyu Shen, Yingfa Chen, Zhen Leng Thai, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
초록
트랜스포머 기반 모델은 뛰어난 언어 모델링 성능을 보여주지만, 높은 복잡성으로 인해 긴 문맥을 처리할 때 비용이 많이 듭니다. 반면, 선형 어텐션(linear attention) 및 상태 공간 모델(state space models)과 같은 순환 신경망(RNN)은 토큰당 일정한 복잡성을 유지하기 때문에 인기를 끌고 있습니다. 그러나 이러한 순환 모델은 모든 문맥 정보가 일정한 크기의 순환 상태로 압축되기 때문에, 긴 문맥에서 정확한 정보 회상이 필요한 작업에는 어려움을 겪습니다. 선행 연구에 따르면 회상 능력은 순환 상태 크기와 양의 상관관계를 가지지만, 순환 상태 크기를 늘려 RNN을 직접 학습시키는 것은 높은 학습 비용을 초래합니다. 본 논문에서는 사전 학습된 RNN의 상태를 사후 학습을 통해 효율적으로 확장하는 StateX 학습 파이프라인을 소개합니다. 선형 어텐션 및 상태 공간 모델이라는 두 가지 인기 있는 RNN 클래스에 대해, 모델 파라미터를 증가시키지 않거나 미미하게 증가시키면서 상태 크기를 확장할 수 있는 사후 학습 아키텍처 수정을 설계했습니다. 최대 13억 파라미터 규모의 모델에 대한 실험을 통해 StateX가 높은 사후 학습 비용 없이 RNN의 회상 및 문맥 내 학습 능력을 효율적으로 향상시키며, 다른 기능을 저하시키지 않음을 입증했습니다.
English
While Transformer-based models have demonstrated remarkable language modeling
performance, their high complexities result in high costs when processing long
contexts. In contrast, recurrent neural networks (RNNs) such as linear
attention and state space models have gained popularity due to their constant
per-token complexities. However, these recurrent models struggle with tasks
that require accurate recall of contextual information from long contexts,
because all contextual information is compressed into a constant-size recurrent
state. Previous works have shown that recall ability is positively correlated
with the recurrent state size, yet directly training RNNs with larger recurrent
states results in high training costs. In this paper, we introduce StateX, a
training pipeline for efficiently expanding the states of pre-trained RNNs
through post-training. For two popular classes of RNNs, linear attention and
state space models, we design post-training architectural modifications to
scale up the state size with no or negligible increase in model parameters.
Experiments on models up to 1.3B parameters demonstrate that StateX efficiently
enhances the recall and in-context learning ability of RNNs without incurring
high post-training costs or compromising other capabilities.