GLoRe: Cuándo, dónde y cómo mejorar el razonamiento de los modelos de lenguaje grandes mediante refinamientos globales y locales
GLoRe: When, Where, and How to Improve LLM Reasoning via Global and Local Refinements
February 13, 2024
Autores: Alex Havrilla, Sharath Raparthy, Christoforus Nalmpantis, Jane Dwivedi-Yu, Maksym Zhuravinskyi, Eric Hambro, Roberta Railneau
cs.AI
Resumen
Los modelos de lenguaje de última generación pueden exhibir capacidades impresionantes de refinamiento de razonamiento en tareas de matemáticas, ciencias o programación. Sin embargo, trabajos recientes demuestran que incluso los mejores modelos tienen dificultades para identificar cuándo y dónde refinar sin acceso a retroalimentación externa. Los Modelos de Recompensa Basados en Resultados (ORMs, por sus siglas en inglés), entrenados para predecir la corrección de la respuesta final e indicar cuándo refinar, ofrecen una solución conveniente para decidir cuándo hacerlo. Los Modelos de Recompensa Basados en Procesos (PRMs, por sus siglas en inglés), entrenados para predecir la corrección de los pasos intermedios, pueden entonces usarse para indicar dónde refinar. Sin embargo, son costosos de entrenar, ya que requieren anotaciones humanas extensas. En este artículo, proponemos los Modelos de Recompensa Basados en Resultados Paso a Paso (SORMs, por sus siglas en inglés), que se entrenan únicamente con datos sintéticos para aproximar la recompensa futura esperada de la política óptima o \(V^{\star}\). Más específicamente, los SORMs se entrenan para predecir la corrección de la respuesta final cuando se muestrea la política actual muchas veces (en lugar de solo una vez, como en el caso de los ORMs). Nuestros experimentos muestran que los SORMs pueden detectar con mayor precisión los pasos de razonamiento incorrectos en comparación con los ORMs, mejorando así la precisión en tareas de refinamiento. Luego entrenamos modelos de refinamiento global, que toman solo la pregunta y una solución preliminar como entrada y predicen una solución corregida, y modelos de refinamiento local, que también toman como entrada una crítica que indica la ubicación del primer error de razonamiento. Generamos datos de entrenamiento para ambos modelos de manera sintética reutilizando los datos utilizados para entrenar el SORM. Encontramos que combinar refinamientos globales y locales, utilizando el ORM como un reranker, supera significativamente a cualquiera de los dos por separado, así como a una línea base de la mejor de tres muestras. Con esta estrategia, podemos mejorar la precisión de un modelo LLaMA-2 13B (ya ajustado con aprendizaje por refuerzo) en GSM8K del 53% al 65% cuando se muestrea de manera codiciosa.
English
State-of-the-art language models can exhibit impressive reasoning refinement
capabilities on math, science or coding tasks. However, recent work
demonstrates that even the best models struggle to identify when and
where to refine without access to external feedback. Outcome-based Reward
Models (ORMs), trained to predict correctness of the final answer
indicating when to refine, offer one convenient solution for deciding when to
refine. Process Based Reward Models (PRMs), trained to predict
correctness of intermediate steps, can then be used to indicate where to
refine. But they are expensive to train, requiring extensive human annotations.
In this paper, we propose Stepwise ORMs (SORMs) which are trained,
only on synthetic data, to approximate the expected future reward of the
optimal policy or V^{star}. More specifically, SORMs are trained to predict
the correctness of the final answer when sampling the current policy many times
(rather than only once as in the case of ORMs). Our experiments show that SORMs
can more accurately detect incorrect reasoning steps compared to ORMs, thus
improving downstream accuracy when doing refinements. We then train
global refinement models, which take only the question and a draft
solution as input and predict a corrected solution, and local
refinement models which also take as input a critique indicating the location
of the first reasoning error. We generate training data for both models
synthetically by reusing data used to train the SORM. We find combining global
and local refinements, using the ORM as a reranker, significantly outperforms
either one individually, as well as a best of three sample baseline. With this
strategy we can improve the accuracy of a LLaMA-2 13B model (already fine-tuned
with RL) on GSM8K from 53\% to 65\% when greedily sampled.