細節中的魔鬼:實現用於訓練專業混合專家模型的負載平衡損失
Demons in the Detail: On Implementing Load Balancing Loss for Training Specialized Mixture-of-Expert Models
January 21, 2025
作者: Zihan Qiu, Zeyu Huang, Bo Zheng, Kaiyue Wen, Zekun Wang, Rui Men, Ivan Titov, Dayiheng Liu, Jingren Zhou, Junyang Lin
cs.AI
摘要
本文重新檢視在訓練混合專家模型(MoEs)時實施負載平衡損失(LBL)。具體而言,MoEs的LBL被定義為 N_E sum_{i=1}^{N_E} f_i p_i,其中 N_E 是專家的總數,f_i 代表選擇專家 i 的頻率,而 p_i 則表示專家 i 的平均閘控分數。現有的MoE訓練框架通常採用並行訓練策略,以便在微批次內計算 f_i 和LBL,然後在並行組中進行平均。實質上,用於訓練十億規模LLMs的微批次通常包含非常少的序列。因此,微批次的LBL幾乎達到序列級別,並且路由器被推動以在每個序列內均勻分配令牌。在這種嚴格的限制下,即使是來自特定領域序列(例如代碼)的令牌也會均勻路由到所有專家,從而抑制專家的專業化。在這項工作中,我們提出使用全局批次計算LBL以放寬此限制。由於全局批次包含比微批次更多樣化的序列,這將鼓勵在語料庫級別實現負載平衡。具體而言,我們引入額外的通信步驟來同步微批次間的 f_i,然後使用它來計算LBL。通過對基於MoEs的LLMs進行實驗(總參數高達42.8B,令牌數量達到400B),我們驚訝地發現全局批次的LBL策略在預訓練困惑度和下游任務中都取得了出色的性能增益。我們的分析顯示,全局批次的LBL還大大提高了MoE專家的領域專業化。
English
This paper revisits the implementation of
Load-balancing Loss (LBL) when training
Mixture-of-Experts (MoEs) models. Specifically, LBL for MoEs is defined as N_E
sum_{i=1}^{N_E} f_i p_i, where N_E is the total number of experts, f_i
represents the frequency of expert i being selected, and p_i denotes the
average gating score of the expert i. Existing MoE training frameworks
usually employ the parallel training strategy so that f_i and the LBL are
calculated within a micro-batch and then averaged across parallel
groups. In essence, a micro-batch for training billion-scale LLMs normally
contains very few sequences. So, the micro-batch LBL is almost at the sequence
level, and the router is pushed to distribute the token evenly within each
sequence. Under this strict constraint, even tokens from a domain-specific
sequence (e.g., code) are uniformly routed to all experts, thereby
inhibiting expert specialization. In this work, we propose calculating LBL
using a global-batch to loose this constraint. Because a
global-batch contains much more diverse sequences than a micro-batch, which
will encourage load balance at the corpus level. Specifically, we introduce an
extra communication step to synchronize f_i across micro-batches and then use
it to calculate the LBL. Through experiments on training MoEs-based LLMs (up to
42.8B total parameters and 400B tokens), we surprisingly
find that the global-batch LBL strategy yields excellent performance gains in
both pre-training perplexity and downstream tasks. Our analysis reveals that
the global-batch LBL also greatly improves the domain specialization of MoE
experts.Summary
AI-Generated Summary