A diffusion transformer implementation in Flax
This repository implements the diffusion transformer (DiT) architecture which has been proposed in Scalable Diffusion Models with Transformers in Flax. We test the architecture using the EDM parameterization introduced in Elucidating the Design Space of Diffusion-Based Generative Models.
Note
The architecture does not follow the original implementation exactly. For instance, we don't use label conditioning and just use a learnable positional encoding for the patches. The rest is fairly similar.
The experiments
folder contains a use case that trains an EDM (diffusion model) on MNIST-SDF.
To train the model, just execute:
cd experiments/mnist_sdf
python main.py
--config=config.py
--workdir=<dir>
(--usewand)
Below are some samples drawn from the EDM using a DiT-B after training 100 epochs. In my experiments, the UNet still works better, but that might just be how the hyperparameters have been chosen.
To install the latest GitHub , just call the following on the command line:
pip install git+https://github.com/dirmeier/diffusion-transformer@<RELEASE>
Simon Dirmeier sfyrbnd @ pm me