I Gradient dei Trasformatori Multi-Livello Possono Essere Approssimati in Tempo Quasi Lineare
Multi-Layer Transformers Gradient Can be Approximated in Almost Linear Time
August 23, 2024
Autori: Yingyu Liang, Zhizhou Sha, Zhenmei Shi, Zhao Song, Yufa Zhou
cs.AI
Abstract
La complessità computazionale quadratica nel meccanismo di self-attention delle architetture transformer più diffuse presenta sfide significative per l'addestramento e l'inferenza, in particolare in termini di efficienza e requisiti di memoria. Per affrontare queste sfide, questo articolo introduce un nuovo metodo di calcolo veloce per il calcolo del gradiente nei modelli transformer multi-strato. Il nostro approccio consente il calcolo dei gradienti per l'intero modello transformer multi-strato in un tempo quasi lineare n^{1+o(1)}, dove n è la lunghezza della sequenza di input. Questa svolta riduce significativamente il collo di bottiglia computazionale associato alla tradizionale complessità temporale quadratica. La nostra teoria è valida per qualsiasi funzione di perdita e mantiene un errore di approssimazione limitato in tutto il modello. Inoltre, la nostra analisi rimane valida quando il modello transformer multi-strato include molti sottomoduli pratici, come la connessione residua, la maschera causale e l'attenzione multi-testina. Migliorando l'efficienza del calcolo del gradiente nei modelli linguistici di grandi dimensioni, speriamo che il nostro lavoro possa facilitare un addestramento e un dispiegamento più efficaci dei modelli linguistici a contesto lungo, basandosi sui nostri risultati teorici.
English
The quadratic computational complexity in the self-attention mechanism of
popular transformer architectures poses significant challenges for training and
inference, particularly in terms of efficiency and memory requirements. Towards
addressing these challenges, this paper introduces a novel fast computation
method for gradient calculation in multi-layer transformer models. Our approach
enables the computation of gradients for the entire multi-layer transformer
model in almost linear time n^{1+o(1)}, where n is the input sequence
length. This breakthrough significantly reduces the computational bottleneck
associated with the traditional quadratic time complexity. Our theory holds for
any loss function and maintains a bounded approximation error across the entire
model. Furthermore, our analysis can hold when the multi-layer transformer
model contains many practical sub-modules, such as residual connection, casual
mask, and multi-head attention. By improving the efficiency of gradient
computation in large language models, we hope that our work will facilitate the
more effective training and deployment of long-context language models based on
our theoretical results.