Les jetons critiques comptent : l'estimation contrastive au niveau des jetons améliore la capacité de raisonnement des LLM.Critical Tokens Matter: Token-Level Contrastive Estimation Enhence LLM's
Reasoning Capability
Les grands modèles de langage (LLMs) ont montré des performances remarquables dans les tâches de raisonnement. Ils utilisent la génération de jetons autorégressive pour construire des trajectoires de raisonnement, permettant le développement d'une chaîne de pensée cohérente. Dans ce travail, nous explorons l'impact des jetons individuels sur les résultats finaux des tâches de raisonnement. Nous identifions l'existence de "jetons critiques" qui conduisent à des trajectoires de raisonnement incorrectes dans les LLMs. Plus précisément, nous constatons que les LLMs ont tendance à produire des résultats positifs lorsqu'ils sont contraints de décoder d'autres jetons au lieu des jetons critiques. Motivés par cette observation, nous proposons une nouvelle approche - cDPO - conçue pour reconnaître automatiquement et mener des récompenses au niveau du jeton pour les jetons critiques pendant le processus d'alignement. Plus précisément, nous développons une approche d'estimation contrastive pour identifier automatiquement les jetons critiques. Cela est réalisé en comparant la probabilité de génération des modèles positif et négatif. Pour ce faire, nous affinons séparément les modèles positif et négatif sur diverses trajectoires de raisonnement, leur permettant ainsi d'identifier les jetons critiques au sein des trajectoires incorrectes qui contribuent à des résultats erronés. De plus, pour aligner davantage le modèle avec les informations des jetons critiques pendant le processus d'alignement, nous étendons les algorithmes DPO conventionnels au niveau du jeton et utilisons la probabilité différentielle des modèles positif et négatif susmentionnés comme poids important pour l'apprentissage du DPO au niveau du jeton. Les résultats expérimentaux sur les référentiels GSM8K et MATH500 avec les modèles largement utilisés Llama-3 (8B et 70B) et deepseek-math (7B) démontrent l'efficacité de l'approche proposée cDPO.