O Gradiente dos Transformadores de Múltiplas Camadas Pode ser Aproximado em Quase Tempo Linear.
Multi-Layer Transformers Gradient Can be Approximated in Almost Linear Time
August 23, 2024
Autores: Yingyu Liang, Zhizhou Sha, Zhenmei Shi, Zhao Song, Yufa Zhou
cs.AI
Resumo
A complexidade computacional quadrática no mecanismo de autoatenção das arquiteturas de transformer populares apresenta desafios significativos para o treinamento e inferência, especialmente em termos de eficiência e requisitos de memória. Para enfrentar esses desafios, este artigo apresenta um novo método de cálculo de gradientes rápido para modelos de transformer de várias camadas. Nossa abordagem possibilita o cálculo de gradientes para todo o modelo de transformer de várias camadas em quase tempo linear n^{1+o(1)}, onde n é o comprimento da sequência de entrada. Essa inovação reduz significativamente o gargalo computacional associado à tradicional complexidade de tempo quadrático. Nossa teoria é válida para qualquer função de perda e mantém um erro de aproximação limitado em todo o modelo. Além disso, nossa análise é aplicável quando o modelo de transformer de várias camadas contém muitos submódulos práticos, como conexão residual, máscara casual e atenção multi-head. Ao melhorar a eficiência do cálculo de gradientes em grandes modelos de linguagem, esperamos que nosso trabalho facilite o treinamento e implantação mais eficazes de modelos de linguagem de longo contexto com base em nossos resultados teóricos.
English
The quadratic computational complexity in the self-attention mechanism of
popular transformer architectures poses significant challenges for training and
inference, particularly in terms of efficiency and memory requirements. Towards
addressing these challenges, this paper introduces a novel fast computation
method for gradient calculation in multi-layer transformer models. Our approach
enables the computation of gradients for the entire multi-layer transformer
model in almost linear time n^{1+o(1)}, where n is the input sequence
length. This breakthrough significantly reduces the computational bottleneck
associated with the traditional quadratic time complexity. Our theory holds for
any loss function and maintains a bounded approximation error across the entire
model. Furthermore, our analysis can hold when the multi-layer transformer
model contains many practical sub-modules, such as residual connection, casual
mask, and multi-head attention. By improving the efficiency of gradient
computation in large language models, we hope that our work will facilitate the
more effective training and deployment of long-context language models based on
our theoretical results.Summary
AI-Generated Summary