Poda Adaptativa de Cálculo para el Transformador con Olvido
Adaptive Computation Pruning for the Forgetting Transformer
April 9, 2025
Autores: Zhixuan Lin, Johan Obando-Ceron, Xu Owen He, Aaron Courville
cs.AI
Resumen
El recientemente propuesto Transformador con Olvido (FoX) incorpora una puerta de olvido en la atención softmax y ha demostrado un rendimiento consistentemente mejor o similar en comparación con el Transformador estándar basado en RoPE. Notablemente, muchas cabezas de atención en FoX tienden a olvidar rápidamente, haciendo que su salida en cada paso de tiempo dependa principalmente del contexto local. Basándonos en esta observación, proponemos la Poda de Cómputo Adaptativa (ACP) para FoX, un método que poda dinámicamente los cálculos que involucran dependencias entrada-salida que son fuertemente decaídas por la puerta de olvido. Esto se logra utilizando un umbral de poda establecido dinámicamente que asegura que los pesos de atención podados permanezcan insignificantes. Aplicamos ACP al preentrenamiento de modelos de lenguaje con FoX y mostramos que reduce consistentemente el número de FLOPs en la atención softmax en aproximadamente un 70% en diferentes tamaños de modelos y longitudes de contexto, lo que resulta en una mejora de aproximadamente un 10% a 35% en el rendimiento del entrenamiento. Además, las longitudes de contexto más largas generan mayores ahorros computacionales. Todas estas mejoras de velocidad se logran sin ninguna degradación del rendimiento. También realizamos varios análisis para proporcionar una comprensión más profunda de nuestro método, como examinar los patrones de poda y analizar la distribución de los ahorros de FLOPs en diferentes cabezas de atención. Nuestro código está disponible en https://github.com/zhixuan-lin/arctic-fox.
English
The recently proposed Forgetting Transformer (FoX) incorporates a forget gate
into softmax attention and has shown consistently better or on-par performance
compared to the standard RoPE-based Transformer. Notably, many attention heads
in FoX tend to forget quickly, causing their output at each timestep to rely
primarily on the local context. Based on this observation, we propose Adaptive
Computation Pruning (ACP) for FoX, a method that dynamically prunes
computations involving input-output dependencies that are strongly decayed by
the forget gate. This is achieved using a dynamically set pruning threshold
that ensures that the pruned attention weights remain negligible. We apply ACP
to language model pretraining with FoX and show it consistently reduces the
number of FLOPs in softmax attention by around 70% across different model sizes
and context lengths, resulting in a roughly 10% to 35% improvement in training
throughput. Furthermore, longer context lengths yield greater computational
savings. All these speed improvements are achieved without any performance
degradation. We also perform several analyses to provide deeper insights into
our method, such as examining the pruning patterns and analyzing the
distribution of FLOP savings across different attention heads. Our code is
available at https://github.com/zhixuan-lin/arctic-fox.Summary
AI-Generated Summary