Warum das Training von Transformatoren mit niedriger Präzision scheitert: Eine Analyse zu Flash Attention
Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
October 5, 2025
papers.authors: Haiquan Qiu, Quanming Yao
cs.AI
papers.abstract
Das Streben nach Recheneffizienz hat die Einführung von Niedrigpräzisionsformaten für das Training von Transformer-Modellen vorangetrieben. Dieser Fortschritt wird jedoch oft durch bekannte Trainingsinstabilitäten behindert. Diese Arbeit liefert die erste mechanistische Erklärung für einen langjährigen und ungelösten Fehlerfall, bei dem das Training mit Flash Attention in Niedrigpräzisionseinstellungen zu katastrophalen Verlustexplosionen führt. Unsere detaillierte Analyse zeigt, dass der Fehler kein zufälliges Artefakt ist, sondern durch zwei miteinander verflochtene Phänomene verursacht wird: das Auftreten ähnlicher niedrigrangiger Repräsentationen innerhalb des Aufmerksamkeitsmechanismus und den kumulativen Effekt von verzerrten Rundungsfehlern, die der Niedrigpräzisionsarithmetik innewohnen. Wir zeigen, wie diese Faktoren einen Teufelskreis der Fehlerakkumulation erzeugen, der Gewichtsaktualisierungen korrumpiert und letztlich die Trainingsdynamik zum Scheitern bringt. Um unsere Erkenntnisse zu validieren, führen wir eine minimale Modifikation der Flash Attention ein, die die Verzerrung in den Rundungsfehlern mildert. Diese einfache Änderung stabilisiert den Trainingsprozess, bestätigt unsere Analyse und bietet eine praktische Lösung für dieses anhaltende Problem.
English
The pursuit of computational efficiency has driven the adoption of
low-precision formats for training transformer models. However, this progress
is often hindered by notorious training instabilities. This paper provides the
first mechanistic explanation for a long-standing and unresolved failure case
where training with flash attention in low-precision settings leads to
catastrophic loss explosions. Our in-depth analysis reveals that the failure is
not a random artifact but caused by two intertwined phenomena: the emergence of
similar low-rank representations within the attention mechanism and the
compounding effect of biased rounding errors inherent in low-precision
arithmetic. We demonstrate how these factors create a vicious cycle of error
accumulation that corrupts weight updates, ultimately derailing the training
dynamics. To validate our findings, we introduce a minimal modification to the
flash attention that mitigates the bias in rounding errors. This simple change
stabilizes the training process, confirming our analysis and offering a
practical solution to this persistent problem.