Melhorando Modelos Mundiais de Transformadores para RL com Eficiência de Dados
Improving Transformer World Models for Data-Efficient RL
February 3, 2025
Autores: Antoine Dedieu, Joseph Ortiz, Xinghua Lou, Carter Wendelken, Wolfgang Lehrach, J Swaroop Guntupalli, Miguel Lazaro-Gredilla, Kevin Patrick Murphy
cs.AI
Resumo
Apresentamos uma abordagem ao RL baseado em modelos que alcança um novo estado da arte no desafiador benchmark Craftax-classic, um jogo de sobrevivência 2D de mundo aberto que requer que os agentes demonstrem uma ampla gama de habilidades gerais - como forte generalização, exploração profunda e raciocínio de longo prazo. Com uma série de escolhas de design cuidadosas voltadas para melhorar a eficiência da amostragem, nosso algoritmo de MBRL alcança uma recompensa de 67,4% após apenas 1 milhão de passos de ambiente, superando significativamente o DreamerV3, que alcança 53,2%, e, pela primeira vez, supera o desempenho humano de 65,0%. Nosso método começa construindo uma linha de base sem modelo de última geração, usando uma arquitetura de política inovadora que combina CNNs e RNNs. Em seguida, adicionamos três melhorias à configuração padrão de MBRL: (a) "Dyna com aquecimento", que treina a política em dados reais e imaginários, (b) "tokenizador de vizinho mais próximo" em patches de imagem, que melhora o esquema para criar os inputs do modelo de mundo transformador (TWM), e (c) "forçamento de professor em bloco", que permite ao TWM raciocinar conjuntamente sobre os tokens futuros do próximo passo de tempo.
English
We present an approach to model-based RL that achieves a new state of the art
performance on the challenging Craftax-classic benchmark, an open-world 2D
survival game that requires agents to exhibit a wide range of general abilities
-- such as strong generalization, deep exploration, and long-term reasoning.
With a series of careful design choices aimed at improving sample efficiency,
our MBRL algorithm achieves a reward of 67.4% after only 1M environment steps,
significantly outperforming DreamerV3, which achieves 53.2%, and, for the first
time, exceeds human performance of 65.0%. Our method starts by constructing a
SOTA model-free baseline, using a novel policy architecture that combines CNNs
and RNNs. We then add three improvements to the standard MBRL setup: (a) "Dyna
with warmup", which trains the policy on real and imaginary data, (b) "nearest
neighbor tokenizer" on image patches, which improves the scheme to create the
transformer world model (TWM) inputs, and (c) "block teacher forcing", which
allows the TWM to reason jointly about the future tokens of the next timestep.Summary
AI-Generated Summary