8000 GitHub - mlaux1/rl-blox: Modular JAX-based toolbox for implementing RL algorithms.
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

mlaux1/rl-blox

Repository files navigation

Code style: black pre-commit

RL-BLOX

This project contains modular implementations of various model-free and model-based RL algorithms and consists of deep neural network-based as well as tabular representation of Q-Values, policies, etc. which can be used interchangeably. The goal of this project is for the authors to learn by reimplementing various RL algorithms and to eventually provide an algorithmic toolbox for research purposes.

Caution

This library is still experimental and under development. Using it will not result in a good user experience. It is not well-documented, it is buggy, its interface is not clearly defined, its most interesting features are in feature branches. We recommend not to use it now. If you are an RL developer and want to collaborate, feel free to contact us.

Design Principles

The implementation of this project follows the following principles:

  1. Algorithms are functions!
  2. Algorithms are implemented in single files.
  3. Policies and values functions are data containers.

Dependencies

  1. Our environment interface is Gymnasium.
  2. We use JAX for everything.
  3. We use Chex to write reliable code.
  4. For optimization algorithms we use Optax.
  5. For probability distributions we use Distrax.
  6. For all neural networks we use Flax NNX.
  7. To save checkpoints we use Orbax.

Installation

git clone git@github.com:mlaux1/rl-blox.git

After cloning the repository, it is recommended to install the library in editable mo 9345 de.

pip install -e .

To be able to run the provided examples use pip install -e '.[examples]'. To install development dependencies, please use pip install -e '.[dev]'. You can install all optional dependencies using pip install -e '.[all]'.

Getting Started

RL-BLOX relies on gymnasium's environment interface. This is an example with the SAC RL algorithm.

import gymnasium as gym

from rl_blox.algorithms.model_free.sac import (
    create_sac_state,
    train_sac,
)

env_name = "Pendulum-v1"
env = gym.make(env_name)
seed = 1
verbose = 1
env = gym.wrappers.RecordEpisodeStatistics(env)

sac_state = create_sac_state(
    env,
    policy_hidden_nodes=[128, 128],
    policy_learning_rate=3e-4,
    q_hidden_nodes=[512, 512],
    q_learning_rate=1e-3,
    seed=seed,
)
sac_result = train_sac(
    env,
    sac_state.policy,
    sac_state.policy_optimizer,
    sac_state.q1,
    sac_state.q1_optimizer,
    sac_state.q2,
    sac_state.q2_optimizer,
    total_timesteps=11_000,
    buffer_size=11_000,
    gamma=0.99,
    learning_starts=5_000,
    verbose=verbose,
)
env.close()
policy, _, q1, _, _, q2, _, _, _ = sac_result

# Do something with the trained policy...

API Documentation

You can build the sphinx documentation with

pip install -e '.[doc]'
cd doc
make html

The HTML documentation will be available under doc/build/html/index.html.

Contributing

If you wish to report bugs, please use the issue tracker. If you would like to contribute to RL-BLOX, just open an issue or a pull request. The target branch for merge requests is the development branch. The development branch will be merged to master for new releases. If you have questions about the software, you should ask them in the discussion section.

The recommended workflow to add a new feature, add documentation, or fix a bug is the following:

  • Push your changes to a branch (e.g. feature/x, doc/y, or fix/z) of your fork of the RL-BLOX repository.
  • Open a pull request to the main branch.

It is forbidden to directly push to the main branch.

Testing

Run the tests with

pip install -e '.[dev]'
pytest

Releases

Semantic Versioning

Semantic versioning must be used, that is, the major version number will be incremented when the API changes in a backwards incompatible way, the minor version will be incremented when new functionality is added in a backwards compatible manner, and the patch version is incremented for bugfixes, documentation, etc.

Funding

This library is currently developed at the Robotics Group of the University of Bremen together with the Robotics Innovation Center of the German Research Center for Artificial Intelligence (DFKI) in Bremen.

About

Modular JAX-based toolbox for implementing RL algorithms.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages

0