8000 GitHub - fl0pedro/graphax: Cross-Country Elimination in JAX
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fl0pedro/graphax

 
 

Repository files navigation

Graphax

alt text

Graphax

This package contains the implementation of the sparse cross-country elimination method for Automatic Differentiation (AD).

What the hell is Cross-Country Elimination?

Cross-country elimination is a automatic differentiation (AD) technique that enables AD algorithm design with respect to relevant quantities such as computational and memory cost. It allows for the design of tailored AD algorithms for a given function we wish to differentiate. AlphaGrad is an example of automated AD algorithm discovery using Reinforcement Learning.

Another nice feature of cross-country elimination is that it can exploit the inherent static sparsity structure of Jacobians.

Installation

The package requires has the following dependencies: - jax - numpy - scipy - matplotlib The package itself is to be installed by running pip install -e . in the root directory.

Usage

The package exposes a primitive called jacve which is the equivalent of jax.jacfwd and jax.jacrev. It provides an additional keyword order which is used to pass the elimination order for cross-country elimination. It has two default modes fwd and rev which implement forward-mode and reverse-mode AD. It is fully compatible with jax.jit, jax.vmap and even jax.jacfwd and jax.jacrev. Example use that enables the use cross country elimination: graphax.jacve(f, order=[1,3,2,4], argnums=(0,1,2,3))(1., 1., 2., 7.).

Projects using Graphax

  • AlphaGrad
  • Synaptax

About

Cross-Country Elimination in JAX

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%
0