언어 모델링을 위한 비동기식 로컬-SGD 훈련
Asynchronous Local-SGD Training for Language Modeling
January 17, 2024
저자: Bo Liu, Rachita Chhaparia, Arthur Douillard, Satyen Kale, Andrei A. Rusu, Jiajun Shen, Arthur Szlam, Marc'Aurelio Ranzato
cs.AI
초록
로컬 확률적 경사 하강법(Local-SGD), 또는 연합 평균화(federated averaging)로도 불리는 이 방법은 각 장치가 통신당 하나 이상의 SGD 업데이트를 수행하는 분산 최적화 접근법이다. 본 연구는 언어 모델 학습을 위한 비동기식 Local-SGD의 실증적 연구를 제시한다. 즉, 각 작업자는 SGD 단계를 마치자마자 전역 매개변수를 업데이트한다. 우리는 작업자의 하드웨어 이질성, 모델 크기, 작업자 수, 그리고 최적화기가 학습 성능에 미치는 영향을 종합적으로 조사한다. 우리는 단순한 구현에서 비동기식 Local-SGD가 동기식 대비 더 많은 반복을 통해 수렴하며, 전역 모델 매개변수를 더 자주 업데이트함에도 불구하고 더 느리게 수렴함을 발견했다. 작업자 그래디언트가 오래된 경우 전역 매개변수에 대한 모멘텀 가속이 주요 문제로 확인되었다. 우리는 지연된 네스테로프 모멘텀 업데이트를 활용하고 작업자의 계산 속도에 기반하여 로컬 학습 단계를 조정하는 새로운 방법을 제안한다. 이 접근법은 C4 데이터셋에서 최대 1억 5천만 개의 매개변수를 가진 모델로 평가되었으며, 업데이트 단계당 혼란도(perplexity) 측면에서 동기식 Local-SGD와 동등한 성능을 보였고, 실제 소요 시간 측면에서는 이를 크게 능가했다.
English
Local stochastic gradient descent (Local-SGD), also referred to as federated
averaging, is an approach to distributed optimization where each device
performs more than one SGD update per communication. This work presents an
empirical study of {\it asynchronous} Local-SGD for training language models;
that is, each worker updates the global parameters as soon as it has finished
its SGD steps. We conduct a comprehensive investigation by examining how worker
hardware heterogeneity, model size, number of workers, and optimizer could
impact the learning performance. We find that with naive implementations,
asynchronous Local-SGD takes more iterations to converge than its synchronous
counterpart despite updating the (global) model parameters more frequently. We
identify momentum acceleration on the global parameters when worker gradients
are stale as a key challenge. We propose a novel method that utilizes a delayed
Nesterov momentum update and adjusts the workers' local training steps based on
their computation speed. This approach, evaluated with models up to 150M
parameters on the C4 dataset, matches the performance of synchronous Local-SGD
in terms of perplexity per update step, and significantly surpasses it in terms
of wall clock time.