Le Hérisson et le Porc-épic : Des Attentions Linéaires Expressives avec Mimétisme de Softmax
The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry
February 6, 2024
Auteurs: Michael Zhang, Kush Bhatia, Hermann Kumbong, Christopher Ré
cs.AI
Résumé
Les attentions linéaires ont montré un potentiel pour améliorer l'efficacité des Transformers, réduisant la complexité quadratique de l'attention à une complexité linéaire par rapport à la longueur de la séquence. Cela ouvre des perspectives prometteuses pour (1) l'entraînement de Transformers linéaires à partir de zéro, (2) la "conversion par ajustement fin" de Transformers spécifiques à une tâche en versions linéaires qui retrouvent les performances de la tâche, et (3) la "conversion à partir de modèles pré-entraînés" de Transformers, tels que les grands modèles de langage, en versions linéaires pouvant être ajustées sur des tâches en aval. Cependant, les attentions linéaires sous-performent souvent l'attention softmax standard en termes de qualité. Pour combler cet écart de performance, nous constatons que les attentions linéaires antérieures manquent de propriétés clés de l'attention softmax liées à de bonnes performances : des poids à faible entropie (ou "pointus") et une monotonie du produit scalaire. Nous observons également des cartes de caractéristiques étonnamment simples qui conservent ces propriétés et égalent les performances de l'attention softmax, mais qui sont inefficaces à calculer dans le cadre de l'attention linéaire. Nous proposons donc Hedgehog, une attention linéaire apprenable qui conserve les propriétés pointues et monotones de l'attention softmax tout en maintenant une complexité linéaire. Hedgehog utilise des MLPs simples et entraînables pour produire des poids d'attention imitant l'attention softmax. Les expériences montrent que Hedgehog récupère plus de 99 % de la qualité du Transformer standard dans des configurations d'entraînement à partir de zéro et de conversion par ajustement fin, surpassant les attentions linéaires antérieures jusqu'à 6 points de perplexité sur WikiText-103 avec des GPT causaux, et jusqu'à 8,7 points de score GLUE sur des BERT bidirectionnels ajustés finement. Hedgehog permet également la conversion à partir de modèles pré-entraînés. La conversion d'un GPT-2 pré-entraîné en une variante d'attention linéaire atteint un état de l'art de 16,7 en perplexité sur WikiText-103 pour des modèles décodeurs sous-quadratiques de 125M. Nous transformons enfin un Llama-2 7B pré-entraîné en un Llama à attention linéaire viable. Avec une adaptation de bas rang, Hedgehog-Llama2 7B atteint 28,1 points ROUGE-1 de plus que le modèle de base à attention standard, là où les attentions linéaires antérieures entraînent une baisse de 16,5 points.
English
Linear attentions have shown potential for improving Transformer efficiency,
reducing attention's quadratic complexity to linear in sequence length. This
holds exciting promise for (1) training linear Transformers from scratch, (2)
"finetuned-conversion" of task-specific Transformers into linear versions that
recover task performance, and (3) "pretrained-conversion" of Transformers such
as large language models into linear versions finetunable on downstream tasks.
However, linear attentions often underperform standard softmax attention in
quality. To close this performance gap, we find prior linear attentions lack
key properties of softmax attention tied to good performance: low-entropy (or
"spiky") weights and dot-product monotonicity. We further observe surprisingly
simple feature maps that retain these properties and match softmax performance,
but are inefficient to compute in linear attention. We thus propose Hedgehog, a
learnable linear attention that retains the spiky and monotonic properties of
softmax attention while maintaining linear complexity. Hedgehog uses simple
trainable MLPs to produce attention weights mimicking softmax attention.
Experiments show Hedgehog recovers over 99% of standard Transformer quality in
train-from-scratch and finetuned-conversion settings, outperforming prior
linear attentions up to 6 perplexity points on WikiText-103 with causal GPTs,
and up to 8.7 GLUE score points on finetuned bidirectional BERTs. Hedgehog also
enables pretrained-conversion. Converting a pretrained GPT-2 into a linear
attention variant achieves state-of-the-art 16.7 perplexity on WikiText-103 for
125M subquadratic decoder models. We finally turn a pretrained Llama-2 7B into
a viable linear attention Llama. With low-rank adaptation, Hedgehog-Llama2 7B
achieves 28.1 higher ROUGE-1 points over the base standard attention model,
where prior linear attentions lead to 16.5 point drops.