8000 VAE compilation failure with aggressive fusion enabled · Issue #20875 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

VAE compilation failure with aggressive fusion enabled #20875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
PhaneeshB opened this issue May 21, 2025 · 4 comments
Open

VAE compilation failure with aggressive fusion enabled #20875

PhaneeshB opened this issue May 21, 2025 · 4 comments
Labels
bug 🐞 Something isn't working codegen Shared code generation infrastructure and dialects

Comments

@PhaneeshB
Copy link
Contributor
PhaneeshB commented May 21, 2025

What happened?

sharktank_nightly CI failing for FLUX and SDXL Vae with the following error for multiple dispatches

.../shark-ai/shortfin/vae_flux_small.mlir:65:26: error: function 'decode$async_dispatch_3_elementwise_32x262144_f32' uses 1048704 bytes of shared memory; exceeded the limit of 65536 bytes
    %result0, %result1 = torch.aten.var_mean.correction %155, %156, %int0_22, %true : !torch.vtensor<[1,32,16,16384],f32>, !torch.list<int>, !torch.int, !torch.bool -> !torch.vtensor<[1,32,1,1],f32>, !torch.vtensor<[1,32,1,1],f32>
                         ^
.../shark-ai/shortfin/vae_flux_small.mlir:65:26: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
    %result0, %result1 = torch.aten.var_mean.correction %155, %156, %int0_22, %true : !torch.vtensor<[1,32,16,16384],f32>, !torch.list<int>, !torch.int, !torch.bool -> !torch.vtensor<[1,32,1,1],f32>, !torch.vtensor<[1,32,1,1],f32>

minimal repro

module @module {
  util.global private @__auto.decoder.conv_in.weight = #stream.parameter.named<"model"::"decoder.conv_in.weight"> : tensor<512x16x3x3xf32>
  util.global private @__auto.decoder.conv_in.bias = #stream.parameter.named<"model"::"decoder.conv_in.bias"> : tensor<512xf32>
  func.func @decode(%arg0: !torch.vtensor<[1,32,16,16384],f32>) -> !torch.vtensor<[1,32,16,16384],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %int2_20 = torch.constant.int 2
    %int3_21 = torch.constant.int 3
    %156 = torch.prim.ListConstruct %int2_20, %int3_21 : (!torch.int, !torch.int) -> !torch.list<int>
    %int0_22 = torch.constant.int 0
    %true = torch.constant.bool true
    %result0, %result1 = torch.aten.var_mean.correction %arg0, %156, %int0_22, %true : !torch.vtensor<[1,32,16,16384],f32>, !torch.list<int>, !torch.int, !torch.bool -> !torch.vtensor<[1,32,1,1],f32>, !torch.vtensor<[1,32,1,1],f32>
    %float9.999990e-07 = torch.constant.float 9.9999999999999995E-7
    %int1_23 = torch.constant.int 1
    %157 = torch.aten.add.Scalar %result0, %float9.999990e-07, %int1_23 : !torch.vtensor<[1,32,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,32,1,1],f32>
    %158 = torch.aten.rsqrt %157 : !torch.vtensor<[1,32,1,1],f32> -> !torch.vtensor<[1,32,1,1],f32>
    %int1_24 = torch.constant.int 1
    %159 = torch.aten.sub.Tensor %arg0, %result1, %int1_24 : !torch.vtensor<[1,32,16,16384],f32>, !torch.vtensor<[1,32,1,1],f32>, !torch.int -> !torch.vtensor<[1,32,16,16384],f32>
    %160 = torch.aten.mul.Tensor %159, %158 : !torch.vtensor<[1,32,16,16384],f32>, !torch.vtensor<[1,32,1,1],f32> -> !torch.vtensor<[1,32,16,16384],f32>
    return %160 : !torch.vtensor<[1,32,16,16384],f32>
  }
}

repro compile command

iree-compile <vae_flux_repro.mlir>  -o=test.vmfb --iree-hal-target-device=hip --iree-hip-target=gfx942 --iree-dispatch-creation-enable-aggressive-fusion

complete compile command from CI

iree-compile <mllir file> --iree-input-type=auto  --iree-vm-bytecode-module-output-format=flatbuffer-binary \
-o=test.vmfb --iree-hal-target-device=hip --iree-hip-target=gfx942 \
--iree-execution-model=async-external --iree-global-opt-propagate-transposes=1 \
--iree-opt-const-eval=0 --iree-opt-outer-dim-concat=1 \
--iree-opt-aggressively-propa
8000
gate-transposes=1 --iree-codegen-llvmgpu-use-vector-distribution=1 \
--iree-llvmgpu-enable-prefetch=1 --iree-opt-data-tiling=0 \
--iree-vm-target-truncate-unsupported-floats \
--iree-dispatch-creation-enable-aggressive-fusion \
--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' 

