8000 [Mistral] Performance degradation with VMFB containing prefill functions of multiple batch sizes · Issue #20836 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Mistral] Performance degradation with VMFB containing prefill functions of multiple batch sizes #20836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pravg-amd opened this issue May 16, 2025 · 4 comments
Labels
bug 🐞 Something isn't working

Comments

@pravg-amd
Copy link
Contributor
pravg-amd commented May 16, 2025

What happened?

When a single vmfb containing prefill functions for multiple batch sizes (2 and 8), there is a performance degradation while running prefill function with batch size 8 when compared to running the VMFB with single prefill function of batch size 8.

This is not visible with batch sizes (4,8) etc.

prefill_bs_4_8.txt
prefill_bs_2_8.txt
prefill_bs_8.txt

Steps to reproduce your issue

The mlir files with prefill batch size (2,8), (4,8) and (8) are attached in this ticket.

Generate the vmfb using the following command

iree-compile prefill_bs_2_8.mlir \
    --iree-hal-target-device=hip \
    --iree-hip-target=gfx942 \
    --iree-opt-level=O3  \
    --iree-hal-indirect-command-buffers=true  \
    --iree-stream-resource-memory-model=discrete  \
    --iree-hal-memoization=true \
    -o quark_mistral_nemo.vmfb

Run the benchmark for prefill_bs8 (SharkMI300x-3)

iree-benchmark-module \
    --device=hip://2 \
    --device_allocator=caching \
    --module=quark_mistral_nemo.vmfb \
    --parameters=model=/data/Mistral-Nemo-Instruct-2407-FP8/quark_mistral_nemo.irpa \
    --function=prefill_bs8 \
    --input=8x1024xsi64 \
    --input=8xsi64 \
    --input=8x32xsi64 \
    --input=1024x2621440xf8E4M3FNUZ \
    --benchmark_repetitions=5

With prefill_bs_8.mlir / prefill_bs_4_8.mlir

BM_prefill_bs8/process_time/real_time               639 ms          639 ms            1 items_per_second=1.56409/s
BM_prefill_bs8/process_time/real_time               640 ms          640 ms            1 items_per_second=1.56353/s
BM_prefill_bs8/process_time/real_time               639 ms          639 ms            1 items_per_second=1.56485/s
BM_prefill_bs8/process_time/real_time               639 ms          640 ms            1 items_per_second=1.56472/s
BM_prefill_bs8/process_time/real_time               639 ms          640 ms            1 items_per_second=1.56405/s
BM_prefill_bs8/process_time/real_time_mean          639 ms          640 ms            5 items_per_second=1.56425/s
BM_prefill_bs8/process_time/real_time_median        639 ms          640 ms            5 items_per_second=1.56409/s
BM_prefill_bs8/process_time/real_time_stddev      0.221 ms        0.371 ms            5 items_per_second=540.428u/s
BM_prefill_bs8/process_time/real_time_cv           0.03 %          0.06 %             5 items_per_second=0.03%

With prefill_bs_2_8.mlir

BM_prefill_bs8/process_time/real_time               873 ms          873 ms            1 items_per_second=1.14508/s
BM_prefill_bs8/process_time/real_time               874 ms          875 ms            1 items_per_second=1.14362/s
BM_prefill_bs8/process_time/real_time               874 ms          874 ms            1 items_per_second=1.14405/s
BM_prefill_bs8/process_time/real_time               873 ms          874 ms            1 items_per_second=1.14492/s
BM_prefill_bs8/process_time/real_time               874 ms          874 ms            1 items_per_second=1.14404/s
BM_prefill_bs8/process_time/real_time_mean          874 ms          874 ms            5 items_per_second=1.14434/s
BM_prefill_bs8/process_time/real_time_median        874 ms          874 ms            5 items_per_second=1.14405/s
BM_prefill_bs8/process_time/real_time_stddev      0.478 ms        0.574 ms            5 items_per_second=626.199u/s

