This repository contains the PyTorch implementation of the Kalman Variational Autoencoder (K-VAE), based on the paper "A Disentangled Recognition and Nonlinear Dynamics Model for Unsupervised Learning" (arXiv:1710.05741). It is a framework for unsupervised learning using a disentangled recognition and nonlinear dynamics model.
Follow these steps to set up the environment, install dependencies, and run the training and evaluation scripts.
Clone the repository with submodules:
git clone --recursive https://github.com/nkgvl/kalman-vae.git
cd kalman-vae
Create a new Conda environment and install the required packages:
conda create --name kvae-env python=3.11
conda activate kvae-env
# Install dependencies from conda-forge
conda install -c conda-forge opencv pygame pymunk
# Install other specific dependencies
conda install matplotlib~=3.8.0 numpy~=1.26.0 pandas~=2.1.1 Pillow~=10.0.1 tqdm~=4.65.0 wandb~=0.15.12
# For PyTorch installation, refer to the official website to select the appropriate version and CUDA support
# Visit https://pytorch.org for instructions
Install the K-VAE package using pip:
pip install .
Modify examples/run_training.sh
and run the training script:
cd examples
bash run_training.sh
After training, modify examples/run_evaluation.sh
and run the evaluation script to assess performance:
cd examples
bash run_evaluation.sh --checkpoint_dir [YOUR_CHECKPOINT_DIR] --epoch [EPOCH_NUMBER]
Evaluation videos and performance tables will be saved in the videos/
and tables/
directories under the specified checkpoint directory. For an example of the output, see the evaluation video here:
idx_2_mask_length_30.mp4
After completing the setup, you can use the K-VAE model for your research and experiments. Feel free to modify the training and evaluation scripts to explore different configurations.
- This implementation is inspired by the original paper "A Disentangled Recognition and Nonlinear Dynamics Model for Unsupervised Learning" and its original implementation.
- The dataset generation uses code from this repository.
This project is licensed under the MIT License - see the LICENSE file for details.