Note: full vae model and min repro both compile error free without this flag --iree-dispatch-creation-enable-aggressive-fusion

Steps to reproduce your issue

No response

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

@PhaneeshB PhaneeshB added the bug 🐞 Something isn't working label May 21, 2025
@PhaneeshB
Copy link
Contributor Author

Without the use of --iree-dispatch-creation-enable-aggressive-fusion there are 2 dispatches created each with 2 linalg.generic
and when the flag is used, fusing the dispatches, which leads to the error.

Flow-IR without flag:

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#device_target_hip = #hal.device.target<"hip", [#executable_target_rocm_hsaco_fb]> : !hal.device
module @module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
  util.global private @__device_0 = #device_target_hip
  flow.executable private @decode$async_dispatch_0 {
    flow.executable.export public @decode$async_dispatch_0_elementwise_32x262144_f32 workgroups() -> (index, index, index) {
      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice 
      flow.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @decode$async_dispatch_0_elementwise_32x262144_f32(%arg0: !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>>, %arg1: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>) {
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant 2.621440e+05 : f32
        %0 = iree_tensor_ext.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>> -> tensor<32x262144xf32>
        %1 = tensor.empty() : tensor<32xf32>
        %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32xf32>) -> tensor<32xf32>
        %3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<32x262144xf32>) outs(%2 : tensor<32xf32>) {
        ^bb0(%in: f32, %out: f32):
          %6 = arith.addf %in, %out : f32
          linalg.yield %6 : f32
        } -> tensor<32xf32>
        %4 = tensor.empty() : tensor<32x262144xf32>
        %5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%0, %3 : tensor<32x262144xf32>, tensor<32xf32>) outs(%4 : tensor<32x262144xf32>) {
        ^bb0(%in: f32, %in_1: f32, %out: f32):
          %6 = arith.divf %in_1, %cst_0 : f32
          %7 = arith.subf %in, %6 : f32
          linalg.yield %7 : f32
        } -> tensor<32x262144xf32>
        iree_tensor_ext.dispatch.tensor.store %5, %arg1, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : tensor<32x262144xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
        return
      }
    }
  }
  flow.executable private @decode$async_dispatch_1 {
    flow.executable.export public @decode$async_dispatch_1_elementwise_32x262144_f32 workgroups() -> (index, index, index) {
      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice 
      flow.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @decode$async_dispatch_1_elementwise_32x262144_f32(%arg0: !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>>, %arg1: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>) {
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant 2.621440e+05 : f32
        %cst_1 = arith.constant 9.99999997E-7 : f32
        %0 = iree_tensor_ext.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>> -> tensor<32x262144xf32>
        %1 = tensor.empty() : tensor<32xf32>
        %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32xf32>) -> tensor<32xf32>
        %3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<32x262144xf32>) outs(%2 : tensor<32xf32>) {
        ^bb0(%in: f32, %out: f32):
          %6 = arith.mulf %in, %in : f32
          %7 = arith.addf %6, %out : f32
          linalg.yield %7 : f32
        } -> tensor<32xf32>
        %4 = tensor.empty() : tensor<32x262144xf32>
        %5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%0, %3 : tensor<32x262144xf32>, tensor<32xf32>) outs(%4 : tensor<32x262144xf32>) {
        ^bb0(%in: f32, %in_2: f32, %out: f32):
          %6 = arith.divf %in_2, %cst_0 : f32
          %7 = arith.addf %6, %cst_1 : f32
          %8 = math.rsqrt %7 : f32
          %9 = arith.mulf %in, %8 : f32
          linalg.yield %9 : f32
        } -> tensor<32x262144xf32>
        iree_tensor_ext.dispatch.tensor.store %5, %arg1, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : tensor<32x262144xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
        return
      }
    }
  }
  util.func public @decode$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<1x32x16x16384xf32>
    %1 = flow.tensor.reshape %0 : tensor<1x32x16x16384xf32> -> tensor<32x262144xf32>
    %2 = flow.dispatch @decode$async_dispatch_0::@decode$async_dispatch_0_elementwise_32x262144_f32(%1) : (tensor<32x262144xf32>) -> tensor<32x262144xf32>
    %3 = flow.dispatch @decode$async_dispatch_1::@decode$async_dispatch_1_elementwise_32x262144_f32(%2) : (tensor<32x262144xf32>) -> tensor<32x262144xf32>
    %4 = flow.tensor.reshape %3 : tensor<32x262144xf32> -> tensor<1x32x16x16384xf32>
    %5 = hal.tensor.barrier join(%4 : tensor<1x32x16x16384xf32>) => %arg2 : !hal.fence
    %6 = hal.tensor.export %5 : tensor<1x32x16x16384xf32> -> !hal.buffer_view
    util.return %6 : !hal.buffer_view
  }
  util.func public @decode(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1 = util.call @decode$async(%arg0, %0, %fence) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.return %1 : !hal.buffer_view
  }
}

