スキャン&スナップ:1層Transformerにおけるトレーニングダイナミクスとトークン構成の理解
Scan and Snap: Understanding Training Dynamics and Token Composition in 1-layer Transformer
May 25, 2023
著者: Yuandong Tian, Yiping Wang, Beidi Chen, Simon Du
cs.AI
要旨
Transformerアーキテクチャは、複数の研究領域で印象的な性能を示し、多くのニューラルネットワークモデルの基盤となっています。しかし、その動作原理については限られた理解しかありません。特に、単純な予測損失を用いた場合、勾配訓練ダイナミクスからどのように表現が生まれるかは謎のままです。本論文では、1層のTransformer(1つのセルフアテンションレイヤーと1つのデコーダーレイヤーで構成)について、次のトークン予測タスクにおけるSGD訓練ダイナミクスを数学的に厳密に分析します。セルフアテンションレイヤーが入力トークンを組み合わせる動的プロセスのブラックボックスを開き、その背後にある帰納的バイアスの本質を明らかにします。具体的には、(a)位置エンコーディングなし、(b)長い入力シーケンス、(c)デコーダーレイヤーがセルフアテンションレイヤーよりも速く学習する、という仮定の下で、セルフアテンションが識別的スキャンアルゴリズムとして機能することを証明します。均一なアテンションから始まり、特定の次のトークンを予測するために、異なるキートークンにより多く注意を向け、異なる次のトークンにまたがって出現する共通のキートークンにはあまり注意を向けなくなります。異なるトークンの中では、キーとクエリトークンの共起頻度が低いものから高いものの順に、アテンションの重みを徐々に減らしていきます。興味深いことに、このプロセスは勝者総取りにはならず、2つのレイヤーの学習率によって制御可能な相転移によって減速し、(ほぼ)固定されたトークンの組み合わせを残します。この「スキャン&スナップ」ダイナミクスを、合成データと実世界のデータ(WikiText)で検証します。
English
Transformer architecture has shown impressive performance in multiple
research domains and has become the backbone of many neural network models.
However, there is limited understanding on how it works. In particular, with a
simple predictive loss, how the representation emerges from the gradient
training dynamics remains a mystery. In this paper, for 1-layer
transformer with one self-attention layer plus one decoder layer, we analyze
its SGD training dynamics for the task of next token prediction in a
mathematically rigorous manner. We open the black box of the dynamic process of
how the self-attention layer combines input tokens, and reveal the nature of
underlying inductive bias. More specifically, with the assumption (a) no
positional encoding, (b) long input sequence, and (c) the decoder layer learns
faster than the self-attention layer, we prove that self-attention acts as a
discriminative scanning algorithm: starting from uniform attention, it
gradually attends more to distinct key tokens for a specific next token to be
predicted, and pays less attention to common key tokens that occur across
different next tokens. Among distinct tokens, it progressively drops attention
weights, following the order of low to high co-occurrence between the key and
the query token in the training set. Interestingly, this procedure does not
lead to winner-takes-all, but decelerates due to a phase transition that
is controllable by the learning rates of the two layers, leaving (almost) fixed
token combination. We verify this \emph{scan and snap} dynamics on
synthetic and real-world data (WikiText).