8000 [GPU] compilation failure for alternative bwd grouped conv · Issue #20498 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[GPU] compilation failure for alternative bwd grouped conv #20498

10000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
zjgarvey opened this issue Apr 8, 2025 · 6 comments · Fixed by #20611
Closed

[GPU] compilation failure for alternative bwd grouped conv #20498

zjgarvey opened this issue Apr 8, 2025 · 6 comments · Fixed by #20611
Assignees
Labels
bug 🐞 Something isn't working

Comments

@zjgarvey
Copy link
Contributor
zjgarvey commented Apr 8, 2025

What happened?

Two similar sets of IR for backward grouped convolution are provided.

The first performs grouped-dim expansion and collapse around the conv, which happens after filter spatial dim flips and dLdy padding (this compiles), and the second performs expand/collapse at the function boundaries and applies filter flipping and padding after expansion (but fails to compile).

This IR Compiles

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d6, d2 + d7, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d5, d6, d7, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
module @module {
  util.func public @conv_2d_float32_input_backward_128x24x48x384_nhwc_384x1x3x128_fhwc_nhwf_1x1s_0x1p_1x1d_3g$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub, preprocessing_pipeline = #util.preprocessing_pipeline<"iree-preprocessing-make-single-dispatch">} {
    %cst = arith.constant 0.000000e+00 : f32
    %c2 = arith.constant 2 : index
    %0 = hal.tensor.import wait(%arg2) => %arg0 : !hal.buffer_view -> tensor<128x24x48x384xf32>
    %1 = hal.tensor.import wait(%arg2) => %arg1 : !hal.buffer_view -> tensor<384x1x3x128xf32>
    %2 = tensor.empty() : tensor<384x1x3x128xf32>
    %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<384x1x3x128xf32>) -> tensor<384x1x3x128xf32>
    %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1 : tensor<384x1x3x128xf32>) outs(%3 : tensor<384x1x3x128xf32>) {
    ^bb0(%in: f32, %out: f32):
      %10 = linalg.index 0 : index
      %11 = linalg.index 1 : index
      %12 = linalg.index 2 : index
      %13 = linalg.index 3 : index
      %14 = arith.subi %c2, %12 : index
      %extracted = tensor.extract %1[%10, %11, %14, %13] : tensor<384x1x3x128xf32>
      linalg.yield %extracted : f32
    } -> tensor<384x1x3x128xf32>
    %padded = tensor.pad %0 low[0, 0, 1, 0] high[0, 0, 1, 0] {
    ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index):
      tensor.yield %cst : f32
    } : tensor<128x24x48x384xf32> to tensor<128x24x50x384xf32>
    %expanded = tensor.expand_shape %padded [[0], [1], [2], [3, 4]] output_shape [128, 24, 50, 3, 128] : tensor<128x24x50x384xf32> into tensor<128x24x50x3x128xf32>
    %expanded_0 = tensor.expand_shape %4 [[0, 1], [2], [3], [4]] output_shape [3, 128, 1, 3, 128] : tensor<384x1x3x128xf32> into tensor<3x128x1x3x128xf32>
    %5 = tensor.empty() : tensor<128x24x48x3x128xf32>
    %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x24x48x3x128xf32>) -> tensor<128x24x48x3x128xf32>
    %7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%expanded, %expanded_0 : tensor<128x24x50x3x128xf32>, tensor<3x128x1x3x128xf32>) outs(%6 : tensor<128x24x48x3x128xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %10 = arith.mulf %in, %in_1 : f32
      %11 = arith.addf %out, %10 : f32
      linalg.yield %11 : f32
    } -> tensor<128x24x48x3x128xf32>
    %collapsed = tensor.collapse_shape %7 [[0], [1], [2], [3, 4]] : tensor<128x24x48x3x128xf32> into tensor<128x24x48x384xf32>
    %8 = hal.tensor.barrier join(%collapsed : tensor<128x24x48x384xf32>) => %arg3 : !hal.fence
    %9 = hal.tensor.export %8 : tensor<128x24x48x384xf32> -> !hal.buffer_view
    util.return %9 : !hal.buffer_view
  }
  util.func public @conv_2d_float32_input_backward_128x24x48x384_nhwc_384x1x3x128_fhwc_nhwf_1x1s_0x1p_1x1d_3g(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1 = util.call @conv_2d_float32_input_backward_128x24x48x384_nhwc_384x1x3x128_fhwc_nhwf_1x1s_0x1p_1x1d_3g$async(%arg0, %arg1, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.return %1 : !hal.buffer_view
  }
}

