Grootschalige transfer learning voor tabelgegevens via taalmodellering
Large Scale Transfer Learning for Tabular Data via Language Modeling
June 17, 2024
Auteurs: Josh Gardner, Juan C. Perdomo, Ludwig Schmidt
cs.AI
Samenvatting
Tabelgegevens -- gestructureerde, heterogene, spreadsheet-achtige gegevens met rijen en kolommen -- worden in de praktijk veel gebruikt in verschillende domeinen. Echter, hoewel recente foundation models de noodzaak hebben verminderd om taakspecifieke datasets en voorspellers te ontwikkelen in domeinen zoals taalmodellering en computervisie, heeft dit transfer learning-paradigma niet een vergelijkbare impact gehad in het domein van tabelgegevens. In dit werk streven we ernaar deze kloof te verkleinen en presenteren we TabuLa-8B, een taalmodel voor tabelvoorspelling. We definiëren een proces voor het extraheren van een grote, hoogwaardige trainingsdataset uit het TabLib-corpus, waarbij we methoden voorstellen voor het filteren en kwaliteitscontrole van tabelgegevens. Met behulp van de resulterende dataset, die bestaat uit meer dan 1,6 miljard rijen uit 3,1 miljoen unieke tabellen, fine-tunen we een Llama 3-8B groot taalmodel (LLM) voor tabelgegevensvoorspelling (classificatie en gebinnde regressie) met behulp van een nieuw packing- en attentieschema voor tabelvoorspelling. Door evaluatie over een testsuite van 329 datasets, vinden we dat TabuLa-8B een zero-shot nauwkeurigheid heeft op onbekende tabellen die meer dan 15 procentpunten (pp) hoger is dan willekeurig gissen, een prestatie die niet mogelijk is met bestaande state-of-the-art tabelvoorspellingsmodellen (bijv. XGBoost, TabPFN). In de few-shot setting (1-32 shots), zonder enige fine-tuning op de doeldatasets, is TabuLa-8B 5-15 pp nauwkeuriger dan XGBoost en TabPFN-modellen die expliciet getraind zijn op gelijke, of zelfs tot 16x meer gegevens. We maken ons model, code en gegevens beschikbaar bij de publicatie van dit artikel.
English
Tabular data -- structured, heterogeneous, spreadsheet-style data with rows
and columns -- is widely used in practice across many domains. However, while
recent foundation models have reduced the need for developing task-specific
datasets and predictors in domains such as language modeling and computer
vision, this transfer learning paradigm has not had similar impact in the
tabular domain. In this work, we seek to narrow this gap and present TabuLa-8B,
a language model for tabular prediction. We define a process for extracting a
large, high-quality training dataset from the TabLib corpus, proposing methods
for tabular data filtering and quality control. Using the resulting dataset,
which comprises over 1.6B rows from 3.1M unique tables, we fine-tune a Llama
3-8B large language model (LLM) for tabular data prediction (classification and
binned regression) using a novel packing and attention scheme for tabular
prediction. Through evaluation across a test suite of 329 datasets, we find
that TabuLa-8B has zero-shot accuracy on unseen tables that is over 15
percentage points (pp) higher than random guessing, a feat that is not possible
with existing state-of-the-art tabular prediction models (e.g. XGBoost,
TabPFN). In the few-shot setting (1-32 shots), without any fine-tuning on the
target datasets, TabuLa-8B is 5-15 pp more accurate than XGBoost and TabPFN
models that are explicitly trained on equal, or even up to 16x more data. We
release our model, code, and data along with the publication of this paper.