刺猬與豪豬:軟最大值函數模擬的表達性線性注意力
The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry
February 6, 2024
作者: Michael Zhang, Kush Bhatia, Hermann Kumbong, Christopher Ré
cs.AI
摘要
線性注意力已顯示出提升Transformer效率的潛力,將注意力的二次複雜度在序列長度上降至線性。這對於(1)從頭開始訓練線性Transformer,(2)將特定任務的Transformer進行“微調轉換”為能夠恢復任務性能的線性版本,以及(3)將大型語言模型等Transformer進行“預訓練轉換”為可在下游任務上進行微調的線性版本,帶來令人振奮的前景。然而,線性注意力在質量上通常表現不如標準的softmax注意力。為了彌補這一性能差距,我們發現先前的線性注意力缺乏與良好性能相關聯的softmax注意力的關鍵特性:低熵(或“尖銳”)權重和點積單調性。我們進一步觀察到一些驚人簡單的特徵映射,保留了這些特性並與softmax性能匹配,但在線性注意力中計算效率低下。因此,我們提出Hedgehog,一種可學習的線性注意力,保留了softmax注意力的尖銳和單調特性,同時保持線性複雜度。Hedgehog使用簡單可訓練的MLP來生成模仿softmax注意力的注意力權重。實驗表明,Hedgehog在從頭訓練和微調轉換設置中恢復了標準Transformer質量的超過99%,在WikiText-103上的因果GPT中比先前的線性注意力高出多達6個困惑度點,在微調的雙向BERT上高達8.7個GLUE分數點。Hedgehog還實現了預訓練轉換。將預訓練的GPT-2轉換為線性注意力變體,在125M次二次解碼器模型的WikiText-103上實現了16.7的困惑度,達到了最新水平。最後,我們將預訓練的Llama-2 7B轉換為可行的線性注意力Llama。通過低秩適應,Hedgehog-Llama2 7B在ROUGE-1分數上比基礎標準注意力模型高出28.1個點,而先前的線性注意力導致16.5個點的下降。
English
Linear attentions have shown potential for improving Transformer efficiency,
reducing attention's quadratic complexity to linear in sequence length. This
holds exciting promise for (1) training linear Transformers from scratch, (2)
"finetuned-conversion" of task-specific Transformers into linear versions that
recover task performance, and (3) "pretrained-conversion" of Transformers such
as large language models into linear versions finetunable on downstream tasks.
However, linear attentions often underperform standard softmax attention in
quality. To close this performance gap, we find prior linear attentions lack
key properties of softmax attention tied to good performance: low-entropy (or
"spiky") weights and dot-product monotonicity. We further observe surprisingly
simple feature maps that retain these properties and match softmax performance,
but are inefficient to compute in linear attention. We thus propose Hedgehog, a
learnable linear attention that retains the spiky and monotonic properties of
softmax attention while maintaining linear complexity. Hedgehog uses simple
trainable MLPs to produce attention weights mimicking softmax attention.
Experiments show Hedgehog recovers over 99% of standard Transformer quality in
train-from-scratch and finetuned-conversion settings, outperforming prior
linear attentions up to 6 perplexity points on WikiText-103 with causal GPTs,
and up to 8.7 GLUE score points on finetuned bidirectional BERTs. Hedgehog also
enables pretrained-conversion. Converting a pretrained GPT-2 into a linear
attention variant achieves state-of-the-art 16.7 perplexity on WikiText-103 for
125M subquadratic decoder models. We finally turn a pretrained Llama-2 7B into
a viable linear attention Llama. With low-rank adaptation, Hedgehog-Llama2 7B
achieves 28.1 higher ROUGE-1 points over the base standard attention model,
where prior linear attentions lead to 16.5 point drops.