Forgetting Transformer: Softmax-Attention mit einem Vergessens-Gate
Forgetting Transformer: Softmax Attention with a Forget Gate
March 3, 2025
Autoren: Zhixuan Lin, Evgenii Nikishin, Xu Owen He, Aaron Courville
cs.AI
Zusammenfassung
Ein wesentlicher Bestandteil moderner rekurrenter Sequenzmodelle ist das Vergessens-Tor. Während Transformer keine explizite rekurrente Form aufweisen, zeigen wir, dass ein Vergessens-Tor auf natürliche Weise in Transformer integriert werden kann, indem die nicht normalisierten Aufmerksamkeitswerte in einer datenabhängigen Weise heruntergewichtet werden. Wir nennen diesen Aufmerksamkeitsmechanismus „Forgetting Attention“ und das daraus resultierende Modell „Forgetting Transformer“ (FoX). Wir zeigen, dass FoX den Transformer bei der Sprachmodellierung mit langem Kontext, der Längenextrapolation und Downstream-Aufgaben mit kurzem Kontext übertrifft, während es bei Downstream-Aufgaben mit langem Kontext mit dem Transformer gleichauf liegt. Darüber hinaus ist es mit dem FlashAttention-Algorithmus kompatibel und benötigt keine Positions-Einbettungen. Mehrere Analysen, einschließlich des „Nadel-im-Heuhaufen“-Tests, zeigen, dass FoX auch die überlegenen Fähigkeiten des Transformers im Umgang mit langem Kontext im Vergleich zu rekurrenten Sequenzmodellen wie Mamba-2, HGRN2 und DeltaNet beibehält. Wir stellen außerdem ein „Pro“-Block-Design vor, das einige gängige architektonische Komponenten aus rekurrenten Sequenzmodellen integriert, und stellen fest, dass es die Leistung sowohl von FoX als auch des Transformers erheblich verbessert. Unser Code ist verfügbar unter https://github.com/zhixuan-lin/forgetting-transformer.
English
An essential component of modern recurrent sequence models is the forget
gate. While Transformers do not have an explicit recurrent form, we show that a
forget gate can be naturally incorporated into Transformers by down-weighting
the unnormalized attention scores in a data-dependent way. We name this
attention mechanism the Forgetting Attention and the resulting model the
Forgetting Transformer (FoX). We show that FoX outperforms the Transformer on
long-context language modeling, length extrapolation, and short-context
downstream tasks, while performing on par with the Transformer on long-context
downstream tasks. Moreover, it is compatible with the FlashAttention algorithm
and does not require any positional embeddings. Several analyses, including the
needle-in-the-haystack test, show that FoX also retains the Transformer's
superior long-context capabilities over recurrent sequence models such as
Mamba-2, HGRN2, and DeltaNet. We also introduce a "Pro" block design that
incorporates some common architectural components in recurrent sequence models
and find it significantly improves the performance of both FoX and the
Transformer. Our code is available at
https://github.com/zhixuan-lin/forgetting-transformer.Summary
AI-Generated Summary