ChatPaper.aiChatPaper

O Agrupamento Diagonal Desbloqueia o Paralelismo em Transformadores de Memória Recorrente para Contextos Longos

Diagonal Batching Unlocks Parallelism in Recurrent Memory Transformers for Long Contexts

June 5, 2025
Autores: Danil Sivtsov, Ivan Rodkin, Gleb Kuzmin, Yuri Kuratov, Ivan Oseledets
cs.AI

Resumo

Os modelos Transformer enfrentam dificuldades na inferência de contextos longos devido à sua complexidade quadrática de tempo e linear de memória. Os Transformers com Memória Recorrente (RMTs) oferecem uma solução ao reduzir o custo assintótico para tempo linear e uso constante de memória. No entanto, seu mecanismo de atualização de memória resulta em execução sequencial, causando um gargalo de desempenho. Apresentamos o *Diagonal Batching*, um esquema de agendamento que desbloqueia o paralelismo entre segmentos em RMTs enquanto preserva a recorrência exata. Essa abordagem elimina a restrição sequencial, permitindo inferência eficiente em GPUs mesmo para entradas únicas de contexto longo, sem técnicas complexas de *batching* e *pipelining*. Como a técnica é puramente uma reordenação de computação em tempo de execução, os modelos RMT existentes podem adotá-la sem necessidade de retreinamento. Aplicado a um modelo LLaMA-1B ARMT, o *Diagonal Batching* proporciona um ganho de velocidade de 3,3x em relação ao LLaMA-1B com atenção completa padrão e 1,8x em relação à implementação sequencial de RMT em sequências de 131.072 *tokens*. Ao remover o gargalo sequencial, o *Diagonal Batching* reduz o custo e a latência de inferência, fortalecendo os RMTs como uma solução prática para aplicações do mundo real com contextos longos.
English
Transformer models struggle with long-context inference due to their quadratic time and linear memory complexity. Recurrent Memory Transformers (RMTs) offer a solution by reducing the asymptotic cost to linear time and constant memory usage. However, their memory update mechanism leads to sequential execution, causing a performance bottleneck. We introduce Diagonal Batching, a scheduling scheme that unlocks parallelism across segments in RMTs while preserving exact recurrence. This approach eliminates the sequential constraint, enabling efficient GPU inference even for single long-context inputs without complex batching and pipelining techniques. Because the technique is purely a run-time computation reordering, existing RMT models adopt it with no retraining. Applied to a LLaMA-1B ARMT model, Diagonal Batching yields a 3.3x speedup over standard full-attention LLaMA-1B and a 1.8x speedup over the sequential RMT implementation on 131,072-token sequences. By removing sequential bottleneck, Diagonal Batching reduces inference cost and latency, thereby strengthening RMTs as a practical solution for real-world, long-context applications.
PDF373June 6, 2025