为何低精度Transformer训练失败:基于Flash Attention的分析
Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
October 5, 2025
作者: Haiquan Qiu, Quanming Yao
cs.AI
摘要
追求计算效率推动了低精度格式在Transformer模型训练中的应用。然而,这一进展常因训练过程中的不稳定现象而受阻。本文首次从机制上解释了一个长期未解的故障案例:在低精度设置下使用Flash Attention进行训练时,会导致灾难性的损失爆炸。我们的深入分析揭示,该故障并非随机现象,而是由两个相互交织的因素引起:注意力机制中相似低秩表示的出现,以及低精度算术中固有舍入误差的累积效应。我们展示了这些因素如何形成误差积累的恶性循环,从而破坏权重更新,最终导致训练动态失控。为验证我们的发现,我们对Flash Attention进行了最小程度的修改,以减轻舍入误差的偏差。这一简单改动稳定了训练过程,证实了我们的分析,并为这一长期存在的问题提供了实用解决方案。
English
The pursuit of computational efficiency has driven the adoption of
low-precision formats for training transformer models. However, this progress
is often hindered by notorious training instabilities. This paper provides the
first mechanistic explanation for a long-standing and unresolved failure case
where training with flash attention in low-precision settings leads to
catastrophic loss explosions. Our in-depth analysis reveals that the failure is
not a random artifact but caused by two intertwined phenomena: the emergence of
similar low-rank representations within the attention mechanism and the
compounding effect of biased rounding errors inherent in low-precision
arithmetic. We demonstrate how these factors create a vicious cycle of error
accumulation that corrupts weight updates, ultimately derailing the training
dynamics. To validate our findings, we introduce a minimal modification to the
flash attention that mitigates the bias in rounding errors. This simple change
stabilizes the training process, confirming our analysis and offering a
practical solution to this persistent problem.