This IR Fails To Compile

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d6, d2 + d7, d3, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d5, d6, d7, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
module @module {
  util.func public @conv_2d_float32_input_backward_128x24x48x384_nhwc_384x1x3x128_fhwc_nhwf_1x1s_0x1p_1x1d_3g$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.fence, %arg3: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub, preprocessing_pipeline = #util.preprocessing_pipeline<"iree-preprocessing-make-single-dispatch">} {
    %cst = arith.constant 0.000000e+00 : f32
    %c2 = arith.constant 2 : index
    %0 = hal.tensor.import wait(%arg2) => %arg0 : !hal.buffer_view -> tensor<128x24x48x384xf32>
    %1 = hal.tensor.import wait(%arg2) => %arg1 : !hal.buffer_view -> tensor<384x1x3x128xf32>
    %expanded = tensor.expand_shape %0 [[0], [1], [2], [3, 4]] output_shape [128, 24, 48, 3, 128] : tensor<128x24x48x384xf32> into tensor<128x24x48x3x128xf32>
    %expanded_0 = tensor.expand_shape %1 [[0, 1], [2], [3], [4]] output_shape [3, 128, 1, 3, 128] : tensor<384x1x3x128xf32> into tensor<3x128x1x3x128xf32>
    %2 = tensor.empty() : tensor<3x128x1x3x128xf32>
    %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<3x128x1x3x128xf32>) -> tensor<3x128x1x3x128xf32>
    %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%expanded_0 : tensor<3x128x1x3x128xf32>) outs(%3 : tensor<3x128x1x3x128xf32>) {
    ^bb0(%in: f32, %out: f32):
      %10 = linalg.index 0 : index
      %11 = linalg.index 1 : index
      %12 = linalg.index 2 : index
      %13 = linalg.index 3 : index
      %14 = linalg.index 4 : index
      %15 = arith.subi %c2, %13 : index
      %extracted = tensor.extract %expanded_0[%10, %11, %12, %15, %14] : tensor<3x128x1x3x128xf32>
      linalg.yield %extracted : f32
    } -> tensor<3x128x1x3x128xf32>
    %padded = tensor.pad %expanded low[0, 0, 1, 0, 0] high[0, 0, 1, 0, 0] {
    ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index):
      tensor.yield %cst : f32
    } : tensor<128x24x48x3x128xf32> to tensor<128x24x50x3x128xf32>
    %5 = tensor.empty() : tensor<128x24x48x3x128xf32>
    %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x24x48x3x128xf32>) -> tensor<128x24x48x3x128xf32>
    %7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%padded, %4 : tensor<128x24x50x3x128xf32>, tensor<3x128x1x3x128xf32>) outs(%6 : tensor<128x24x48x3x128xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %10 = arith.mulf %in, %in_1 : f32
      %11 = arith.addf %out, %10 : f32
      linalg.yield %11 : f32
    } -> tensor<128x24x48x3x128xf32>
    %collapsed = tensor.collapse_shape %7 [[0], [1], [2], [3, 4]] : tensor<128x24x48x3x128xf32> into tensor<128x24x48x384xf32>
    %8 = hal.tensor.barrier join(%collapsed : tensor<128x24x48x384xf32>) => %arg3 : !hal.fence
    %9 = hal.tensor.export %8 : tensor<128x24x48x384xf32> -> !hal.buffer_view
    util.return %9 : !hal.buffer_view
  }
  util.func public @conv_2d_float32_input_backward_128x24x48x384_nhwc_384x1x3x128_fhwc_nhwf_1x1s_0x1p_1x1d_3g(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %0 = util.null : !hal.fence
    %c-1_i32 = arith.constant -1 : i32
    %c0 = arith.constant 0 : index
    %device_0 = hal.devices.get %c0 : !hal.device
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
    %1 = util.call @conv_2d_float32_input_backward_128x24x48x384_nhwc_384x1x3x128_fhwc_nhwf_1x1s_0x1p_1x1d_3g$async(%arg0, %arg1, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.return %1 : !hal.buffer_view
  }
}

Steps to reproduce your issue

Try to compile both examples with

iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx942

What component(s) does this issue relate to?

Compiler

Version information

pip install:

iree-base-compiler 3.4.0rc20250408

Additional context

No response

@zjgarvey zjgarvey added the bug 🐞 Something isn't working label Apr 8, 2025
@rkayaith
Copy link
Member

The compile fails with:

failing.mlir:6:3: error: 'func.func' op uses 592944 bytes of shared memory; exceeded the limit of 65536 bytes

