De Egel & het Stokstaartje: Expressieve Lineaire Attenties met Softmax Mimicry
The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry
February 6, 2024
Auteurs: Michael Zhang, Kush Bhatia, Hermann Kumbong, Christopher Ré
cs.AI
Samenvatting
Lineaire aandachtssystemen hebben potentieel getoond voor het verbeteren van de efficiëntie van Transformers, waarbij de kwadratische complexiteit van aandacht wordt teruggebracht tot lineair in sequentielengte. Dit biedt spannende mogelijkheden voor (1) het trainen van lineaire Transformers vanaf nul, (2) "fine-tuned conversie" van taakspecifieke Transformers naar lineaire versies die de taakprestaties herstellen, en (3) "pretrained conversie" van Transformers zoals grote taalmodelen naar lineaire versies die kunnen worden gefinetuned voor downstream taken. Echter, lineaire aandachtssystemen presteren vaak minder goed dan standaard softmax-aandacht in kwaliteit. Om dit prestatiegat te dichten, constateren we dat eerdere lineaire aandachtssystemen essentiële eigenschappen van softmax-aandacht missen die gekoppeld zijn aan goede prestaties: laag-entropie (of "spiky") gewichten en dot-product monotoniciteit. We observeren verder verrassend eenvoudige feature maps die deze eigenschappen behouden en de prestaties van softmax evenaren, maar inefficiënt zijn om te berekenen in lineaire aandacht. Daarom stellen we Hedgehog voor, een leerbaar lineair aandachtssysteem dat de spiky en monotone eigenschappen van softmax-aandacht behoudt terwijl het lineaire complexiteit handhaaft. Hedgehog gebruikt eenvoudige trainbare MLPs om aandachtgewichten te produceren die softmax-aandacht nabootsen. Experimenten tonen aan dat Hedgehog meer dan 99% van de kwaliteit van standaard Transformers herstelt in train-from-scratch en fine-tuned conversie instellingen, en presteert beter dan eerdere lineaire aandachtssystemen met tot 6 perplexiteitspunten op WikiText-103 met causale GPTs, en tot 8,7 GLUE-scorepunten op gefinetunde bidirectionele BERTs. Hedgehog maakt ook pretrained conversie mogelijk. Het omzetten van een pretrained GPT-2 naar een lineaire aandachtvariant behaalt state-of-the-art 16,7 perplexiteit op WikiText-103 voor 125M subkwadratische decodermodellen. We zetten ten slotte een pretrained Llama-2 7B om in een levensvatbare lineaire aandacht Llama. Met low-rank aanpassing behaalt Hedgehog-Llama2 7B 28,1 hogere ROUGE-1 punten ten opzichte van het basisstandaard aandachtmodel, waar eerdere lineaire aandachtssystemen leiden tot dalingen van 16,5 punten.
English
Linear attentions have shown potential for improving Transformer efficiency,
reducing attention's quadratic complexity to linear in sequence length. This
holds exciting promise for (1) training linear Transformers from scratch, (2)
"finetuned-conversion" of task-specific Transformers into linear versions that
recover task performance, and (3) "pretrained-conversion" of Transformers such
as large language models into linear versions finetunable on downstream tasks.
However, linear attentions often underperform standard softmax attention in
quality. To close this performance gap, we find prior linear attentions lack
key properties of softmax attention tied to good performance: low-entropy (or
"spiky") weights and dot-product monotonicity. We further observe surprisingly
simple feature maps that retain these properties and match softmax performance,
but are inefficient to compute in linear attention. We thus propose Hedgehog, a
learnable linear attention that retains the spiky and monotonic properties of
softmax attention while maintaining linear complexity. Hedgehog uses simple
trainable MLPs to produce attention weights mimicking softmax attention.
Experiments show Hedgehog recovers over 99% of standard Transformer quality in
train-from-scratch and finetuned-conversion settings, outperforming prior
linear attentions up to 6 perplexity points on WikiText-103 with causal GPTs,
and up to 8.7 GLUE score points on finetuned bidirectional BERTs. Hedgehog also
enables pretrained-conversion. Converting a pretrained GPT-2 into a linear
attention variant achieves state-of-the-art 16.7 perplexity on WikiText-103 for
125M subquadratic decoder models. We finally turn a pretrained Llama-2 7B into
a viable linear attention Llama. With low-rank adaptation, Hedgehog-Llama2 7B
achieves 28.1 higher ROUGE-1 points over the base standard attention model,
where prior linear attentions lead to 16.5 point drops.