Download checkpoints
- Download checkpoint.pt from https://drive.google.com/file/d/1U9bVcLXmjtP1XW8gLwSG4-ILgvdedabK/view?usp=sharing Put to:
mlmi-lsgm/MNIST/
- Download checkpoint_nll.pt from https://drive.google.com/file/d/1NEbX-nSWDixtS8LoB-RmyKNW8td_aXd1/view?usp=sharing Put to:
mlmi-lsgm/MNIST/vada
- Download checkpoint.pt from https://drive.google.com/file/d/1zMAE9S0AmDLL8P8Qh5a8JTZGQFsbfu60/view?usp=sharing Put to:
mlmi-lsgm/save_dir/vae
Setup environment
cd mlmi-lsgm
pip install -r requirements.txt
pip install blobfile
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
Train VAE
bash bash/unix/train_vae.sh
Train LSGM (using pretrained VAE)
bash bash/unix/train_vada_from_vae.sh
Train LSGM from existing checkpoint
bash bash/unix/train_vada_from_vada.sh