You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue tracks performance burndown of the 3d autoencoder for the Wan2.1 video generation pipeline.
This model differs from the SDXL VAE primarily in that it operates on slices of a video latent input, looping over the "frame" dimension and processing each "latent frame" separately. We also see significantly different dispatches formed (A lot more Conv3d, different shapes). The target is also in bf16 precision. There is more work to be done on optimizing this export -- currently, the export process unrolls the loop over video frame slices into a static number of repetitions matching the number of frames, but we should probably emit an entrypoint for processing a single frame, and another to perform initialization / scf.for loop over frames / postprocess with a dynamic number of input frames. It's difficult to accurately emit this with the turbine dynamo export stack.
That being said, I have run benchmarks on a target configuration (512x512 output, 81 frames encode, 21 frames decode) and have preliminary results for VAE encode which follow:
Benchmark command and output (note: latencies are affected by runtime tracing):
HIP_VISIBLE_DEVICES=1 IREE_PY_RUNTIME=tracy iree-benchmark-module --module=wan2_1_vae_512x512_gfx942.vmfb --input=@vae_encode_input.npy --function=encode --device=hip://0 --parameters=model=wan2_1_vae_bf16.irpa --benchmark_repetitions=3
-- Using Tracy runtime (IREE_PY_RUNTIME=tracy)
2025-05-14T16:57:17+00:00
Running /home/eagarvey/shark-ai/.venv/lib/python3.12/site-packages/iree/_runtime_libs_tracy/iree-benchmark-module
Run on (128 X 3762.99 MHz CPU s)
CPU Caches:
L1 Data 32 KiB (x128)
L1 Instruction 32 KiB (x128)
L2 Unified 1024 KiB (x128)
L3 Unified 32768 KiB (x16)
Load Average: 3.68, 3.89, 5.81
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
--------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations UserCounters...
--------------------------------------------------------------------------------------------------
BM_encode/process_time/real_time 2595 ms 4572 ms 1 items_per_second=0.385282/s
BM_encode/process_time/real_time 2597 ms 4651 ms 1 items_per_second=0.385042/s
BM_encode/process_time/real_time 2599 ms 4709 ms 1 items_per_second=0.384763/s
BM_encode/process_time/real_time_mean 2597 ms 4644 ms 3 items_per_second=0.385029/s
BM_encode/process_time/real_time_median 2597 ms 4651 ms 3 items_per_second=0.385042/s
BM_encode/process_time/real_time_stddev 1.76 ms 68.4 ms 3 items_per_second=260.177u/s
BM_encode/process_time/real_time_cv 0.07 % 1.47 % 3 items_per_second=0.07%
This is the nn.module we are exporting through iree-turbine aot: orig_vae.py#L506
This is the export entrypoint for VAE: export.py#L268-L271
Line 270 instantiates the nn.module and sample inputs, and Line 271 feeds them into the generalized export function.
Currently, the attention dispatch in VAE decode is performing below expectations -- opinions are requested as to whether this should be improved in the compiler or if we should try to export better attention shapes AOT. @Groverkss
Uh oh!
There was an error while loading. Please reload this page.
This issue tracks performance burndown of the 3d autoencoder for the Wan2.1 video generation pipeline.
This model differs from the SDXL VAE primarily in that it operates on slices of a video latent input, looping over the "frame" dimension and processing each "latent frame" separately. We also see significantly different dispatches formed (A lot more Conv3d, different shapes). The target is also in bf16 precision. There is more work to be done on optimizing this export -- currently, the export process unrolls the loop over video frame slices into a static number of repetitions matching the number of frames, but we should probably emit an entrypoint for processing a single frame, and another to perform initialization / scf.for loop over frames / postprocess with a dynamic number of input frames. It's difficult to accurately emit this with the turbine dynamo export stack.
That being said, I have run benchmarks on a target configuration (512x512 output, 81 frames encode, 21 frames decode) and have preliminary results for VAE encode which follow:
Artifacts required for reproducing results:
Weights: https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/diffusion/wan2_1_14b_480p/wan2_1_vae_bf16.irpa
VMFB: https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/diffusion/wan2_1_14b_480p/wan2_1_vae_512x512_gfx942.vmfb
Sample inputs: https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/diffusion/wan2_1_14b_480p/vae_encode_input.npy
Optional:
MLIR: https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/diffusion/wan2_1_14b_480p/wan2_1_vae_512x512.mlir
You may also use azure CLI to download them as a batch:
Compile command used:
Encode (81 frames, 512x512, bf16):
Benchmark command and output (note: latencies are affected by runtime tracing):
Top dispatches (screenshot of tracy results)
:Tracy profile: DDL link
Decode (1 frame, 512x512, bf16):
Attention seems responsible for performance issues with wan2.1 VAE decode.
Top dispatches (screenshot of tracy results)
:Additional Notes
Compiler version:
Torch version (affects exported MLIR):
The text was updated successfully, but these errors were encountered: