Ёж и дикобраз: Выразительные линейные механизмы внимания с имитацией Softmax
The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry
February 6, 2024
Авторы: Michael Zhang, Kush Bhatia, Hermann Kumbong, Christopher Ré
cs.AI
Аннотация
Линейные механизмы внимания продемонстрировали потенциал для повышения эффективности Transformer, снижая квадратичную сложность внимания до линейной относительно длины последовательности. Это открывает захватывающие перспективы для (1) обучения линейных Transformer с нуля, (2) "тонкой настройки-конвертации" специализированных Transformer в линейные версии, восстанавливающие производительность на задачах, и (3) "предварительной конвертации" Transformer, таких как крупные языковые модели, в линейные версии, которые можно дообучать на целевых задачах. Однако линейные механизмы внимания часто уступают стандартному softmax-вниманию по качеству. Чтобы сократить этот разрыв, мы обнаружили, что предыдущие линейные механизмы внимания не обладают ключевыми свойствами softmax-внимания, связанными с высокой производительностью: низкоэнтропийными (или "остроконечными") весами и монотонностью скалярного произведения. Мы также наблюдаем удивительно простые карты признаков, которые сохраняют эти свойства и соответствуют производительности softmax, но неэффективны для вычисления в линейном внимании. Таким образом, мы предлагаем Hedgehog, обучаемый линейный механизм внимания, который сохраняет остроконечные и монотонные свойства softmax-внимания, сохраняя при этом линейную сложность. Hedgehog использует простые обучаемые MLP для создания весов внимания, имитирующих softmax-внимание. Эксперименты показывают, что Hedgehog восстанавливает более 99% качества стандартного Transformer в настройках обучения с нуля и тонкой настройки-конвертации, превосходя предыдущие линейные механизмы внимания на до 6 пунктов perplexity на WikiText-103 с каузальными GPT и до 8.7 пунктов GLUE на дообученных двунаправленных BERT. Hedgehog также позволяет выполнять предварительную конвертацию. Конвертация предварительно обученной GPT-2 в линейную версию внимания достигает современного уровня 16.7 perplexity на WikiText-103 для 125M субквадратичных декодерных моделей. Наконец, мы превращаем предварительно обученную Llama-2 7B в жизнеспособную линейную версию внимания Llama. С использованием низкоранговой адаптации Hedgehog-Llama2 7B достигает 28.1 дополнительных пунктов ROUGE-1 по сравнению с базовой моделью стандартного внимания, тогда как предыдущие линейные механизмы внимания приводят к снижению на 16.5 пунктов.
English
Linear attentions have shown potential for improving Transformer efficiency,
reducing attention's quadratic complexity to linear in sequence length. This
holds exciting promise for (1) training linear Transformers from scratch, (2)
"finetuned-conversion" of task-specific Transformers into linear versions that
recover task performance, and (3) "pretrained-conversion" of Transformers such
as large language models into linear versions finetunable on downstream tasks.
However, linear attentions often underperform standard softmax attention in
quality. To close this performance gap, we find prior linear attentions lack
key properties of softmax attention tied to good performance: low-entropy (or
"spiky") weights and dot-product monotonicity. We further observe surprisingly
simple feature maps that retain these properties and match softmax performance,
but are inefficient to compute in linear attention. We thus propose Hedgehog, a
learnable linear attention that retains the spiky and monotonic properties of
softmax attention while maintaining linear complexity. Hedgehog uses simple
trainable MLPs to produce attention weights mimicking softmax attention.
Experiments show Hedgehog recovers over 99% of standard Transformer quality in
train-from-scratch and finetuned-conversion settings, outperforming prior
linear attentions up to 6 perplexity points on WikiText-103 with causal GPTs,
and up to 8.7 GLUE score points on finetuned bidirectional BERTs. Hedgehog also
enables pretrained-conversion. Converting a pretrained GPT-2 into a linear
attention variant achieves state-of-the-art 16.7 perplexity on WikiText-103 for
125M subquadratic decoder models. We finally turn a pretrained Llama-2 7B into
a viable linear attention Llama. With low-rank adaptation, Hedgehog-Llama2 7B
achieves 28.1 higher ROUGE-1 points over the base standard attention model,
where prior linear attentions lead to 16.5 point drops.