MemDLM : Entraînement de DLM amélioré par la mémoire
MemDLM: Memory-Enhanced DLM Training
March 23, 2026
Auteurs: Zehua Pei, Hui-Ling Zhen, Weizhe Lin, Sinno Jialin Pan, Yunhe Wang, Mingxuan Yuan, Bei Yu
cs.AI
Résumé
Les modèles de langage par diffusion (DLM) présentent des avantages attractifs par rapport aux modèles auto-régressifs (AR), tels qu'un décodage parallèle par attention complète et une génération flexible. Cependant, ils souffrent d'un décalage notable entre l'entraînement et l'inférence : les DLM sont entraînés avec un objectif de prédiction masquée statique et en une seule étape, mais sont déployés via une trajectoire de débruîtage progressive en plusieurs étapes. Nous proposons MemDLM (DLM à mémoire renforcée), qui réduit cet écart en intégrant un processus de débruîtage simulé dans l'entraînement via une optimisation bi-niveau. Une boucle interne met à jour un ensemble de poids rapides, formant une mémoire paramétrique qui capture l'expérience de trajectoire locale de chaque échantillon, tandis qu'une boucle externe met à jour le modèle de base conditionné par cette mémoire. En déchargeant la pression de mémorisation des représentations de tokens vers les paramètres, MemDLM permet une convergence plus rapide et une perte d'entraînement réduite. De plus, la boucle interne peut être réactivée au moment de l'inférence comme étape d'adaptation, générant des gains supplémentaires pour la compréhension de contextes longs. Nous constatons que, lorsqu'elle est activée à l'inférence, cette mémoire paramétrique agit comme un mécanisme émergent de récupération intégré aux poids, aidant MemDLM à réduire davantage les goulots d'étranglement attentionnels au niveau des tokens dans des tâches de récupération difficiles de type "aiguille dans une botte de foin". Code : https://github.com/JarvisPei/MemDLM.
English
Diffusion Language Models (DLMs) offer attractive advantages over Auto-Regressive (AR) models, such as full-attention parallel decoding and flexible generation. However, they suffer from a notable train-inference mismatch: DLMs are trained with a static, single-step masked prediction objective, but deployed through a multi-step progressive denoising trajectory. We propose MemDLM (Memory-Enhanced DLM), which narrows this gap by embedding a simulated denoising process into training via Bi-level Optimization. An inner loop updates a set of fast weights, forming a Parametric Memory that captures the local trajectory experience of each sample, while an outer loop updates the base model conditioned on this memory. By offloading memorization pressure from token representations to parameters, MemDLM yields faster convergence and lower training loss. Moreover, the inner loop can be re-enabled at inference time as an adaptation step, yielding additional gains on long-context understanding. We find that, when activated at inference time, this Parametric Memory acts as an emergent in-weight retrieval mechanism, helping MemDLM further reduce token-level attention bottlenecks on challenging Needle-in-a-Haystack retrieval tasks. Code: https://github.com/JarvisPei/MemDLM.