Adaptieve Rekenkundige Snoei voor de Transformer met Vergeetmechanisme
Adaptive Computation Pruning for the Forgetting Transformer
April 9, 2025
Auteurs: Zhixuan Lin, Johan Obando-Ceron, Xu Owen He, Aaron Courville
cs.AI
Samenvatting
De recent voorgestelde Forgetting Transformer (FoX) integreert een forget-gate in softmax-attentie en heeft consequent betere of vergelijkbare prestaties laten zien in vergelijking met de standaard RoPE-gebaseerde Transformer. Opmerkelijk is dat veel aandachtskoppen in FoX de neiging hebben om snel te vergeten, waardoor hun uitvoer bij elke tijdstap voornamelijk afhankelijk is van de lokale context. Op basis van deze observatie stellen we Adaptive Computation Pruning (ACP) voor FoX voor, een methode die dynamisch berekeningen verwijdert die betrekking hebben op input-output-afhankelijkheden die sterk zijn verzwakt door de forget-gate. Dit wordt bereikt met behulp van een dynamisch ingestelde pruning-drempel die ervoor zorgt dat de verwijderde aandachtswaarden verwaarloosbaar blijven. We passen ACP toe bij het vooraf trainen van taalmmodellen met FoX en laten zien dat het consistent het aantal FLOPs in softmax-attentie met ongeveer 70% vermindert, ongeacht de modelgrootte en contextlengte, wat resulteert in een verbetering van de trainingsdoorvoer van ongeveer 10% tot 35%. Bovendien leveren langere contextlengtes grotere computationele besparingen op. Al deze snelheidsverbeteringen worden bereikt zonder enig prestatieverlies. We voeren ook verschillende analyses uit om dieper inzicht te bieden in onze methode, zoals het onderzoeken van de pruning-patronen en het analyseren van de verdeling van FLOP-besparingen over verschillende aandachtskoppen. Onze code is beschikbaar op https://github.com/zhixuan-lin/arctic-fox.
English
The recently proposed Forgetting Transformer (FoX) incorporates a forget gate
into softmax attention and has shown consistently better or on-par performance
compared to the standard RoPE-based Transformer. Notably, many attention heads
in FoX tend to forget quickly, causing their output at each timestep to rely
primarily on the local context. Based on this observation, we propose Adaptive
Computation Pruning (ACP) for FoX, a method that dynamically prunes
computations involving input-output dependencies that are strongly decayed by
the forget gate. This is achieved using a dynamically set pruning threshold
that ensures that the pruned attention weights remain negligible. We apply ACP
to language model pretraining with FoX and show it consistently reduces the
number of FLOPs in softmax attention by around 70% across different model sizes
and context lengths, resulting in a roughly 10% to 35% improvement in training
throughput. Furthermore, longer context lengths yield greater computational
savings. All these speed improvements are achieved without any performance
degradation. We also perform several analyses to provide deeper insights into
our method, such as examining the pruning patterns and analyzing the
distribution of FLOP savings across different attention heads. Our code is
available at https://github.com/zhixuan-lin/arctic-fox.Summary
AI-Generated Summary