数据混合代理:学习为持续预训练重新加权领域
Data Mixing Agent: Learning to Re-weight Domains for Continual Pre-training
July 21, 2025
作者: Kailai Yang, Xiao Liu, Lei Ji, Hao Li, Yeyun Gong, Peng Cheng, Mao Yang
cs.AI
摘要
在特定任务的小规模数据上进行持续预训练是提升大语言模型在新目标领域表现的有效方法,但这也可能导致其原有能力的灾难性遗忘。一种常见的解决方案是在领域空间中对源领域和目标领域的训练数据进行混合重加权,以实现性能的平衡。以往的领域重加权策略依赖于基于人类直觉或经验结果的手动指定启发式方法。在本研究中,我们证明了更通用的启发式方法可以通过参数化实现,为此提出了数据混合代理(Data Mixing Agent),这是首个基于模型的端到端框架,能够学习如何对领域进行重加权。该代理通过强化学习,在大量数据混合轨迹及其对应的评估环境反馈中,学习可泛化的启发式规则。在数学推理领域的持续预训练实验中,数据混合代理在源领域和目标领域基准测试中均实现了优于强基线的平衡性能。此外,它在未见过的源领域、目标模型和领域空间上表现出良好的泛化能力,无需重新训练。直接应用于代码生成领域也表明其跨目标领域的适应性。进一步分析展示了代理的启发式规则与人类直觉的高度一致性,以及其在减少源领域数据使用的情况下实现更优模型性能的效率。
English
Continual pre-training on small-scale task-specific data is an effective
method for improving large language models in new target fields, yet it risks
catastrophic forgetting of their original capabilities. A common solution is to
re-weight training data mixtures from source and target fields on a domain
space to achieve balanced performance. Previous domain reweighting strategies
rely on manual designation with certain heuristics based on human intuition or
empirical results. In this work, we prove that more general heuristics can be
parameterized by proposing Data Mixing Agent, the first model-based, end-to-end
framework that learns to re-weight domains. The agent learns generalizable
heuristics through reinforcement learning on large quantities of data mixing
trajectories with corresponding feedback from an evaluation environment.
Experiments in continual pre-training on math reasoning show that Data Mixing
Agent outperforms strong baselines in achieving balanced performance across
source and target field benchmarks. Furthermore, it generalizes well across
unseen source fields, target models, and domain spaces without retraining.
Direct application to the code generation field also indicates its adaptability
across target domains. Further analysis showcases the agents' well-aligned
heuristics with human intuitions and their efficiency in achieving superior
model performance with less source-field data.