git clone https://github.com/Sea-Snell/LMTemplate.git
cd LMTemplate
export PYTHONPATH=${PWD}/src/
Install with conda (cpu, tpu, or gpu) or docker (cpu or gpu only).
install with conda (cpu):
conda env create -f environment.yml
conda activate LMTemplate
install with conda (gpu):
conda env create -f environment.yml
conda activate LMTemplate
python -m pip install --upgrade pip
python -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
conda install pytorch=1.11 cudatoolkit=11.3 -c pytorch
install with conda (tpu):
conda env create -f environment.yml
conda activate LMTemplate
python -m pip install --upgrade pip
python -m pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html