8000 GitHub - Lakonik/GMFlow: [ICML 2025] Gaussian Mixture Flow Matching Models (GMFlow)
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Lakonik/GMFlow

Repository files navigation

Gaussian Mixture Flow Matching Models (GMFlow)

Official PyTorch implementation of the paper:

Gaussian Mixture Flow Matching Models [arXiv]
In ICML 2025
Hansheng Chen1, Kai Zhang2, Hao Tan2, Zexiang Xu3, Fujun Luan2, Leonidas Guibas1, Gordon Wetzstein1, Sai Bi2
1Stanford University, 2Adobe Research, 3Hillbot

Highlights

GMFlow is an extension of diffusion/flow matching models.

  • Gaussian Mixture Output: GMFlow expands the network's output layer to predict a Gaussian mixture (GM) distribution of flow velocity. Standard diffusion/flow matching models are special cases of GMFlow with a single Gaussian component.

  • Precise Few-Step Sampling: GMFlow introduces novel GM-SDE and GM-ODE solvers that leverage analytic denoising distributions and velocity fields for precise few-step sampling.

  • Improved Classifier-Free Guidance (CFG): GMFlow introduces a probabilistic guidance scheme that mitigates the over-saturation issues of CFG and improves image generation quality.

  • Efficiency: GMFlow maintains similar training and inference costs to standard diffusion/flow matching models.

Installation

The code has been tested in the environment described as follows:

Other dependencies can be installed via pip install -r requirements.txt.

An example of installation commands is shown below (assuming you have already installed CUDA Toolkit and configured the environment variables):

# Create conda environment
conda create -y -n gmflow python=3.10 numpy=1.26 ninja
conda activate gmflow

# Goto https://pytorch.org/ to select the appropriate version
pip install torch torchvision

# Clone this repo and install other dependencies
git clone https://github.com/Lakonik/GMFlow && cd gmflow
pip install -r requirements.txt

This codebase may work on Windows systems, but it has not been tested extensively.

GM-DiT ImageNet 256x256

Inference

We provide a Diffusers pipeline for easy inference. The following code demonstrates how to sample images from the pretrained GM-DiT model using the GM-ODE 2 solver and the GM-SDE 2 solver.

import torch
from huggingface_hub import snapshot_download
from lib.models.diffusions.schedulers import FlowEulerODEScheduler, GMFlowSDEScheduler
from lib.pipelines.gmdit_pipeline import GMDiTPipeline

# Currently the pipeline can only load local checkpoints, so we need to download the checkpoint first
ckpt = snapshot_download(repo_id='Lakonik/gmflow_imagenet_k8_ema')
pipe = GMDiTPipeline.from_pretrained(ckpt, variant='bf16', torch_dtype=torch.bfloat16)
pipe = pipe.to('cuda')

# Pick words that exist in ImageNet
words = ['jay', 'magpie']
class_ids = pipe.get_label_ids(words)

# Sample using GM-ODE 2 solver
pipe.scheduler = FlowEulerODEScheduler.from_config(pipe.scheduler.config)
generator = torch.manual_seed(42)
output = pipe(
    class_labels=class_ids,
    guidance_scale=0.45,
    num_inference_steps=32,
    num_inference_substeps=4,
    output_mode='mean',
    order=2,
    generator=generator)
for i, (word, image) in enumerate(zip(words, output.images)):
    image.save(f'{i:03d}_{word}_gmode2_step32.png')

# Sample using GM-SDE 2 solver (the first run may be slow due to CUDA compilation)
pipe.scheduler = GMFlowSDEScheduler.from_config(pipe.scheduler.config)
generator = torch.manual_seed(42)
output = pipe(
    class_labels=class_ids,
    guidance_scale=0.45,
    num_inference_steps=32,
    num_inference_substeps=1,
    output_mode='sample',
    order=2,
    generator=generator)
for i, (word, image) in enumerate(zip(words, output.images)):
    image.save(f'{i:03d}_{word}_gmsde2_step32.png')

The results will be saved under the current directory.

Before Training: Data Preparation

Download ILSVRC2012_img_train.tar and the metadata. Extract the downloaded archives according to the following folder tree (or use symlinks).

