Finetrainers v0.2.0 π§ͺ
New trainers
- Channel concatenated control conditioning for Wan2.1 and CogView4
Wan image-conditioning on T2V model |
---|
wan-t2v-image-conditioning.mp4 |
CogView4 control conditioning (Edit + Canny) |
The training involves adding extra input channels to the patch embedding layer (referred to as the "control injection" layer in finetrainers), to mix conditioning features into the latent stream. This architecture choice is very common and has been seen before in many models - CogVideoX-I2V, HunyuanVideo-I2V, Alibaba's Fun Control models, etc. Due to the popularity and simplicity in the architecture choice, it is a good choice to support standalone as a trainer.
import torch
from diffusers import CogView4Pipeline
from diffusers.utils import load_image
from finetrainers.models.utils import _expand_linear_with_zeroed_weights
from finetrainers.patches import load_lora_weights
from finetrainers.patches.dependencies.diffusers.control import control_channel_concat
dtype = torch.bfloat16
device = torch.device("cuda")
generator = torch.Generator().manual_seed(0)
pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=dtype)
in_channels = pipe.transformer.config.in_channels
patch_channels = pipe.transformer.patch_embed.proj.in_features
pipe.transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(pipe.transformer.patch_embed.proj, new_in_features=2 * patch_channels)
load_lora_weights(pipe, "finetrainers/CogView4-6B-Edit-LoRA-v0", "cogview4-lora")
pipe.set_adapters("cogview4-lora", 0.9)
pipe.to(device)
prompt = "Make the image look like it's from an ancient Egyptian mural."
control_image = load_image("examples/training/control/cogview4/omni_edit/validation_dataset/0.png")
height, width = 1024, 1024
with torch.no_grad():
latents = pipe.prepare_latents(1, in_channels, height, width, dtype, device, generator)
control_image = pipe.image_processor.preprocess(control_image, height=height, width=width)
control_image = control_image.to(device=device, dtype=dtype)
control_latents = pipe.vae.encode(control_image).latent_dist.sample(generator=generator)
control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
with control_channel_concat(pipe.transformer, ["hidden_states"], [control_latents], dims=[1]):
image = pipe(prompt, latents=latents, num_inference_steps=30, generator=generator).images[0]
image.save("output.png")
New models supported
- FLUX.1-dev
- Wan2.1 I2V
Find example training configs here.
Attention
Support for multiple different attention providers for training and inference - Pytorch native, flash-attn
, sageattention
, xformers
, flex
. See docs for more details.
Other major changes
- Better regional compilation support
What's Changed
- Update project showcase by @a-r-r-o-w in #355
- Flux ModelSpec by @a-r-r-o-w in #358
- Pytorch regional compilation by @a-r-r-o-w in #361
- [Doc] Fix a typo of
flux.md
by @DarkSharpness in #363 - [Fix] Raise ValueError proactively before some confusing errors occur due to wrong input image size by @DarkSharpness in #364
- Improve webdataset caption loading by @a-r-r-o-w in #365
- fix string matching for blocks by @neph1 in #360
- Bump ruff version by @a-r-r-o-w in #367
- Channel-concatenated Control Trainer by @a-r-r-o-w in #310
- Fix #352: FSDP2 argument typo by @a-r-r-o-w in #370
- Support Wan I2V; Better regional compile support by @a-r-r-o-w in #375
- chore: save all weights with step-specific directories by @Leojc in #379
- fix: lora loading for final validation by @Leojc in #382
- Fix posterior computation and control tests by @a-r-r-o-w in #384
- Support flash/flex/xformers/sage attention by @a-r-r-o-w in #377
New Contributors
- @DarkSharpness made their first contribution in #363
Full Changelog: v0.1.0...v0.2.0