스캔 앤 스냅: 1층 트랜스포머의 학습 동역학과 토큰 구성 이해
Scan and Snap: Understanding Training Dynamics and Token Composition in 1-layer Transformer
May 25, 2023
저자: Yuandong Tian, Yiping Wang, Beidi Chen, Simon Du
cs.AI
초록
Transformer 아키텍처는 여러 연구 분야에서 인상적인 성능을 보여주며 많은 신경망 모델의 핵심이 되었습니다. 그러나 그 작동 방식에 대한 이해는 여전히 제한적입니다. 특히, 단순한 예측 손실 함수를 사용할 때, 그레이디언트 훈련 역학을 통해 어떻게 표현이 형성되는지는 여전히 미스터리로 남아 있습니다. 본 논문에서는 하나의 self-attention 층과 하나의 디코더 층으로 구성된 1층 Transformer를 대상으로, 다음 토큰 예측 작업에 대한 SGD 훈련 역학을 수학적으로 엄밀하게 분석합니다. 우리는 self-attention 층이 입력 토큰을 결합하는 동적 과정의 블랙박스를 열고, 내재된 귀납적 편향의 본질을 밝힙니다. 더 구체적으로, (a) 위치 인코딩이 없고, (b) 입력 시퀀스가 길며, (c) 디코더 층이 self-attention 층보다 빠르게 학습한다는 가정 하에, self-attention이 차별적 스캐닝 알고리즘으로 작동함을 증명합니다: 균일한 주의에서 시작하여, 특정 다음 토큰을 예측하기 위해 점차적으로 구별되는 키 토큰에 더 주의를 기울이고, 다양한 다음 토큰에서 공통적으로 나타나는 키 토큰에는 덜 주의를 기울입니다. 구별되는 토큰들 중에서는, 키와 쿼리 토큰 간의 훈련 데이터셋에서의 공현 빈도가 낮은 순서부터 높은 순서로 점진적으로 주의 가중치를 감소시킵니다. 흥미롭게도, 이 과정은 승자독식으로 이어지지 않고, 두 층의 학습률에 의해 제어 가능한 위상 전환으로 인해 감속되며, (거의) 고정된 토큰 조합을 남깁니다. 우리는 이러한 \emph{스캔 및 스냅} 역학을 합성 데이터와 실제 데이터(WikiText)에서 검증합니다.
English
Transformer architecture has shown impressive performance in multiple
research domains and has become the backbone of many neural network models.
However, there is limited understanding on how it works. In particular, with a
simple predictive loss, how the representation emerges from the gradient
training dynamics remains a mystery. In this paper, for 1-layer
transformer with one self-attention layer plus one decoder layer, we analyze
its SGD training dynamics for the task of next token prediction in a
mathematically rigorous manner. We open the black box of the dynamic process of
how the self-attention layer combines input tokens, and reveal the nature of
underlying inductive bias. More specifically, with the assumption (a) no
positional encoding, (b) long input sequence, and (c) the decoder layer learns
faster than the self-attention layer, we prove that self-attention acts as a
discriminative scanning algorithm: starting from uniform attention, it
gradually attends more to distinct key tokens for a specific next token to be
predicted, and pays less attention to common key tokens that occur across
different next tokens. Among distinct tokens, it progressively drops attention
weights, following the order of low to high co-occurrence between the key and
the query token in the training set. Interestingly, this procedure does not
lead to winner-takes-all, but decelerates due to a phase transition that
is controllable by the learning rates of the two layers, leaving (almost) fixed
token combination. We verify this \emph{scan and snap} dynamics on
synthetic and real-world data (WikiText).