Poda Computacional Adaptativa para o Transformer com Esquecimento
Adaptive Computation Pruning for the Forgetting Transformer
April 9, 2025
Autores: Zhixuan Lin, Johan Obando-Ceron, Xu Owen He, Aaron Courville
cs.AI
Resumo
O recentemente proposto Transformer com Esquecimento (FoX) incorpora um portão de esquecimento na atenção softmax e tem demonstrado desempenho consistentemente melhor ou equivalente em comparação com o Transformer padrão baseado em RoPE. Notavelmente, muitas cabeças de atenção no FoX tendem a esquecer rapidamente, fazendo com que sua saída em cada passo de tempo dependa principalmente do contexto local. Com base nessa observação, propomos a Poda de Computação Adaptativa (ACP) para o FoX, um método que poda dinamicamente as computações envolvendo dependências entrada-saída que são fortemente atenuadas pelo portão de esquecimento. Isso é alcançado usando um limite de poda definido dinamicamente que garante que os pesos de atenção podados permaneçam insignificantes. Aplicamos o ACP ao pré-treinamento de modelos de linguagem com o FoX e mostramos que ele reduz consistentemente o número de FLOPs na atenção softmax em cerca de 70% em diferentes tamanhos de modelos e comprimentos de contexto, resultando em uma melhoria de aproximadamente 10% a 35% na taxa de processamento do treinamento. Além disso, comprimentos de contexto mais longos proporcionam maiores economias computacionais. Todas essas melhorias de velocidade são alcançadas sem qualquer degradação de desempenho. Também realizamos várias análises para fornecer insights mais profundos sobre nosso método, como examinar os padrões de poda e analisar a distribuição das economias de FLOPs entre diferentes cabeças de atenção. Nosso código está disponível em https://github.com/zhixuan-lin/arctic-fox.
English
The recently proposed Forgetting Transformer (FoX) incorporates a forget gate
into softmax attention and has shown consistently better or on-par performance
compared to the standard RoPE-based Transformer. Notably, many attention heads
in FoX tend to forget quickly, causing their output at each timestep to rely
primarily on the local context. Based on this observation, we propose Adaptive
Computation Pruning (ACP) for FoX, a method that dynamically prunes
computations involving input-output dependencies that are strongly decayed by
the forget gate. This is achieved using a dynamically set pruning threshold
that ensures that the pruned attention weights remain negligible. We apply ACP
to language model pretraining with FoX and show it consistently reduces the
number of FLOPs in softmax attention by around 70% across different model sizes
and context lengths, resulting in a roughly 10% to 35% improvement in training
throughput. Furthermore, longer context lengths yield greater computational
savings. All these speed improvements are achieved without any performance
degradation. We also perform several analyses to provide deeper insights into
our method, such as examining the pruning patterns and analyzing the
distribution of FLOP savings across different attention heads. Our code is
available at https://github.com/zhixuan-lin/arctic-fox.Summary
AI-Generated Summary