扫描与捕捉:理解一层Transformer中的训练动态和标记组成
Scan and Snap: Understanding Training Dynamics and Token Composition in 1-layer Transformer
May 25, 2023
作者: Yuandong Tian, Yiping Wang, Beidi Chen, Simon Du
cs.AI
摘要
Transformer架构在多个研究领域展现出令人印象深刻的性能,并已成为许多神经网络模型的支柱。然而,对其工作原理仍知之甚少。特别是,在简单的预测损失下,表示是如何从梯度训练动态中出现仍然是个谜。在本文中,针对具有一个自注意层和一个解码器层的1层Transformer,我们以数学严谨的方式分析其随机梯度下降训练动态,用于下一个标记预测任务。我们揭开了自注意层如何结合输入标记的动态过程的黑匣子,并揭示了潜在的归纳偏差的本质。更具体地,基于以下假设(a)没有位置编码,(b)长输入序列,以及(c)解码器层学习速度比自注意层更快,我们证明了自注意层充当了一种辨别式扫描算法:从均匀关注开始,它逐渐更多地关注不同的关键标记,以便预测特定的下一个标记,并减少对出现在不同下一个标记中的常见关键标记的关注。在不同标记中,它逐渐降低注意力权重,遵循训练集中关键标记与查询标记之间的低到高共现顺序。有趣的是,这个过程并不导致胜者通吃,而是由于两层的学习速率控制的相变而减速,最终留下(几乎)固定的标记组合。我们在合成和真实数据(WikiText)上验证了这种“扫描和捕捉”动态。
English
Transformer architecture has shown impressive performance in multiple
research domains and has become the backbone of many neural network models.
However, there is limited understanding on how it works. In particular, with a
simple predictive loss, how the representation emerges from the gradient
training dynamics remains a mystery. In this paper, for 1-layer
transformer with one self-attention layer plus one decoder layer, we analyze
its SGD training dynamics for the task of next token prediction in a
mathematically rigorous manner. We open the black box of the dynamic process of
how the self-attention layer combines input tokens, and reveal the nature of
underlying inductive bias. More specifically, with the assumption (a) no
positional encoding, (b) long input sequence, and (c) the decoder layer learns
faster than the self-attention layer, we prove that self-attention acts as a
discriminative scanning algorithm: starting from uniform attention, it
gradually attends more to distinct key tokens for a specific next token to be
predicted, and pays less attention to common key tokens that occur across
different next tokens. Among distinct tokens, it progressively drops attention
weights, following the order of low to high co-occurrence between the key and
the query token in the training set. Interestingly, this procedure does not
lead to winner-takes-all, but decelerates due to a phase transition that
is controllable by the learning rates of the two layers, leaving (almost) fixed
token combination. We verify this \emph{scan and snap} dynamics on
synthetic and real-world data (WikiText).