The IR with --compile-to=executable-sources looks like:

func.func @conv_2d_float32_input_backward_128x24x48x384_nhwc_384x1x3x128_fhwc_nhwf_1x1s_0x1p_1x1d_3g$async_dispatch_0_conv_128x24x48x3x128x128x3_f32() {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %c2 = arith.constant 2 : index
  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128x24x48x384xf32>>
  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<384x1x3x128xf32>>
  %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<128x24x48x384xf32>>
  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [128, 24, 48, 384], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<128x24x48x384xf32>> -> tensor<128x24x48x384xf32>
  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [384, 1, 3, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<384x1x3x128xf32>> -> tensor<384x1x3x128xf32>
  %expanded = tensor.expand_shape %3 [[0], [1], [2], [3, 4]] output_shape [128, 24, 48, 3, 128] : tensor<128x24x48x384xf32> into tensor<128x24x48x3x128xf32>
  %padded = tensor.pad %expanded low[0, 0, 1, 0, 0] high[0, 0, 1, 0, 0] {
  ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
    tensor.yield %cst : f32
  } : tensor<128x24x48x3x128xf32> to tensor<128x24x50x3x128xf32>
  %expanded_0 = tensor.expand_shape %4 [[0, 1], [2], [3], [4]] output_shape [3, 128, 1, 3, 128] : tensor<384x1x3x128xf32> into tensor<3x128x1x3x128xf32>
  %5 = tensor.empty() : tensor<3x128x1x3x128xf32>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%5 : tensor<3x128x1x3x128xf32>) {
  ^bb0(%out: f32):
    %10 = linalg.index 0 : index
    %11 = linalg.index 2 : index
    %12 = linalg.index 1 : index
    %13 = arith.addi %11, %12 : index
    %14 = linalg.index 3 : index
    %15 = linalg.index 4 : index
    %16 = arith.subi %c2, %14 : index
    %extracted = tensor.extract %expanded_0[%10, %13, %c0, %16, %15] : tensor<3x128x1x3x128xf32>
    linalg.yield %extracted : f32
  } -> tensor<3x128x1x3x128xf32>
  %collapsed = tensor.collapse_shape %6 [[0], [1, 2], [3], [4]] : tensor<3x128x1x3x128xf32> into tensor<3x128x3x128xf32>
  %7 = tensor.empty() : tensor<128x24x48x3x128xf32>
  %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x24x48x3x128xf32>) -> tensor<128x24x48x3x128xf32>
  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2 + d6, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d5, d6, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%padded, %collapsed : tensor<128x24x50x3x128xf32>, tensor<3x128x3x128xf32>) outs(%8 : tensor<128x24x48x3x128xf32>) {
  ^bb0(%in: f32, %in_2: f32, %out: f32):
    %10 = arith.mulf %in, %in_2 : f32
    %11 = arith.addf %out, %10 : f32
    linalg.yield %11 : f32
  } -> tensor<128x24x48x3x128xf32>
  %collapsed_1 = tensor.collapse_shape %9 [[0], [1], [2], [3, 4]] : tensor<128x24x48x3x128xf32> into tensor<128x24x48x384xf32>
  flow.dispatch.tensor.store %collapsed_1, %2, offsets = [0, 0, 0, 0], sizes = [128, 24, 48, 384], strides = [1, 1, 1, 1] : tensor<128x24x48x384xf32> -> !flow.dispatch.tensor<writeonly:tensor<128x24x48x384xf32>>
  return
}

And for reference here's what the working IR looks like: https://gist.github.com/rkayaith/d322dd974cdc41fe3a3ea2d62981ce1c#file-2-working-exec-source-mlir-L11

The compilation IR dump of the failing IR is here: https://gist.github.com/rkayaith/d322dd974cdc41fe3a3ea2d62981ce1c#file-3-failing-ir-after-all-mlir

@rkayaith
Copy link
Member

The %collapsed = tensor.collapse_shape op in the generic(collapse_shape(generic(...))) chain seems like it could be fused during BubbleExpandShapes, but:

If I hack FoldReshapeWithGenericOpByCollapsing into the pass we end up with this IR, which does compile:

