対角バッチングが長文脈におけるリカレントメモリートランスフォーマーの並列処理を実現
Diagonal Batching Unlocks Parallelism in Recurrent Memory Transformers for Long Contexts
June 5, 2025
著者: Danil Sivtsov, Ivan Rodkin, Gleb Kuzmin, Yuri Kuratov, Ivan Oseledets
cs.AI
要旨
Transformerモデルは、その二次時間および線形メモリ複雑性のため、長文脈推論に苦戦しています。Recurrent Memory Transformers(RMT)は、漸近コストを線形時間および定数メモリ使用量に削減することで解決策を提供します。しかし、そのメモリ更新メカニズムは逐次実行を引き起こし、パフォーマンスのボトルネックとなります。
本論文では、RMTにおいてセグメント間の並列性を解き放ちながら、正確な再帰を維持するスケジューリング手法であるDiagonal Batchingを提案します。このアプローチは逐次制約を排除し、複雑なバッチ処理やパイプライン技術を必要とせずに、単一の長文脈入力に対する効率的なGPU推論を可能にします。この技術は純粋に実行時の計算順序変更であるため、既存のRMTモデルは再学習なしで採用できます。
LLaMA-1B ARMTモデルに適用した場合、Diagonal Batchingは、標準的な完全注意機構のLLaMA-1Bと比較して3.3倍、逐次RMT実装と比較して1.8倍の高速化を131,072トークン列で達成します。逐次ボトルネックを除去することで、Diagonal Batchingは推論コストとレイテンシを削減し、RMTを現実世界の長文脈アプリケーションに対する実用的なソリューションとして強化します。
English
Transformer models struggle with long-context inference due to their
quadratic time and linear memory complexity. Recurrent Memory Transformers
(RMTs) offer a solution by reducing the asymptotic cost to linear time and
constant memory usage. However, their memory update mechanism leads to
sequential execution, causing a performance bottleneck.
We introduce Diagonal Batching, a scheduling scheme that unlocks parallelism
across segments in RMTs while preserving exact recurrence. This approach
eliminates the sequential constraint, enabling efficient GPU inference even for
single long-context inputs without complex batching and pipelining techniques.
Because the technique is purely a run-time computation reordering, existing RMT
models adopt it with no retraining.
Applied to a LLaMA-1B ARMT model, Diagonal Batching yields a 3.3x speedup
over standard full-attention LLaMA-1B and a 1.8x speedup over the sequential
RMT implementation on 131,072-token sequences. By removing sequential
bottleneck, Diagonal Batching reduces inference cost and latency, thereby
strengthening RMTs as a practical solution for real-world, long-context
applications.