Diagonale Batchverwerking Ontgrendelt Parallelisme in Recurrente Geheugen Transformers voor Lange Contexten
Diagonal Batching Unlocks Parallelism in Recurrent Memory Transformers for Long Contexts
June 5, 2025
Auteurs: Danil Sivtsov, Ivan Rodkin, Gleb Kuzmin, Yuri Kuratov, Ivan Oseledets
cs.AI
Samenvatting
Transformer-modellen hebben moeite met inferentie in lange contexten vanwege hun
kwadratische tijdscomplexiteit en lineaire geheugencomplexiteit. Recurrent Memory Transformers
(RMT's) bieden een oplossing door de asymptotische kosten te reduceren naar lineaire tijd en
constant geheugengebruik. Hun geheugenupdate-mechanisme leidt echter tot sequentiële uitvoering,
wat een prestatieknelpunt veroorzaakt.
Wij introduceren Diagonal Batching, een planningsschema dat parallellisme mogelijk maakt
tussen segmenten in RMT's terwijl exacte recurrentie behouden blijft. Deze aanpak elimineert de
sequentiële beperking, waardoor efficiënte GPU-inferentie mogelijk wordt, zelfs voor enkele
lange-context inputs zonder complexe batching- en pipeliningtechnieken. Omdat de techniek puur
een herordening van runtime-berekeningen is, kunnen bestaande RMT-modellen deze zonder
hertraining toepassen.
Toegepast op een LLaMA-1B ARMT-model levert Diagonal Batching een 3,3x versnelling op
ten opzichte van standaard full-attention LLaMA-1B en een 1,8x versnelling ten opzichte van de
sequentiële RMT-implementatie op sequenties van 131.072 tokens. Door de sequentiële knelpunt
te verwijderen, verlaagt Diagonal Batching de inferentiekosten en latentie, waardoor RMT's worden
versterkt als een praktische oplossing voor real-world, lange-context toepassingen.
English
Transformer models struggle with long-context inference due to their
quadratic time and linear memory complexity. Recurrent Memory Transformers
(RMTs) offer a solution by reducing the asymptotic cost to linear time and
constant memory usage. However, their memory update mechanism leads to
sequential execution, causing a performance bottleneck.
We introduce Diagonal Batching, a scheduling scheme that unlocks parallelism
across segments in RMTs while preserving exact recurrence. This approach
eliminates the sequential constraint, enabling efficient GPU inference even for
single long-context inputs without complex batching and pipelining techniques.
Because the technique is purely a run-time computation reordering, existing RMT
models adopt it with no retraining.
Applied to a LLaMA-1B ARMT model, Diagonal Batching yields a 3.3x speedup
over standard full-attention LLaMA-1B and a 1.8x speedup over the sequential
RMT implementation on 131,072-token sequences. By removing sequential
bottleneck, Diagonal Batching reduces inference cost and latency, thereby
strengthening RMTs as a practical solution for real-world, long-context
applications.