对角线批处理解锁了循环记忆Transformer在长上下文中的并行计算能力
Diagonal Batching Unlocks Parallelism in Recurrent Memory Transformers for Long Contexts
June 5, 2025
作者: Danil Sivtsov, Ivan Rodkin, Gleb Kuzmin, Yuri Kuratov, Ivan Oseledets
cs.AI
摘要
Transformer模型在处理长上下文推理时面临挑战,因其时间复杂度和内存复杂度分别呈二次方和线性增长。循环记忆Transformer(RMTs)通过将渐近成本降低至线性时间和恒定内存使用,提供了解决方案。然而,其内存更新机制导致顺序执行,形成性能瓶颈。
我们引入了对角线批处理(Diagonal Batching),一种调度方案,它能在保持精确循环的同时,释放RMTs中跨段并行处理的潜力。该方法消除了顺序执行的限制,使得即便对于单个长上下文输入,也能实现高效的GPU推理,无需复杂的批处理和流水线技术。由于该技术纯粹是运行时计算顺序的重排,现有RMT模型无需重新训练即可采用。
应用于LLaMA-1B ARMT模型时,对角线批处理在131,072个token的序列上,相比标准全注意力LLaMA-1B实现了3.3倍的加速,相较于顺序RMT实现也有1.8倍的提升。通过消除顺序瓶颈,对角线批处理降低了推理成本和延迟,从而巩固了RMTs作为现实世界长上下文应用实用解决方案的地位。
English
Transformer models struggle with long-context inference due to their
quadratic time and linear memory complexity. Recurrent Memory Transformers
(RMTs) offer a solution by reducing the asymptotic cost to linear time and
constant memory usage. However, their memory update mechanism leads to
sequential execution, causing a performance bottleneck.
We introduce Diagonal Batching, a scheduling scheme that unlocks parallelism
across segments in RMTs while preserving exact recurrence. This approach
eliminates the sequential constraint, enabling efficient GPU inference even for
single long-context inputs without complex batching and pipelining techniques.
Because the technique is purely a run-time computation reordering, existing RMT
models adopt it with no retraining.
Applied to a LLaMA-1B ARMT model, Diagonal Batching yields a 3.3x speedup
over standard full-attention LLaMA-1B and a 1.8x speedup over the sequential
RMT implementation on 131,072-token sequences. By removing sequential
bottleneck, Diagonal Batching reduces inference cost and latency, thereby
strengthening RMTs as a practical solution for real-world, long-context
applications.