Les gradients des Transformers à plusieurs couches peuvent être approximés en temps presque linéaire.
Multi-Layer Transformers Gradient Can be Approximated in Almost Linear Time
August 23, 2024
Auteurs: Yingyu Liang, Zhizhou Sha, Zhenmei Shi, Zhao Song, Yufa Zhou
cs.AI
Résumé
La complexité computationnelle quadratique dans le mécanisme d'auto-attention des architectures de transformer populaires pose des défis importants pour l'entraînement et l'inférence, notamment en termes d'efficacité et d'exigences en mémoire. Pour relever ces défis, cet article présente une nouvelle méthode de calcul rapide des gradients dans les modèles de transformer à plusieurs couches. Notre approche permet le calcul des gradients pour l'ensemble du modèle de transformer à plusieurs couches en un temps presque linéaire n^{1+o(1)}, où n est la longueur de la séquence d'entrée. Cette percée réduit considérablement le goulot d'étranglement computationnel associé à la complexité temporelle quadratique traditionnelle. Notre théorie est valable pour toute fonction de perte et maintient une erreur d'approximation bornée sur l'ensemble du modèle. De plus, notre analyse peut être appliquée lorsque le modèle de transformer à plusieurs couches contient de nombreux sous-modules pratiques, tels que la connexion résiduelle, le masque causal et l'attention multi-têtes. En améliorant l'efficacité du calcul des gradients dans les grands modèles de langage, nous espérons que notre travail facilitera l'entraînement et le déploiement plus efficaces de modèles de langage à long contexte basés sur nos résultats théoriques.
English
The quadratic computational complexity in the self-attention mechanism of
popular transformer architectures poses significant challenges for training and
inference, particularly in terms of efficiency and memory requirements. Towards
addressing these challenges, this paper introduces a novel fast computation
method for gradient calculation in multi-layer transformer models. Our approach
enables the computation of gradients for the entire multi-layer transformer
model in almost linear time n^{1+o(1)}, where n is the input sequence
length. This breakthrough significantly reduces the computational bottleneck
associated with the traditional quadratic time complexity. Our theory holds for
any loss function and maintains a bounded approximation error across the entire
model. Furthermore, our analysis can hold when the multi-layer transformer
model contains many practical sub-modules, such as residual connection, casual
mask, and multi-head attention. By improving the efficiency of gradient
computation in large language models, we hope that our work will facilitate the
more effective training and deployment of long-context language models based on
our theoretical results.Summary
AI-Generated Summary