Flow-IR with flag:

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#device_target_hip = #hal.device.target<"hip", [#executable_target_rocm_hsaco_fb]> : !hal.device
module @module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
  util.global private @__device_0 = #device_target_hip
  flow.executable private @decode$async_dispatch_0 {
    flow.executable.export public @decode$async_dispatch_0_elementwise_32x262144_f32 workgroups() -> (index, index, index) {
      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice 
      flow.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @decode$async_dispatch_0_elementwise_32x262144_f32(%arg0: !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>>, %arg1: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>) {
        %cst = arith.constant 0.000000e+00 : f32
        %cst_0 = arith.constant 2.621440e+05 : f32
        %cst_1 = arith.constant 9.99999997E-7 : f32
        %0 = iree_tensor_ext.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>> -> tensor<32x262144xf32>
        %1 = tensor.empty() : tensor<32xf32>
        %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32xf32>) -> tensor<32xf32>
        %3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<32x262144xf32>) outs(%2 : tensor<32xf32>) {
        ^bb0(%in: f32, %out: f32):
          %8 = arith.addf %in, %out : f32
          linalg.yield %8 : f32
        } -> tensor<32xf32>
        %4 = tensor.empty() : tensor<32x262144xf32>
        %5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%0, %3 : tensor<32x262144xf32>, tensor<32xf32>) outs(%4 : tensor<32x262144xf32>) {
        ^bb0(%in: f32, %in_2: f32, %out: f32):
          %8 = arith.divf %in_2, %cst_0 : f32
          %9 = arith.subf %in, %8 : f32
          linalg.yield %9 : f32
        } -> tensor<32x262144xf32>
        %6 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%5 : tensor<32x262144xf32>) outs(%2 : tensor<32xf32>) {
        ^bb0(%in: f32, %out: f32):
          %8 = arith.mulf %in, %in : f32
          %9 = arith.addf %8, %out : f32
          linalg.yield %9 : f32
        } -> tensor<32xf32>
        %7 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%5, %6 : tensor<32x262144xf32>, tensor<32xf32>) outs(%4 : tensor<32x262144xf32>) {
        ^bb0(%in: f32, %in_2: f32, %out: f32):
          %8 = arith.divf %in_2, %cst_0 : f32
          %9 = arith.addf %8, %cst_1 : f32
          %10 = math.rsqrt %9 : f32
          %11 = arith.mulf %in, %10 : f32
          linalg.yield %11 : f32
        } -> tensor<32x262144xf32>
        iree_tensor_ext.dispatch.tensor.store %7, %arg1, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : tensor<32x262144xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
        return
      }
    }
  }
  util.func public @decode$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<1x32x16x16384xf32>
    %1 = flow.tensor.reshape %0 : tensor<1x32x16x16384xf32> -> tens
8000
or<32x262144xf32>
    %2 = flow.dispatch @decode$async_dispatch_0::@decode$async_dispatch_0_elementwise_32x262144_f32(%1) : (tensor<32x262144xf32>) -> tensor<32x262144xf32>
    %3 = flow.tensor.reshape %2 : tensor<32x262144xf32> -> tensor<1x32x16x16384xf32>
    %4 = hal.tensor.barrier join(%3 : tensor<1x32x16x16384xf32>) => %arg2 : !hal.fence
    %5 = hal.tensor.export %4 : tensor<1x32x16x16384xf32> -> !hal.buffer_view
    util.return %5 : !hal.buffer_view
  }
  util.func public @decode(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1 = util.call @decode$async(%arg0, %0, %fence) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.return %1 : !hal.buffer_view
  }
}

@PhaneeshB
Copy link
Contributor Author
PhaneeshB commented May 27, 2025

the fused dispatch errors out in GPUCheckResourceUsagePass
due to this massive memref alloc
%alloc = memref.alloc() : memref<1x262144xf32, #gpu.address_space<workgroup>>

which can be backtracked to the creation of a tensor.empty instruction in TileAndDistributeToWorkgroupsUsingForallOpPass inside the scf.forall loop

