TART: タスクに依存しない推論のためのプラグアンドプレイTransformerモジュール
TART: A plug-and-play Transformer module for task-agnostic reasoning
June 13, 2023
著者: Kush Bhatia, Avanika Narayan, Christopher De Sa, Christopher Ré
cs.AI
要旨
大規模言語モデル(LLM)は、コンテキスト内学習能力を示し、特定のタスクごとのトレーニングなしに複数のタスクを実行できる。一方、従来の適応手法(例えばファインチューニング)は、各タスクごとに基盤となるモデルを変更する。しかし、コンテキスト内学習は、同じ例が提示された場合でも、タスク固有のチューニング手法に一貫して劣る。既存のアプローチ(例:プロンプトエンジニアリング)の多くは、この性能差を埋めるためにLLMの学習済み表現に焦点を当てているが、我々の分析では、LLMの表現には良好な予測を行うための十分な情報が含まれていることが明らかになった。そのため、我々はLLMの推論能力に注目し、この性能差が単純な確率的推論タスクを実行できないことに起因することを示す。これにより、興味深い疑問が生じる:LLMは実際にタスクに依存しない方法で推論を学習できるのか?我々はこれを肯定し、TARTを提案する。TARTは、合成的にトレーニングされたTransformerベースの推論モジュールを使用して、LLMの推論能力を汎用的に向上させる。TARTは、この推論モジュールを合成ロジスティック回帰タスクのみを使用してタスクに依存しない方法でトレーニングし、任意の実世界の事前トレーニング済みモデルと追加のトレーニングなしに組み合わせる。単一の推論モジュールで、TARTは異なるモデルファミリー(GPT-Neo、Pythia、BLOOM)、モデルサイズ(100M~6B)、タスク(14のNLP二値分類タスク)、さらには異なるモダリティ(音声と視覚)にわたって性能を向上させる。さらに、RAFTベンチマークでは、TARTはGPT-Neo(125M)の性能を向上させ、BLOOM(176B)を上回り、GPT-3(175B)の4%以内に収まる。我々のコードとモデルはhttps://github.com/HazyResearch/TARTで公開されている。
English
Large language models (LLMs) exhibit in-context learning abilities which
enable the same model to perform several tasks without any task-specific
training. In contrast, traditional adaptation approaches, such as fine-tuning,
modify the underlying models for each specific task. In-context learning,
however, consistently underperforms task-specific tuning approaches even when
presented with the same examples. While most existing approaches (e.g., prompt
engineering) focus on the LLM's learned representations to patch this
performance gap, our analysis actually reveal that LLM representations contain
sufficient information to make good predictions. As such, we focus on the LLM's
reasoning abilities and demonstrate that this performance gap exists due to
their inability to perform simple probabilistic reasoning tasks. This raises an
intriguing question: Are LLMs actually capable of learning how to reason in a
task-agnostic manner? We answer this in the affirmative and propose TART which
generically improves an LLM's reasoning abilities using a synthetically trained
Transformer-based reasoning module. TART trains this reasoning module in a
task-agnostic manner using only synthetic logistic regression tasks and
composes it with an arbitrary real-world pre-trained model without any
additional training. With a single inference module, TART improves performance
across different model families (GPT-Neo, Pythia, BLOOM), model sizes (100M -
6B), tasks (14 NLP binary classification tasks), and even across different
modalities (audio and vision). Additionally, on the RAFT Benchmark, TART
improves GPT-Neo (125M)'s performance such that it outperforms BLOOM (176B),
and is within 4% of GPT-3 (175B). Our code and models are available at
https://github.com/HazyResearch/TART .