8000 GitHub - amitness/tt-scale-flux: Inference-time scaling of Flux beyond denoising steps.
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Inference-time scaling of Flux beyond denoising steps.

License

Notifications You must be signed in to change notification settings

amitness/tt-scale-flux

 
 

Repository files navigation

tt-scale-flux

Simple re-implementation of inference-time scaling Flux.1-Dev as introduced in Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps by Ma et al. We implement the random search strategy to scale the inference compute budget.

Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists.

Updates

🔥 15/02/2025: Support for structured generation with Qwen2.5 has been added (using outlines and pydantic) in this PR.

🔥 15/02/2025: Support to load other pipelines has been added in this PR! Result section has been updated, too.

Getting started

Make sure to install the dependencies: pip install -r requirements. The codebase was tested using a single H100 and two H100s (both 80GB variants).

By default, we use Gemini 2.0 Flash as the verifier. This requires two things:

  • GEMINI_API_KEY (obtain it from here).
  • google-genai Python library.

Now, fire up:

GEMINI_API_KEY=... python main.py --prompt="a tiny astronaut hatching from an egg on the moon" --num_prompts=None

If you want to use from the data-is-better-together/open-image-preferences-v1-binarized dataset, you can just run:

GEMINI_API_KEY=... python main.py

After this is done executing, you should expect a folder named output with the following structure:

Click to expand
output/flux.1-dev/gemini/overall_score/20250215_141308$ tree 
.
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@1_s@1039315023.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@1_s@77559330.json
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@1_s@77559330.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1046091514.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1388753168.json
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1388753168.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1527774201.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@2_s@1632020675.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@1648932110.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@2033640094.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@2056028012.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@510118118.json
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@510118118.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@544879571.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@722867022.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@951309743.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@3_s@973580742.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1169137714.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1271234848.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1327836930.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1589777351.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1592595351.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_
8000
press_conference_to_journ_hash@b9094b65_i@4_s@1654773907.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1901647417.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@1916603945.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@209448213.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@2104826872.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@532500803.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@710122236.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@744797903.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@754998363.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@823891989.png
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@836183088.json
└── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@836183088.png

Each JSON file should look like so:

Click to expand
{
    "prompt": "Photo of an athlete cat explaining it\u2019s latest scandal at a press conference to journalists.",
    "search_round": 4,
    "num_noises": 16,
    "best_noise_seed": 836183088,
    "best_score": {
        "score": 9.5,
        "explanation": "Considering all aspects, especially the high level of accuracy, creativity, and visual appeal, the overall score reflects the model's excellent performance in generating this image."
    },
    "choice_of_metric": "overall_score",
    "best_img_path": "output/gemini/overall_score/20250213_034054/prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@4_s@836183088.png"
}

To limit the number of prompts, specify --num_prompts. By default, we use 2 prompts. Specify "--num_prompts=all" to use all.

The output directory should also contain a config.json, looking like so:

Click to expand
{
  "max_new_tokens": 300,
  "use_low_gpu_vram": false,
  "choice_of_metric": "overall_score",
  "verifier_to_use": "gemini",
  "torch_dtype": "bf16",
  "height": 1024,
  "width": 1024,
  "max_sequence_length": 512,
  "guidance_scale": 3.5,
  "num_inference_steps": 50,
  "pipeline_config_path": "configs/flux.1_dev.json",
  "search_rounds": 4,
  "prompt": "an anime illustration of a wiener schnitzel",
  "num_prompts": null
}

Once the results are generated, process the results by running:

python process_results.py --path=path_to_the_output_dir

This should output a collage of the best images generated in each search round, grouped by the same prompt.

Controlling the pipeline checkpoint and __call__() args

This is controlled via the --pipeline_config_path CLI args. By default, it uses configs/flux.1_dev.json. You can either modify this one or create your own JSON file to experiment with different pipelines. We provide some predefined configs for Flux.1-Dev, PixArt-Sigma, SDXL, and SD v1.5 in the configs directory.

The above-mentioned pipelines are already supported. To add your own, you need to make modifications to:

Controlling the "scale"

By default, we use 4 search_rounds and start with a noise pool size of 2. Each search round scales up the pool size like so: 2 ** current_seach_round (with indexing starting from 1). This is where the "scale" in inference-time scaling comes from. You can increase the compute budget by specifying a larger search_rounds.

For each search round, we serialize the images and best datapoint (characterized by the best eval score) in a JSON file.

For other supported CLI args, run python main.py -h.

Controlling the verifier

If you don't want to use Gemini, you can use Qwen2.5 as an option. Simply specify --verifier_to_use=qwen for this. Below is a complete command that uses SDXL-base:

python main.py --verifier_to_use="qwen" --pipeline_config_path=configs/sdxl.json --prompt="Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists." --num_prompts=None --search_rounds=6
Sample search JSON
{
    "prompt": "Photo of an athlete cat explaining it\u2019s latest scandal at a press conference to journalists.",
    "search_round": 6,
    "num_noises": 64,
    "best_noise_seed": 576119280,
    "best_score": {
        "score": 9.2,
        "explanation": "The image excels in multiple aspects, combining imagery, creativity, visual quality, and thematic resonance."
    },
    "choice_of_metric": "overall_score",
    "best_img_path": "output/sdxl-base/qwen/overall_score/20250216_100512/prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@6_s@576119280.png"
}
  
Results
Result
scandal_cat
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists.
  

Important

This setup was tested on 2 H100s. If you want to do this on a single GPU, specify --use_low_gpu_vram.

You can also bring in your own verifier by implementing a so-called Verifier class following the structure of either of GeminiVerifier or QwenVerifier. You will then have to make adjustments to the following places:

By default, we use "overall_score" as the metric to obtain the best samples in each search round. You can change it by specifying --choice_of_metric. Supported values are:

  • "accuracy_to_prompt"
  • "creativity_and_originality"
  • "visual_quality_and_realism"
  • "consistency_and_cohesion"
  • "emotional_or_thematic_resonance"
  • "overall_score"

If you're experimenting with a new verifier, you can relax these choices.

The verifier prompt that is used during grading/verification is specified in this file. The prompt is a slightly modified version of the one specified in the Figure 16 of the paper (Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps). You are welcome to experiment with a different prompt.

More results

Click to expand
Result
Manga
a bustling manga street, devoid of vehicles, detailed with vibrant colors and dynamic
line work, characters in the background adding life and movement, under a soft golden
hour light, with rich textures and a lively atmosphere, high resolution, sharp focus
Alice
Alice in a vibrant, dreamlike digital painting inside the Nemo Nautilus submarine.
wiener_schnitzel
an anime illustration of a wiener schnitzel
  

Both searches were performed with "overall_score" as the metric. Below is example, presenting a comparison between the outputs of different metrics -- "overall_score" vs. "emotional_or_thematic_resonance" for the prompt: "a tiny astronaut hatching from an egg on the moon":

Click to expand
Metric Result
"overall_score" overall
"emotional_or_thematic_resonance" Alicet

Results from other models

PixArt-Sigma
Result
saxophone
A person playing saxophone.
scandal_cat
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists.

SD v1.5
Result
saxophone
a photo of an astronaut riding a horse on mars
  
SDXL-base
Result
scandal_cat
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists.
  

Acknowledgements

  • Thanks to Willis Ma for all the guidance and pair-coding.
  • Thanks to Hugging Face for supporting the compute.
  • Thanks to Google for providing Gemini credits.

About

Inference-time scaling of Flux beyond denoising steps.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%
0