TART: Un módulo Transformer plug-and-play para razonamiento independiente de la tarea
TART: A plug-and-play Transformer module for task-agnostic reasoning
June 13, 2023
Autores: Kush Bhatia, Avanika Narayan, Christopher De Sa, Christopher Ré
cs.AI
Resumen
Los modelos de lenguaje de gran escala (LLMs, por sus siglas en inglés) exhiben habilidades de aprendizaje en contexto que permiten que el mismo modelo realice varias tareas sin necesidad de entrenamiento específico para cada una. En contraste, los enfoques tradicionales de adaptación, como el ajuste fino, modifican los modelos subyacentes para cada tarea específica. Sin embargo, el aprendizaje en contexto consistentemente tiene un rendimiento inferior a los enfoques de ajuste específico, incluso cuando se presentan los mismos ejemplos. Mientras que la mayoría de los enfoques existentes (por ejemplo, la ingeniería de prompts) se centran en las representaciones aprendidas por los LLMs para cerrar esta brecha de rendimiento, nuestro análisis revela que las representaciones de los LLMs contienen suficiente información para hacer buenas predicciones. Por ello, nos enfocamos en las habilidades de razonamiento de los LLMs y demostramos que esta brecha de rendimiento existe debido a su incapacidad para realizar tareas simples de razonamiento probabilístico. Esto plantea una pregunta intrigante: ¿Son los LLMs realmente capaces de aprender a razonar de manera independiente de la tarea? Respondemos afirmativamente a esta pregunta y proponemos TART, que mejora genéricamente las habilidades de razonamiento de un LLM utilizando un módulo de razonamiento basado en Transformers entrenado sintéticamente. TART entrena este módulo de razonamiento de manera independiente de la tarea utilizando únicamente tareas sintéticas de regresión logística y lo combina con un modelo preentrenado del mundo real sin necesidad de entrenamiento adicional. Con un único módulo de inferencia, TART mejora el rendimiento en diferentes familias de modelos (GPT-Neo, Pythia, BLOOM), tamaños de modelos (100M - 6B), tareas (14 tareas de clasificación binaria en NLP) e incluso en diferentes modalidades (audio y visión). Además, en el Benchmark RAFT, TART mejora el rendimiento de GPT-Neo (125M) de tal manera que supera a BLOOM (176B) y se encuentra dentro del 4% de GPT-3 (175B). Nuestro código y modelos están disponibles en https://github.com/HazyResearch/TART.
English
Large language models (LLMs) exhibit in-context learning abilities which
enable the same model to perform several tasks without any task-specific
training. In contrast, traditional adaptation approaches, such as fine-tuning,
modify the underlying models for each specific task. In-context learning,
however, consistently underperforms task-specific tuning approaches even when
presented with the same examples. While most existing approaches (e.g., prompt
engineering) focus on the LLM's learned representations to patch this
performance gap, our analysis actually reveal that LLM representations contain
sufficient information to make good predictions. As such, we focus on the LLM's
reasoning abilities and demonstrate that this performance gap exists due to
their inability to perform simple probabilistic reasoning tasks. This raises an
intriguing question: Are LLMs actually capable of learning how to reason in a
task-agnostic manner? We answer this in the affirmative and propose TART which
generically improves an LLM's reasoning abilities using a synthetically trained
Transformer-based reasoning module. TART trains this reasoning module in a
task-agnostic manner using only synthetic logistic regression tasks and
composes it with an arbitrary real-world pre-trained model without any
additional training. With a single inference module, TART improves performance
across different model families (GPT-Neo, Pythia, BLOOM), model sizes (100M -
6B), tasks (14 NLP binary classification tasks), and even across different
modalities (audio and vision). Additionally, on the RAFT Benchmark, TART
improves GPT-Neo (125M)'s performance such that it outperforms BLOOM (176B),
and is within 4% of GPT-3 (175B). Our code and models are available at
https://github.com/HazyResearch/TART .