StateX: 学習後の状態拡張によるRNNのリコール性能向上
StateX: Enhancing RNN Recall via Post-training State Expansion
September 26, 2025
著者: Xingyu Shen, Yingfa Chen, Zhen Leng Thai, Xu Han, Zhiyuan Liu, Maosong Sun
cs.AI
要旨
Transformerベースのモデルは、言語モデリングにおいて顕著な性能を発揮しているが、その高い複雑性により、長いコンテキストを処理する際に高コストが発生する。一方、線形アテンションや状態空間モデルなどのリカレントニューラルネットワーク(RNN)は、トークンごとの計算量が一定であることから人気を集めている。しかし、これらのリカレントモデルは、長いコンテキストから正確に情報を想起する必要があるタスクにおいて苦戦する。なぜなら、すべてのコンテキスト情報が一定サイズのリカレント状態に圧縮されるためである。これまでの研究では、想起能力はリカレント状態のサイズと正の相関があることが示されているが、リカレント状態を大きくしてRNNを直接訓練すると、高い訓練コストが発生する。本論文では、事前訓練済みRNNの状態を効率的に拡張するための訓練パイプラインであるStateXを提案する。線形アテンションと状態空間モデルという2つの人気のあるRNNクラスに対して、モデルパラメータの増加を最小限またはゼロに抑えつつ、状態サイズを拡大するためのアーキテクチャ変更を設計する。最大1.3Bパラメータのモデルを用いた実験により、StateXが高い事後訓練コストを発生させることなく、RNNの想起能力とコンテキスト内学習能力を効率的に向上させ、他の能力を損なわないことが実証された。
English
While Transformer-based models have demonstrated remarkable language modeling
performance, their high complexities result in high costs when processing long
contexts. In contrast, recurrent neural networks (RNNs) such as linear
attention and state space models have gained popularity due to their constant
per-token complexities. However, these recurrent models struggle with tasks
that require accurate recall of contextual information from long contexts,
because all contextual information is compressed into a constant-size recurrent
state. Previous works have shown that recall ability is positively correlated
with the recurrent state size, yet directly training RNNs with larger recurrent
states results in high training costs. In this paper, we introduce StateX, a
training pipeline for efficiently expanding the states of pre-trained RNNs
through post-training. For two popular classes of RNNs, linear attention and
state space models, we design post-training architectural modifications to
scale up the state size with no or negligible increase in model parameters.
Experiments on models up to 1.3B parameters demonstrate that StateX efficiently
enhances the recall and in-context learning ability of RNNs without incurring
high post-training costs or compromising other capabilities.