Simple re-implementation of inference-time scaling Flux.1-Dev as introduced in Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps by Ma et al. We implement the random search strategy to scale the inference compute budget.
Updates
🔥 15/02/2025: Support for structured generation with Qwen2.5 has been added (using outlines
and pydantic
) in this PR.
🔥 15/02/2025: Support to load other pipelines has been added in this PR! Result section has been updated, too.
Make sure to install the dependencies: pip install -r requirements
. The codebase was tested using a single H100 and two H100s (both 80GB variants).
By default, we use Gemini 2.0 Flash as the verifier. This requires two things:
Now, fire up:
GEMINI_API_KEY=... python main.py --prompt="a tiny astronaut hatching from an egg on the moon" --num_prompts=None
If you want to use from the data-is-better-together/open-image-preferences-v1-binarized dataset, you can just run:
GEMINI_API_KEY=... python main.py
After this is done executing, you should expect a folder named output
with the following structure:
Click to expand
output/flux.1-dev/gemini/overall_score/20250215_141308$ tree
.
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@1_s@1039315023.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@1_s@77559330.json
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@1_s@77559330.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1046091514.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1388753168.json
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1388753168.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1527774201.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1632020675.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@1648932110.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@2033640094.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@2056028012.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@510118118.json
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@510118118.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@544879571.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@722867022.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@951309743.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@973580742.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1169137714.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1271234848.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1327836930.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1589777351.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1592595351.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_
8000
press_conference_to_journ_hash@b9094b65_i@4_s@1654773907.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1901647417.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1916603945.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@209448213.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@2104826872.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@532500803.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@710122236.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@744797903.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@754998363.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@823891989.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@836183088.json
└── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@836183088.png
Each JSON file should look like so:
Click to expand
{
"prompt": "Photo of an athlete cat explaining it\u2019s latest scandal at a press conference to journalists.",
"search_round": 4,
"num_noises": 16,
"best_noise_seed": 836183088,
"best_score": {
"score": 9.5,
"explanation": "Considering all aspects, especially the high level of accuracy, creativity, and visual appeal, the overall score reflects the model's excellent performance in generating this image."
},
"choice_of_metric": "overall_score",
"best_img_path": "output/gemini/overall_score/20250213_034054/prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@836183088.png"
}
To limit the number of prompts, specify --num_prompts
. By default, we use 2 prompts. Specify "--num_prompts=all" to use all.
The output directory should also contain a config.json
, looking like so:
Click to expand
{
"max_new_tokens": 300,
"use_low_gpu_vram": false,
"choice_of_metric": "overall_score",
"verifier_to_use": "gemini",
"torch_dtype": "bf16",
"height": 1024,
"width": 1024,
"max_sequence_length": 512,
"guidance_scale": 3.5,
"num_inference_steps": 50,
"pipeline_config_path": "configs/flux.1_dev.json",
"search_rounds": 4,
"prompt": "an anime illustration of a wiener schnitzel",
"num_prompts": null
}
Once the results are generated, process the results by running:
python process_results.py --path=path_to_the_output_dir
This should output a collage of the best images generated in each search round, grouped by the same prompt.
This is controlled via the --pipeline_config_path
CLI args. By default, it uses configs/flux.1_dev.json
. You can either modify this one or create your own JSON file to experiment with different pipelines. We provide some predefined configs for Flux.1-Dev, PixArt-Sigma, SDXL, and SD v1.5 in the configs
directory.
The above-mentioned pipelines are already supported. To add your own, you need to make modifications to:
By default, we use 4 search_rounds
and start with a noise pool size of 2. Each search round scales up the pool size like so: 2 ** current_seach_round
(with indexing starting from 1). This is where the "scale" in inference-time scaling comes from. You can increase the compute budget by specifying a larger search_rounds
.
For each search round, we serialize the images and best datapoint (characterized by the best eval score) in a JSON file.
For other supported CLI args, run python main.py -h
.
If you don't want to use Gemini, you can use Qwen2.5 as an option. Simply specify --verifier_to_use=qwen
for this. Below is a
complete command that uses SDXL-base:
python main.py --verifier_to_use="qwen" --pipeline_config_path=configs/sdxl.json --prompt="Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists." --num_prompts=None --search_rounds=6
Sample search JSON
{
"prompt": "Photo of an athlete cat explaining it\u2019s latest scandal at a press conference to journalists.",
"search_round": 6,
"num_noises": 64,
"best_noise_seed": 576119280,
"best_score": {
"score": 9.2,
"explanation": "The image excels in multiple aspects, combining imagery, creativity, visual quality, and thematic resonance."
},
"choice_of_metric": "overall_score",
"best_img_path": "output/sdxl-base/qwen/overall_score/20250216_100512/prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@6_s@576119280.png"
}
Results
Result |
---|
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists. |
Important
This setup was tested on 2 H100s. If you want to do this on a single GPU, specify --use_low_gpu_vram
.
You can also bring in your own verifier by implementing a so-called Verifier
class following the structure of either of GeminiVerifier
or QwenVerifier
. You will then have to make adjustments to the following places:
By default, we use "overall_score" as the metric to obtain the best samples in each search round. You can change it by specifying --choice_of_metric
. Supported values are:
- "accuracy_to_prompt"
- "creativity_and_originality"
- "visual_quality_and_realism"
- "consistency_and_cohesion"
- "emotional_or_thematic_resonance"
- "overall_score"
If you're experimenting with a new verifier, you can relax these choices.
The verifier prompt that is used during grading/verification is specified in this file. The prompt is a slightly modified version of the one specified in the Figure 16 of the paper (Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps). You are welcome to experiment with a different prompt.
Click to expand
Both searches were performed with "overall_score" as the metric. Below is example, presenting a comparison between the outputs of different metrics -- "overall_score" vs. "emotional_or_thematic_resonance" for the prompt: "a tiny astronaut hatching from an egg on the moon":
PixArt-Sigma
Result |
---|
A person playing saxophone. |
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists. |
SDXL-base
Result |
---|
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists. |
- Thanks to Willis Ma for all the guidance and pair-coding.
- Thanks to Hugging Face for supporting the compute.
- Thanks to Google for providing Gemini credits.