Lernen, die mittleren Schichten von Transformern zu überspringen
Learning to Skip the Middle Layers of Transformers
June 26, 2025
Autoren: Tim Lawson, Laurence Aitchison
cs.AI
Zusammenfassung
Bedingte Berechnung ist eine verbreitete Strategie, um Transformer-Modelle effizienter zu gestalten. Bestehende Methoden zielen oft auf einzelne Module (z. B. Mixture-of-Experts-Schichten) ab oder überspringen Schichten unabhängig voneinander. Interpretationsstudien haben jedoch gezeigt, dass die mittleren Schichten von Transformern eine größere Redundanz aufweisen und dass frühe Schichten Informationen in Token-Positionen aggregieren. Aufbauend auf diesen Erkenntnissen schlagen wir eine neuartige Architektur vor, die dynamisch eine variable Anzahl von Schichten von der Mitte nach außen überspringt. Insbesondere bestimmt ein gelerntes Gating-Mechanismus basierend auf der Eingabe, ob ein symmetrischer Bereich zentraler Blöcke umgangen werden soll, und ein gated Attention-Mechanismus verhindert, dass nachfolgende Token übersprungene Token-Positionen berücksichtigen. Die Residuen-Normen werden durch ein „Sandwich“- oder „Perilayernorm“-Schema kontrolliert, und die Gate-Sparsity wird durch einen adaptiven Regularisierungsverlust gesteuert. Unser Ziel war es, den Rechenaufwand für „einfachere“ Token zu reduzieren und möglicherweise eine mehrstufige Repräsentationshierarchie zu fördern. In den untersuchten Skalierungen erreicht unser Ansatz jedoch keine Verbesserungen im Trade-off zwischen Validierungs-Kreuzentropie und geschätzten FLOPs im Vergleich zu dichten Baselines mit weniger Schichten. Unser Code ist unter https://github.com/tim-lawson/skip-middle verfügbar.
English
Conditional computation is a popular strategy to make Transformers more
efficient. Existing methods often target individual modules (e.g.,
mixture-of-experts layers) or skip layers independently of one another.
However, interpretability research has demonstrated that the middle layers of
Transformers exhibit greater redundancy, and that early layers aggregate
information into token positions. Guided by these insights, we propose a novel
architecture that dynamically skips a variable number of layers from the middle
outward. In particular, a learned gating mechanism determines whether to bypass
a symmetric span of central blocks based on the input, and a gated attention
mechanism prevents subsequent tokens from attending to skipped token positions.
Residual norms are controlled with a 'sandwich' or 'perilayernorm' scheme and
gate sparsity with an adaptive regularization loss. We had aimed to reduce
compute requirements for 'simpler' tokens and potentially foster an emergent
multi-level representational hierarchy but, at the scales investigated, our
approach does not achieve improvements in the trade-off between validation
cross-entropy and estimated FLOPs compared to dense baselines with fewer
layers. We release our code at https://github.com/tim-lawson/skip-middle.