Imparare a saltare gli strati intermedi dei Transformer
Learning to Skip the Middle Layers of Transformers
June 26, 2025
Autori: Tim Lawson, Laurence Aitchison
cs.AI
Abstract
La computazione condizionale è una strategia popolare per rendere i Transformer più efficienti. I metodi esistenti spesso prendono di mira singoli moduli (ad esempio, strati di mixture-of-experts) o saltano strati in modo indipendente l'uno dall'altro. Tuttavia, la ricerca sull'interpretabilità ha dimostrato che gli strati intermedi dei Transformer presentano una maggiore ridondanza e che gli strati iniziali aggregrano informazioni nelle posizioni dei token. Guidati da queste intuizioni, proponiamo una nuova architettura che salta dinamicamente un numero variabile di strati partendo dal centro verso l'esterno. In particolare, un meccanismo di gate appreso determina se bypassare un intervallo simmetrico di blocchi centrali in base all'input, e un meccanismo di attenzione gated impedisce ai token successivi di prestare attenzione alle posizioni dei token saltate. Le norme residue sono controllate con uno schema 'sandwich' o 'perilayernorm' e la sparsità dei gate con una perdita di regolarizzazione adattativa. Avevamo l'obiettivo di ridurre i requisiti computazionali per token 'più semplici' e potenzialmente favorire una gerarchia rappresentativa multi-livello emergente, ma, alle scale investigate, il nostro approccio non raggiunge miglioramenti nel compromesso tra entropia incrociata di validazione e FLOPs stimati rispetto a baseline dense con meno strati. Rilasciamo il nostro codice su https://github.com/tim-lawson/skip-middle.
English
Conditional computation is a popular strategy to make Transformers more
efficient. Existing methods often target individual modules (e.g.,
mixture-of-experts layers) or skip layers independently of one another.
However, interpretability research has demonstrated that the middle layers of
Transformers exhibit greater redundancy, and that early layers aggregate
information into token positions. Guided by these insights, we propose a novel
architecture that dynamically skips a variable number of layers from the middle
outward. In particular, a learned gating mechanism determines whether to bypass
a symmetric span of central blocks based on the input, and a gated attention
mechanism prevents subsequent tokens from attending to skipped token positions.
Residual norms are controlled with a 'sandwich' or 'perilayernorm' scheme and
gate sparsity with an adaptive regularization loss. We had aimed to reduce
compute requirements for 'simpler' tokens and potentially foster an emergent
multi-level representational hierarchy but, at the scales investigated, our
approach does not achieve improvements in the trade-off between validation
cross-entropy and estimated FLOPs compared to dense baselines with fewer
layers. We release our code at https://github.com/tim-lawson/skip-middle.