Waarom training van transformers met lage precisie faalt: een analyse van Flash Attention
Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
October 5, 2025
Auteurs: Haiquan Qiu, Quanming Yao
cs.AI
Samenvatting
De zoektocht naar computationele efficiëntie heeft geleid tot de adoptie van
laagprecisieformaten voor het trainen van transformermodellen. Deze vooruitgang
wordt echter vaak belemmerd door beruchte trainingsinstabiliteiten. Dit artikel
biedt de eerste mechanistische verklaring voor een lang bestaand en onopgelost
faalgeval waarbij trainen met flash attention in laagprecisie-instellingen leidt
tot catastrofale verliesexplosies. Onze diepgaande analyse onthult dat het falen
geen willekeurig artefact is, maar wordt veroorzaakt door twee verweven
verschijnselen: het ontstaan van vergelijkbare laagrangrepresentaties binnen het
attention-mechanisme en het cumulatieve effect van bevooroordeelde
afrondingsfouten die inherent zijn aan laagprecisie-rekenkunde. We tonen aan hoe
deze factoren een vicieuze cirkel van foutaccumulatie creëren die
gewichtsupdates corrumpeert en uiteindelijk de trainingsdynamiek ontspoort. Om
onze bevindingen te valideren, introduceren we een minimale aanpassing aan de
flash attention die de bias in afrondingsfouten vermindert. Deze eenvoudige
wijziging stabiliseert het trainingsproces, bevestigt onze analyse en biedt een
praktische oplossing voor dit hardnekkige probleem.
English
The pursuit of computational efficiency has driven the adoption of
low-precision formats for training transformer models. However, this progress
is often hindered by notorious training instabilities. This paper provides the
first mechanistic explanation for a long-standing and unresolved failure case
where training with flash attention in low-precision settings leads to
catastrophic loss explosions. Our in-depth analysis reveals that the failure is
not a random artifact but caused by two intertwined phenomena: the emergence of
similar low-rank representations within the attention mechanism and the
compounding effect of biased rounding errors inherent in low-precision
arithmetic. We demonstrate how these factors create a vicious cycle of error
accumulation that corrupts weight updates, ultimately derailing the training
dynamics. To validate our findings, we introduce a minimal modification to the
flash attention that mitigates the bias in rounding errors. This simple change
stabilizes the training process, confirming our analysis and offering a
practical solution to this persistent problem.