This package contains the implementation of the sparse cross-country elimination method for Automatic Differentiation (AD).
Cross-country elimination is a automatic differentiation (AD) technique
that enables AD algorithm design with respect to relevant quantities
such as computational and memory cost. It allows for the design of
tailored AD algorithms for a given function we wish to differentiate.
AlphaGrad
is an example of automated AD algorithm discovery using
Reinforcement Learning.
Another nice feature of cross-country elimination is that it can exploit the inherent static sparsity structure of Jacobians.
The package requires has the following dependencies: - jax - numpy -
scipy - matplotlib The package itself is to be installed by running
pip install -e .
in the root directory.
The package exposes a primitive called jacve
which is the equivalent
of jax.jacfwd
and jax.jacrev
. It provides an additional keyword
order
which is used to pass the elimination order for cross-country
elimination. It has two default modes fwd
and rev
which
implement forward-mode and reverse-mode AD. It is fully compatible with
jax.jit
, jax.vmap
and even jax.jacfwd
and jax.jacrev
.
Example use that enables the use cross country elimination:
graphax.jacve(f, order=[1,3,2,4], argnums=(0,1,2,3))(1., 1., 2., 7.)
.
- AlphaGrad
- Synaptax