Градиент многослойных трансформеров можно приблизить практически линейным образом.
Multi-Layer Transformers Gradient Can be Approximated in Almost Linear Time
August 23, 2024
Авторы: Yingyu Liang, Zhizhou Sha, Zhenmei Shi, Zhao Song, Yufa Zhou
cs.AI
Аннотация
Квадратическая вычислительная сложность в механизме самовнимания популярных архитектур трансформеров представляет существенные вызовы для обучения и вывода, особенно в плане эффективности и требований к памяти. Для решения этих проблем в данной статье представлен новый быстрый метод вычисления градиента в многослойных моделях трансформеров. Наш подход позволяет вычислять градиенты для всей многослойной модели трансформера практически за линейное время n^{1+o(1)}, где n - длина входной последовательности. Этот прорыв значительно снижает вычислительное узкое место, связанное с традиционной квадратичной сложностью по времени. Наша теория справедлива для любой функции потерь и обеспечивает ограниченную погрешность аппроксимации по всей модели. Более того, наш анализ может быть применен, когда многослойная модель трансформера содержит множество практических подмодулей, таких как остаточное соединение, случайная маска и многоголовое внимание. Улучшая эффективность вычисления градиента в больших языковых моделях, мы надеемся, что наша работа упростит более эффективное обучение и развертывание языковых моделей с длинным контекстом на основе наших теоретических результатов.
English
The quadratic computational complexity in the self-attention mechanism of
popular transformer architectures poses significant challenges for training and
inference, particularly in terms of efficiency and memory requirements. Towards
addressing these challenges, this paper introduces a novel fast computation
method for gradient calculation in multi-layer transformer models. Our approach
enables the computation of gradients for the entire multi-layer transformer
model in almost linear time n^{1+o(1)}, where n is the input sequence
length. This breakthrough significantly reduces the computational bottleneck
associated with the traditional quadratic time complexity. Our theory holds for
any loss function and maintains a bounded approximation error across the entire
model. Furthermore, our analysis can hold when the multi-layer transformer
model contains many practical sub-modules, such as residual connection, casual
mask, and multi-head attention. By improving the efficiency of gradient
computation in large language models, we hope that our work will facilitate the
more effective training and deployment of long-context language models based on
our theoretical results.Summary
AI-Generated Summary