What component(s) does this issue relate to?

Compiler

Version information

IREE compiler version 3.5.0rc20250514 @ d63e15e

Additional context

Steps to download model and irpa files are available here.

https://gist.github.com/pravg-amd/1b9f3e3c3abcb6f2c35fdc10a09db09d

@pravg-amd pravg-amd added the bug 🐞 Something isn't working label May 16, 2025
@pravg-amd
Copy link
Contributor Author
pravg-amd commented May 16, 2025

Initial analysis

As part of the DeduplicateExecutables pass, the following dispatch gets changed as follows

    %1520 = flow.dispatch @prefill_bs8$async_dispatch_805::@prefill_bs8$async_dispatch_805_matmul_like_Dx131072x5120_f16xf16xf32[%1519, %13](%1519, %1518, %__hoisted_tensor_131072x5120xf16_578, %13) : (index, tensor<?x5120xf16>{%13}, tensor<131072x5120xf16>, index) -> tensor<?x131072xf16>{%13}


to

    %1520 = flow.dispatch @prefill_bs2$async_dispatch_805::@prefill_bs2$async_dispatch_805_matmul_like_Dx131072x5120_f16xf16xf32[%1519, %13](%1519, %1518, %__hoisted_tensor_131072x5120xf16_578, %13) : (index, tensor<?x5120xf16>{%13}, tensor<131072x5120xf16>, index) -> tensor<?x131072xf16>{%13}

At HAL for the prefill_bs_2_8 case, the workgroup sizes are (512, 4, z)

    %ordinal_204 = hal.executable.export.ordinal target(@module_linked::@rocm_hsaco_fb::@prefill_bs2$async_dispatch_805_matmul_like_Dx131072x5120_f16xf16xf32) : index
    %147 = arith.divsi %10, %c64 : index
    hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe_203 : !hal.executable)[%ordinal_204] workgroups([%c512, %c4, %147]) constants([%55, %57, %c-267351040_i32, %53]) bindings([
      (%transient_buffer_3 : !hal.buffer)[%c0, %51],
      (%__hoisted_tensor_131072x5120xf16 : !hal.buffer)[%c0, %c10739587072],
      (%transient_buffer : !hal.buffer)[%c0, %32]
    ]) flags("None")

At HAL for prefill_bs_4_8 case, the workgroup sizes are (1024, 2, z)

    %ordinal_204 = hal.executable.export.ordinal target(@module_linked::@rocm_hsaco_fb::@prefill_bs4$async_dispatch_805_matmul_like_Dx131072x5120_f16xf16xf32) : index
    %147 = arith.divsi %10, %c128 : index
    hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe_203 : !hal.executable)[%ordinal_204] workgroups([%c1024, %c2, %147]) constants([%55, %57, %c-267351040_i32, %53]) bindings([
      (%transient_buffer_3 : !hal.buffer)[%c0, %51],
      (%__hoisted_tensor_131072x5120xf16 : !hal.buffer)[%c0, %c10739587072],
      (%transient_buffer : !hal.buffer)[%c0, %32]
    ]) flags("None")

Verified by disabling the DeduplicateExecutables to see the performance gain, though it increases the VMFB size.

CC:: @kumardeepakamd @MaheshRavishankar @pdhirajkumarprasad

@benvanik
Copy link
Collaborator

Good find! This looks like a case that specialization should be able to handle - the analysis information derived from the dispatch sites should be present during executable configuration, but AFAIK today that's not really used by codegen. It'd be good to check if what's required to specialize is there (--iree-hal-dump-executable-sources-to= should show it). This same situation would arise if a single input function dispatched with different sizes, and this particular case of globbing things together just happens to definitely show it.

@benvanik
Copy link
Collaborator

(also, great triage! thanks for digging in!)

@amd-vivekag
Copy link
Contribut 63BE or

Following packages are required to be installed to generate irpa file:

torch
pytest

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants
0