JaxMARL:JAX 中的多智能體強化學習環境
JaxMARL: Multi-Agent RL Environments in JAX
November 16, 2023
作者: Alexander Rutherford, Benjamin Ellis, Matteo Gallici, Jonathan Cook, Andrei Lupu, Gardar Ingvarsson, Timon Willi, Akbir Khan, Christian Schroeder de Witt, Alexandra Souly, Saptarashmi Bandyopadhyay, Mikayel Samvelyan, Minqi Jiang, Robert Tjarko Lange, Shimon Whiteson, Bruno Lacerda, Nick Hawes, Tim Rocktaschel, Chris Lu, Jakob Nicolaus Foerster
cs.AI
摘要
基準測試在機器學習演算法的發展中扮演重要角色。例如,強化學習(RL)的研究受到可用環境和基準測試的深刻影響。然而,RL環境傳統上在CPU上運行,限制了它們在典型學術計算中的可擴展性。JAX的最新進展使得更廣泛地使用硬體加速來克服這些計算障礙,實現了大規模並行RL訓練流程和環境。這對多智能體強化學習(MARL)研究尤為有用。首先,在每個環境步驟中必須考慮多個智能體,增加了計算負擔,其次,由於非穩態性、分散式部分可觀察性或其他MARL挑戰,樣本複雜度增加。在本文中,我們提出了JaxMARL,這是第一個結合易用性和GPU加速效率的開源代碼庫,支持大量常用的MARL環境以及流行的基準算法。從實際時間角度來看,我們的實驗表明,相較於現有方法,我們基於JAX的訓練流程每次運行高達12500倍的速度更快。這使得評估高效而徹底,有潛力緩解該領域的評估危機。我們還介紹和評估了SMAX,這是流行的星際爭霸多智能體挑戰的向量化簡化版本,無需運行星際爭霸II遊戲引擎。這不僅實現了GPU加速,還提供了一個更靈活的MARL環境,為自我對弈、元學習和其他未來MARL應用打開了潛力。我們的代碼位於https://github.com/flairox/jaxmarl。
English
Benchmarks play an important role in the development of machine learning
algorithms. For example, research in reinforcement learning (RL) has been
heavily influenced by available environments and benchmarks. However, RL
environments are traditionally run on the CPU, limiting their scalability with
typical academic compute. Recent advancements in JAX have enabled the wider use
of hardware acceleration to overcome these computational hurdles, enabling
massively parallel RL training pipelines and environments. This is particularly
useful for multi-agent reinforcement learning (MARL) research. First of all,
multiple agents must be considered at each environment step, adding
computational burden, and secondly, the sample complexity is increased due to
non-stationarity, decentralised partial observability, or other MARL
challenges. In this paper, we present JaxMARL, the first open-source code base
that combines ease-of-use with GPU enabled efficiency, and supports a large
number of commonly used MARL environments as well as popular baseline
algorithms. When considering wall clock time, our experiments show that per-run
our JAX-based training pipeline is up to 12500x faster than existing
approaches. This enables efficient and thorough evaluations, with the potential
to alleviate the evaluation crisis of the field. We also introduce and
benchmark SMAX, a vectorised, simplified version of the popular StarCraft
Multi-Agent Challenge, which removes the need to run the StarCraft II game
engine. This not only enables GPU acceleration, but also provides a more
flexible MARL environment, unlocking the potential for self-play,
meta-learning, and other future applications in MARL. We provide code at
https://github.com/flairox/jaxmarl.