A reinforcement learning project where agents learn to fight each other in a physics-based arena using JAX and MuJoCo.
This project implements a training pipeline for AI agents that fight each other in a physics-based arena. The agents are trained using the Soft Actor-Critic (SAC) algorithm with JAX acceleration. The environment is built using MuJoCo physics engine, featuring two four-legged agents attempting to push each other off a 2.5x2.5x1.5 platform.
- JAX with CUDA support
- MuJoCo and MuJoCo MJX
- Weights & Biases or Tensorboard (optional, for experiment tracking)
- See
requirements.txt
for full dependencies
# Clone the repository
git clone https://github.com/r-aristov/arena.git
cd arena
# Install other dependencies
pip install -r requirements.txt
To start training:
python sac_my_flax.py
To watch trained agents fight:
# starts with pretrained agents if no parameters specified
python arena.py
# Or specify custom agents:
python arena.py --agent0="path/to/your/agent0" --agent1="path/to/your/agent1"
Key training parameters can be adjusted in sac_my_flax.py
:
- Batch sizes
- Buffer size
- Learning rates
- Training steps
- Self-play parameters
- SAC metaparameters
sac_my_flax.py
- Startup script with pipeline configurationarena.py
- Environment implementation and visualizationworker.py
- Worker thread, responsible for simulation, observation gathering and agent validationbuffer.py
- Replay buffer implementationtrainer.py
- Trainer thread, takes replays from buffer and trains q-network and policy networkagent.py
- Simple policy network implemented in Flaxq_network.py
- Simple q-network implemented in Flaxrunning_mean_std_jax.py
- Running mean jax implementation for observation normalizationmodels/
- MuJoCo model definitionslegacy-agents/
- Pretrained agents to use as reference and validationobs-norm/
- Precomputed mean and var values for observation normalization