Side-by-side implementations of two different jax frameworks (haiku and flax) and pytorch on simple deep learning training and inference tasks. Currently implements MNIST, FashionMNIST, CIFAR10, and CIFAR100 training on MLPs and CNNs, and mult-host model parallel LLM inference for all OPT, T5, T5v1.1, UL2, GPT2, and GPTJ models.
git clone https://github.com/Sea-Snell/jax_v_pytorch.git
cd jax_v_pytorch
Install with conda (cpu or gpu) or docker (gpu only).
install with conda (cpu):
conda env create -f environment.yml
conda activate jax_v_torch
install with conda (gpu):
conda env create -f environment.yml
conda activate jax_v_torch
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
install with conda (tpu):
conda env create -f environment.yml
conda activate jax_v_torch
pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
install with docker (gpu only):
- install docker and docker compose
- make sure to install nvidia-docker2 and NVIDIA Container Toolkit.
docker compose build
docker compose run jax_v_torch
And then in the new container shell that pops up:
cd jax_v_torch
- navigate to any subfolder (for example
cd cifar_mnist/haiku/
) python main.py
Feel free to edit any configs in main.py
. You can do this by either directly editing the file or with command line arguments. The config framework is micro-config.
All implementations are meant to be identical modulo framework specific differences.
cifar_mnist/
implements MNIST/FashionMNIST/CIFAR10/CIFAR100 training on both single and multiple devices (data parallel).pytorch/
implemented in pytorch, single deviceflax/
implemented in flax, single deviceflax_pmap/
implemented in flax, multi devicehaiku/
implemented in haiku, single devicehaiku_pmap/
implemented in haiku, multi device
lm_inference/
implements model-parallel, multi-host LLM inference for all OPT, T5, T5v1.1, UL2, GPT2, and GPTJ models.flax/
implemented in flax with Transformers, multi device, multi host