IR Dump After TileAndDistributeToWorkgroupsUsingForallOpPass:

func.func @decode$async_dispatch_0_elementwise_32x262144_f32() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [1024, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>} {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 2.621440e+05 : f32
  %cst_1 = arith.constant 9.99999997E-7 : f32
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>>
  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
  %2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>> -> tensor<32x262144xf32>

  %3 = tensor.empty() : tensor<32x262144xf32>
  %4 = scf.forall (%arg0) in (32) shared_outs(%arg1 = %3) -> (tensor<32x262144xf32>) {
    %extracted_slice = tensor.extract_slice %2[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
    %extracted_slice_2 = tensor.extract_slice %2[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
    %5 = tensor.empty() : tensor<1xf32>
    %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1xf32>) -> tensor<1xf32>
    %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2 : tensor<1x262144xf32>) outs(%6 : tensor<1xf32>) attrs =  {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
    ^bb0(%in: f32, %out: f32):
      %19 = arith.addf %in, %out : f32
      linalg.yield %19 : f32
    } -> tensor<1xf32>

    %8 = tensor.empty() : tensor<1x262144xf32>

    %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %7 : tensor<1x262144xf32>, tensor<1xf32>) outs(%8 : tensor<1x262144xf32>) attrs =  {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]]}>} {
    ^bb0(%in: f32, %in_6: f32, %out: f32):
      %19 = arith.divf %in_6, %cst_0 : f32
      %20 = arith.subf %in, %19 : f32
      linalg.yield %20 : f32
    } -> tensor<1x262144xf32>
    %10 = tensor.empty() : tensor<1xf32>
    %11 = linalg.fill ins(%cst : f32) outs(%10 : tensor<1xf32>) -> tensor<1xf32>
    %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%9 : tensor<1x262144xf32>) outs(%11 : tensor<1xf32>) attrs =  {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
    ^bb0(%in: f32, %out: f32):
      %19 = arith.mulf %in, %in : f32
      %20 = arith.addf %19, %out : f32
      linalg.yield %20 : f32
    } -> tensor<1xf32>
    %extracted_slice_3 = tensor.extract_slice %2[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
    %extracted_slice_4 = tensor.extract_slice %2[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
    %13 = tensor.empty() : tensor<1xf32>
    %14 = linalg.fill ins(%cst : f32) outs(%13 : tensor<1xf32>) -> tensor<1xf32>
    %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_4 : tensor<1x262144xf32>) outs(%14 : tensor<1xf32>) attrs =  {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
    ^bb0(%in: f32, %out: f32):
      %19 = arith.addf %in, %out : f32
      linalg.yield %19 : f32
    } -> tensor<1xf32>
    %16 = tensor.empty() : tensor<1x262144xf32>
    %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_3, %15 : tensor<1x262144xf32>, tensor<1xf32>) outs(%16 : tensor<1x262144xf32>) attrs =  {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]]}>} {
    ^bb0(%in: f32, %in_6: f32, %out: f32):
      %19 = arith.divf %in_6, %cst_0 : f32
      %20 = arith.subf %in, %19 : f32
      linalg.yield %20 : f32
    } -> tensor<1x262144xf32>
    %extracted_slice_5 = tensor.extract_slice %arg1[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
    %18 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17, %12 : tensor<1x262144xf32>, tensor<1xf32>) outs(%extracted_slice_5 : tensor<1x262144xf32>) attrs =  {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]]}>} {
    ^bb0(%in: f32, %in_6: f32, %out: f32):
      %19 = arith.divf %in_6, %cst_0 : f32
      %20 = arith.addf %19, %cst_1 : f32
      %21 = math.rsqrt %20 : f32
      %22 = arith.mulf %in, %21 : f32
      linalg.yield %22 : f32
    } -> tensor<1x262144xf32>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %18 into %arg1[%arg0, 0] [1, 262144] [1, 1] : tensor<1x262144xf32> into tensor<32x262144xf32>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<x>]}
  iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : tensor<32x262144xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
  return
}

@IanWood1
Copy link
Contributor

This looked related to #20895 but the revert didn't fix the issue.

cc @pashu123

@pashu123 pashu123 added the codegen Shared code generation infrastructure and dialects label May 27, 2025
@pashu123
Copy link
Contributor

This looked related to #20895 but the revert didn't fix the issue.

cc @pashu123

This is not related. This is the limitation of tileDispatchUsingForall. It's not using the correct extract slices of the iter_arg; instead, it's creating a new tensor.empty inside the loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working codegen Shared code generation infrastructure and dialects
Projects
None yet
Development

No branches or pull requests

3 participants
0