func.func @conv_2d_float32_input_backward_128x24x48x384_nhwc_384x1x3x128_fhwc_nhwf_1x1s_0x1p_1x1d_3g$async_dispatch_0_conv_128x24x48x3x128x128x3_f32() {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %c2 = arith.constant 2 : index
  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<128x24x48x384xf32>>
  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<384x1x3x128xf32>>
  %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<128x24x48x384xf32>>
  %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [128, 24, 48, 384], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<128x24x48x384xf32>> -> tensor<128x24x48x384xf32>
  %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [384, 1, 3, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<384x1x3x128xf32>> -> tensor<384x1x3x128xf32>
  %expanded = tensor.expand_shape %3 [[0], [1], [2], [3, 4]] output_shape [128, 24, 48, 3, 128] : tensor<128x24x48x384xf32> into tensor<128x24x48x3x128xf32>
  %padded = tensor.pad %expanded low[0, 0, 1, 0, 0] high[0, 0, 1, 0, 0] {
  ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
    tensor.yield %cst : f32
  } : tensor<128x24x48x3x128xf32> to tensor<128x24x50x3x128xf32>
  %expanded_0 = tensor.expand_shape %4 [[0, 1], [2], [3], [4]] output_shape [3, 128, 1, 3, 128] : tensor<384x1x3x128xf32> into tensor<3x128x1x3x128xf32>
  %5 = tensor.empty() : tensor<3x128x3x128xf32>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%5 : tensor<3x128x3x128xf32>) {
  ^bb0(%out: f32):
    %10 = linalg.index 0 : index
    %11 = linalg.index 1 : index
    %12 = linalg.index 2 : index
    %13 = linalg.index 3 : index
    %14 = arith.subi %c2, %12 : index
    %extracted = tensor.extract %expanded_0[%10, %11, %c0, %14, %13] : tensor<3x128x1x3x128xf32>
    linalg.yield %extracted : f32
  } -> tensor<3x128x3x128xf32>
  %7 = tensor.empty() : tensor<128x24x48x3x128xf32>
  %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<128x24x48x3x128xf32>) -> tensor<128x24x48x3x128xf32>
  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) ->
8000
 (d0, d1, d2 + d6, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d5, d6, d4)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%padded, %6 : tensor<128x24x50x3x128xf32>, tensor<3x128x3x128xf32>) outs(%8 : tensor<128x24x48x3x128xf32>) {
  ^bb0(%in: f32, %in_1: f32, %out: f32):
    %10 = arith.mulf %in, %in_1 : f32
    %11 = arith.addf %out, %10 : f32
    linalg.yield %11 : f32
  } -> tensor<128x24x48x3x128xf32>
  %collapsed = tensor.collapse_shape %9 [[0], [1], [2], [3, 4]] : tensor<128x24x48x3x128xf32> into tensor<128x24x48x384xf32>
  flow.dispatch.tensor.store %collapsed, %2, offsets = [0, 0, 0, 0], sizes = [128, 24, 48, 384], strides = [1, 1, 1, 1] : tensor<128x24x48x384xf32> -> !flow.dispatch.tensor<writeonly:tensor<128x24x48x384xf32>>
  return
}

But the API for adding this pattern includes some other patterns which we may not want: https://github.com/llvm/llvm-project/blob/747d4a952bf7ed4adec72ddf3c9038aeff4fe8ee/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp#L2259-L2264

@nirvedhmeshram any thoughts on how to proceed here?

@nirvedhmeshram
Copy link
Contributor
nirvedhmeshram commented Apr 10, 2025

It is missing pattern in the pipeline to do FoldReshapeWithGenericOpByCollapsing, note that it is only missing becuase this is not going down the im2col path as this is a group conv, once we enable im2col for group conv, it will find the pattern here
https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp#L138-L139
and the issue will resolve itself I think.
Also one thing to note is that the dispatch creation passes do call the pattern here, https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/DispatchCreation/SinkReshapes.cpp#L190
but there we have a control function that is blocking this, @rkayaith as an experiment you could hack that control function to match the one in the convtoigemm pass and see if that also solves the problem. After that we can see if we want to add some passes/patterns in our single dispatch creation pipeline to do something that helps.

zjgarvey added a commit to iree-org/iree-turbine that referenced this issue Apr 10, 2025

@rkayaith
Copy link
Member

Not yet, but with group convs going down the igemm path, this will get resolved.

@nirvedhmeshram
Copy link
Contributor
nirvedhmeshram commented Apr 21, 2025

No we put it on hold until we enable grouped conv with IGEMM as that might solve the problem or change the issues at the very least.

KyleHerndon pushed a commit to KyleHerndon/iree that referenced this issue May 7, 2025
This adds support for converting group convs to im2col, allowing them to
go down the IGEMM path.

Group dimensions are parallel iterator dims that index into the image,
filter, and output. For im2col they are treated as a batch dimension.

This also fixes iree-org#20498
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants
0