DoReMi: 데이터 혼합 최적화를 통한 언어 모델 사전 학습 가속화
DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining
May 17, 2023
저자: Sang Michael Xie, Hieu Pham, Xuanyi Dong, Nan Du, Hanxiao Liu, Yifeng Lu, Percy Liang, Quoc V. Le, Tengyu Ma, Adams Wei Yu
cs.AI
초록
사전 학습 데이터 도메인(예: 위키백과, 책, 웹 텍스트)의 혼합 비율은 언어 모델(LM)의 성능에 큰 영향을 미칩니다. 본 논문에서는 도메인 재가중치 최소최대 최적화(Domain Reweighting with Minimax Optimization, DoReMi)를 제안합니다. 이 방법은 먼저 그룹 분포 강건 최적화(Group DRO)를 사용하여 작은 프록시 모델을 학습시켜, 다운스트림 작업에 대한 지식 없이도 도메인 가중치(혼합 비율)를 생성합니다. 그런 다음 이 도메인 가중치를 사용하여 데이터셋을 재샘플링하고, 더 큰 전체 크기의 모델을 학습시킵니다. 실험에서는 280M 파라미터의 프록시 모델에 DoReMi를 적용하여 8B 파라미터 모델(30배 더 큰)을 더 효율적으로 학습시키기 위한 도메인 가중치를 찾았습니다. The Pile 데이터셋에서 DoReMi는 특정 도메인의 가중치를 낮추더라도 모든 도메인에서 perplexity를 개선했습니다. DoReMi는 The Pile의 기본 도메인 가중치를 사용하여 학습한 베이스라인 모델 대비 평균 few-shot 다운스트림 정확도를 6.5% 향상시켰으며, 베이스라인 정확도에 도달하는 데 필요한 학습 단계를 2.6배 줄였습니다. GLaM 데이터셋에서 DoReMi는 다운스트림 작업에 대한 지식이 없음에도 불구하고, 다운스트림 작업에 맞춰 조정된 도메인 가중치를 사용한 성능과 동등한 결과를 보였습니다.
English
The mixture proportions of pretraining data domains (e.g., Wikipedia, books,
web text) greatly affect language model (LM) performance. In this paper, we
propose Domain Reweighting with Minimax Optimization (DoReMi), which first
trains a small proxy model using group distributionally robust optimization
(Group DRO) over domains to produce domain weights (mixture proportions)
without knowledge of downstream tasks. We then resample a dataset with these
domain weights and train a larger, full-sized model. In our experiments, we use
DoReMi on a 280M-parameter proxy model to find domain weights for training an
8B-parameter model (30x larger) more efficiently. On The Pile, DoReMi improves
perplexity across all domains, even when it downweights a domain. DoReMi
improves average few-shot downstream accuracy by 6.5% over a baseline model
trained using The Pile's default domain weights and reaches the baseline
accuracy with 2.6x fewer training steps. On the GLaM dataset, DoReMi, which has
no knowledge of downstream tasks, even matches the performance of using domain
weights tuned on downstream tasks.