./
├── configs/
├── data/
│   └── imagenet/
│       ├── train/
│       │   ├── n01440764/
│       │   │   ├── n01440764_10026.JPEG
│       │   │   ├── n01440764_10027.JPEG
│       │   │   …
│       │   ├── n01443537/
│       │   …
│       ├── imagenet1000_clsidx_to_labels.txt
│       ├── train.txt
|       …
├── lib/
├── tools/
…

Run the following command to prepare the ImageNet dataset using DDP on 1 node with 8 GPUs

torchrun --nnodes=1 --nproc_per_node=8 tools/prepare_imagenet_dit.py

Training

Run the following command to train the model using DDP on 1 node with 8 GPUs:

torchrun --nnodes=1 --nproc_per_node=8 tools/train.py configs/gmflow_imagenet_k8_8gpus.py --launcher pytorch --diff_seed

Alternatively, you can start single-node DDP training from a Python script:

python train.py configs/gmflow_imagenet_k8_8gpus.py --gpu-ids 0 1 2 3 4 5 6 7

The config in gmflow_imagenet_k8_8gpus.py specifies a training batch size of 512 images per GPU and an inference batch size of 125 images per GPU. Training requires 32GB of VRAM per GPU, and the validation step requires an additional 8GB of VRAM per GPU. If you are using 32GB GPUs, you can disable the validation step by adding the --no-validate flag to the training command. Alternatively, you can also edit the config file to adjust the batch sizes.

By default, checkpoints will be saved into checkpoints/, logs will be saved into work_dirs/, and sampled images will be saved into viz/.

Resuming Training

If existing checkpoints are found, the training will automatically resume from the latest checkpoint.

Tensorboard

The logs can be plotted using Tensorboard. Run the following command to start Tensorboard:

tensorboard --logdir work_dirs/

Evaluation

After training, to conduct a complete evaluation of the model under varying guidance scales, run the following command to start DDP evaluation on 1 node with 8 GPUs:

torchrun --nnodes=1 --nproc_per_node=8 tools/test.py configs/gmflow_imagenet_k8_test.py checkpoints/gmflow_imagenet_k8_8gpus/latest.pth --launcher pytorch --diff_seed

Alternatively, you can start single-node DDP evaluation from a Python script:

python test.py configs/gmflow_imagenet_k8_test.py checkpoints/gmflow_imagenet_k8_8gpus/latest.pth --gpu-ids 0 1 2 3 4 5 6 7

The config in gmflow_imagenet_k8_test.py specifies an inference batch size of 125 images per GPU, which requires 35GB of VRAM per GPU. You can edit the config file to adjust the batch size.

The evaluation results will be saved to where the checkpoint is located, and the sampled images will be saved into viz/.

Toy Model on 2D Checkerboard

We provide a minimal GMFlow trainer in train_toymodel.py for the toy model on the 2D checkerboard dataset. Run the following command to train the model:

python train_toymodel.py -k 64

This minimal trainer does not support transition loss and EMA. To reproduce the results in the paper, you can use the following command to start the full trainer:

python train.py configs/gmflow_checkerboard_k64.py --gpu-ids 0

This full trainer is not optimized for the simple 2D checkerboard dataset, so GPU usage may be inefficient.

Essential Code

  • Training
    • train_toymodel.py: A simplified training script for the 2D checkerboard experiment.
    • gmflow.py: The forward_train method contains the full training loop.
  • Inference
    • gmdit_pipeline.py: Full sampling code in the style of Diffusers.
    • gmflow.py: The forward_test method contains the same full sampling loop.
  • Network
  • GM math operations
    • gmflow_ops: A complete library of analytical operations for GM and Gaussian distributions.

Citation

@inproceedings{gmflow,
  title={Gaussian Mixture Flow Matching Models},
  author={Hansheng Chen and Kai Zhang and Hao Tan and Zexiang Xu and Fujun Luan and Leonidas Guibas and Gordon Wetzstein and Sai Bi},
  booktitle={ICML},
  year={2025},
}

Releases

No releases published

Packages

No packages published

Languages

0