Memory profiler for JAX
Create profiler logs
from jaxprof import JaxProfiler
profiler = JaxProfiler()
def some_jax_code():
...
profiler.capture()
...
or run it in the background
profiler.capture_in_background()
Generate plots from profiler logs
python jaxprof.py --help