8000 GitHub - dirmeier/diffusion-transformer: A diffusion transformer implementation in Flax
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

dirmeier/diffusion-transformer

Repository files navigation

DiT

status ci

A diffusion transformer implementation in Flax

About

This repository implements the diffusion transformer (DiT) architecture which has been proposed in Scalable Diffusion Models with Transformers in Flax. We test the architecture using the EDM parameterization introduced in Elucidating the Design Space of Diffusion-Based Generative Models.

Note

The architecture does not follow the original implementation exactly. For instance, we don't use label conditioning and just use a learnable positional encoding for the patches. The rest is fairly similar.

Example usage

The experiments folder contains a use case that trains an EDM (diffusion model) on MNIST-SDF. To train the model, just execute:

cd experiments/mnist_sdf
python main.py
  --config=config.py
  --workdir=<dir>
  (--usewand)

Below are some samples drawn from the EDM using a DiT-B after training 100 epochs. In my experiments, the UNet still works better, but that might just be how the hyperparameters have been chosen.

Installation

To install the latest GitHub , just call the following on the command line:

pip install git+https://github.com/dirmeier/diffusion-transformer@<RELEASE>

Author

Simon Dirmeier sfyrbnd @ pm me

About

A diffusion transformer implementation in Flax

Topics

Resources

License

Stars

Watchers

Forks

Languages

0