-
Notifications
You must be signed in to change notification settings - Fork 699
VAE compilation failure with aggressive fusion enabled #20875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Without the use of Flow-IR without flag: #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#device_target_hip = #hal.device.target<"hip", [#executable_target_rocm_hsaco_fb]> : !hal.device
module @module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
util.global private @__device_0 = #device_target_hip
flow.executable private @decode$async_dispatch_0 {
flow.executable.export public @decode$async_dispatch_0_elementwise_32x262144_f32 workgroups() -> (index, index, index) {
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @decode$async_dispatch_0_elementwise_32x262144_f32(%arg0: !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>>, %arg1: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>) {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 2.621440e+05 : f32
%0 = iree_tensor_ext.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>> -> tensor<32x262144xf32>
%1 = tensor.empty() : tensor<32xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32xf32>) -> tensor<32xf32>
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<32x262144xf32>) outs(%2 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.addf %in, %out : f32
linalg.yield %6 : f32
} -> tensor<32xf32>
%4 = tensor.empty() : tensor<32x262144xf32>
%5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%0, %3 : tensor<32x262144xf32>, tensor<32xf32>) outs(%4 : tensor<32x262144xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%6 = arith.divf %in_1, %cst_0 : f32
%7 = arith.subf %in, %6 : f32
linalg.yield %7 : f32
} -> tensor<32x262144xf32>
iree_tensor_ext.dispatch.tensor.store %5, %arg1, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : tensor<32x262144xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
return
}
}
}
flow.executable private @decode$async_dispatch_1 {
flow.executable.export public @decode$async_dispatch_1_elementwise_32x262144_f32 workgroups() -> (index, index, index) {
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @decode$async_dispatch_1_elementwise_32x262144_f32(%arg0: !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>>, %arg1: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>) {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 2.621440e+05 : f32
%cst_1 = arith.constant 9.99999997E-7 : f32
%0 = iree_tensor_ext.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>> -> tensor<32x262144xf32>
%1 = tensor.empty() : tensor<32xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32xf32>) -> tensor<32xf32>
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<32x262144xf32>) outs(%2 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.mulf %in, %in : f32
%7 = arith.addf %6, %out : f32
linalg.yield %7 : f32
} -> tensor<32xf32>
%4 = tensor.empty() : tensor<32x262144xf32>
%5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%0, %3 : tensor<32x262144xf32>, tensor<32xf32>) outs(%4 : tensor<32x262144xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%6 = arith.divf %in_2, %cst_0 : f32
%7 = arith.addf %6, %cst_1 : f32
%8 = math.rsqrt %7 : f32
%9 = arith.mulf %in, %8 : f32
linalg.yield %9 : f32
} -> tensor<32x262144xf32>
iree_tensor_ext.dispatch.tensor.store %5, %arg1, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : tensor<32x262144xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
return
}
}
}
util.func public @decode$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<1x32x16x16384xf32>
%1 = flow.tensor.reshape %0 : tensor<1x32x16x16384xf32> -> tensor<32x262144xf32>
%2 = flow.dispatch @decode$async_dispatch_0::@decode$async_dispatch_0_elementwise_32x262144_f32(%1) : (tensor<32x262144xf32>) -> tensor<32x262144xf32>
%3 = flow.dispatch @decode$async_dispatch_1::@decode$async_dispatch_1_elementwise_32x262144_f32(%2) : (tensor<32x262144xf32>) -> tensor<32x262144xf32>
%4 = flow.tensor.reshape %3 : tensor<32x262144xf32> -> tensor<1x32x16x16384xf32>
%5 = hal.tensor.barrier join(%4 : tensor<1x32x16x16384xf32>) => %arg2 : !hal.fence
%6 = hal.tensor.export %5 : tensor<1x32x16x16384xf32> -> !hal.buffer_view
util.return %6 : !hal.buffer_view
}
util.func public @decode(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%0 = util.null : !hal.fence
%c-1_i32 = arith.constant -1 : i32
%c0 = arith.constant 0 : index
%device_0 = hal.devices.get %c0 : !hal.device
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
%1 = util.call @decode$async(%arg0, %0, %fence) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
util.return %1 : !hal.buffer_view
}
}
Flow-IR with flag: #executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#device_target_hip = #hal.device.target<"hip", [#executable_target_rocm_hsaco_fb]> : !hal.device
module @module attributes {stream.affinity.default = #hal.device.affinity<@__device_0>} {
util.global private @__device_0 = #device_target_hip
flow.executable private @decode$async_dispatch_0 {
flow.executable.export public @decode$async_dispatch_0_elementwise_32x262144_f32 workgroups() -> (index, index, index) {
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
flow.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @decode$async_dispatch_0_elementwise_32x262144_f32(%arg0: !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>>, %arg1: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>) {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 2.621440e+05 : f32
%cst_1 = arith.constant 9.99999997E-7 : f32
%0 = iree_tensor_ext.dispatch.tensor.load %arg0, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>> -> tensor<32x262144xf32>
%1 = tensor.empty() : tensor<32xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<32xf32>) -> tensor<32xf32>
%3 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<32x262144xf32>) outs(%2 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%8 = arith.addf %in, %out : f32
linalg.yield %8 : f32
} -> tensor<32xf32>
%4 = tensor.empty() : tensor<32x262144xf32>
%5 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%0, %3 : tensor<32x262144xf32>, tensor<32xf32>) outs(%4 : tensor<32x262144xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%8 = arith.divf %in_2, %cst_0 : f32
%9 = arith.subf %in, %8 : f32
linalg.yield %9 : f32
} -> tensor<32x262144xf32>
%6 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%5 : tensor<32x262144xf32>) outs(%2 : tensor<32xf32>) {
^bb0(%in: f32, %out: f32):
%8 = arith.mulf %in, %in : f32
%9 = arith.addf %8, %out : f32
linalg.yield %9 : f32
} -> tensor<32xf32>
%7 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%5, %6 : tensor<32x262144xf32>, tensor<32xf32>) outs(%4 : tensor<32x262144xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%8 = arith.divf %in_2, %cst_0 : f32
%9 = arith.addf %8, %cst_1 : f32
%10 = math.rsqrt %9 : f32
%11 = arith.mulf %in, %10 : f32
linalg.yield %11 : f32
} -> tensor<32x262144xf32>
iree_tensor_ext.dispatch.tensor.store %7, %arg1, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : tensor<32x262144xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
return
}
}
}
util.func public @decode$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<1x32x16x16384xf32>
%1 = flow.tensor.reshape %0 : tensor<1x32x16x16384xf32> -> tens
8000
or<32x262144xf32>
%2 = flow.dispatch @decode$async_dispatch_0::@decode$async_dispatch_0_elementwise_32x262144_f32(%1) : (tensor<32x262144xf32>) -> tensor<32x262144xf32>
%3 = flow.tensor.reshape %2 : tensor<32x262144xf32> -> tensor<1x32x16x16384xf32>
%4 = hal.tensor.barrier join(%3 : tensor<1x32x16x16384xf32>) => %arg2 : !hal.fence
%5 = hal.tensor.export %4 : tensor<1x32x16x16384xf32> -> !hal.buffer_view
util.return %5 : !hal.buffer_view
}
util.func public @decode(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%0 = util.null : !hal.fence
%c-1_i32 = arith.constant -1 : i32
%c0 = arith.constant 0 : index
%device_0 = hal.devices.get %c0 : !hal.device
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence
%1 = util.call @decode$async(%arg0, %0, %fence) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
util.return %1 : !hal.buffer_view
}
} |
the fused dispatch errors out in which can be backtracked to the creation of a IR Dump After TileAndDistributeToWorkgroupsUsingForallOpPass: func.func @decode$async_dispatch_0_elementwise_32x262144_f32() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [1024, 1, 1] subgroup_size = 64, {gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = false, no_reduce_shared_memory_bank_conflicts = false, use_igemm_convolution = false>}>} {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 2.621440e+05 : f32
%cst_1 = arith.constant 9.99999997E-7 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
%2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x262144xf32>> -> tensor<32x262144xf32>
%3 = tensor.empty() : tensor<32x262144xf32>
%4 = scf.forall (%arg0) in (32) shared_outs(%arg1 = %3) -> (tensor<32x262144xf32>) {
%extracted_slice = tensor.extract_slice %2[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
%extracted_slice_2 = tensor.extract_slice %2[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
%5 = tensor.empty() : tensor<1xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1xf32>) -> tensor<1xf32>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2 : tensor<1x262144xf32>) outs(%6 : tensor<1xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: f32, %out: f32):
%19 = arith.addf %in, %out : f32
linalg.yield %19 : f32
} -> tensor<1xf32>
%8 = tensor.empty() : tensor<1x262144xf32>
%9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %7 : tensor<1x262144xf32>, tensor<1xf32>) outs(%8 : tensor<1x262144xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]]}>} {
^bb0(%in: f32, %in_6: f32, %out: f32):
%19 = arith.divf %in_6, %cst_0 : f32
%20 = arith.subf %in, %19 : f32
linalg.yield %20 : f32
} -> tensor<1x262144xf32>
%10 = tensor.empty() : tensor<1xf32>
%11 = linalg.fill ins(%cst : f32) outs(%10 : tensor<1xf32>) -> tensor<1xf32>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%9 : tensor<1x262144xf32>) outs(%11 : tensor<1xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: f32, %out: f32):
%19 = arith.mulf %in, %in : f32
%20 = arith.addf %19, %out : f32
linalg.yield %20 : f32
} -> tensor<1xf32>
%extracted_slice_3 = tensor.extract_slice %2[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
%extracted_slice_4 = tensor.extract_slice %2[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
%13 = tensor.empty() : tensor<1xf32>
%14 = linalg.fill ins(%cst : f32) outs(%13 : tensor<1xf32>) -> tensor<1xf32>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_4 : tensor<1x262144xf32>) outs(%14 : tensor<1xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{partial_reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]], workgroup = [1, 0]}>} {
^bb0(%in: f32, %out: f32):
%19 = arith.addf %in, %out : f32
linalg.yield %19 : f32
} -> tensor<1xf32>
%16 = tensor.empty() : tensor<1x262144xf32>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_3, %15 : tensor<1x262144xf32>, tensor<1xf32>) outs(%16 : tensor<1x262144xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]]}>} {
^bb0(%in: f32, %in_6: f32, %out: f32):
%19 = arith.divf %in_6, %cst_0 : f32
%20 = arith.subf %in, %19 : f32
linalg.yield %20 : f32
} -> tensor<1x262144xf32>
%extracted_slice_5 = tensor.extract_slice %arg1[%arg0, 0] [1, 262144] [1, 1] : tensor<32x262144xf32> to tensor<1x262144xf32>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17, %12 : tensor<1x262144xf32>, tensor<1xf32>) outs(%extracted_slice_5 : tensor<1x262144xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 4096], subgroup_basis = [[1, 16], [0, 1]], thread = [0, 4], thread_basis = [[1, 64], [0, 1]]}>} {
^bb0(%in: f32, %in_6: f32, %out: f32):
%19 = arith.divf %in_6, %cst_0 : f32
%20 = arith.addf %19, %cst_1 : f32
%21 = math.rsqrt %20 : f32
%22 = arith.mulf %in, %21 : f32
linalg.yield %22 : f32
} -> tensor<1x262144xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %18 into %arg1[%arg0, 0] [1, 262144] [1, 1] : tensor<1x262144xf32> into tensor<32x262144xf32>
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [0, 0], sizes = [32, 262144], strides = [1, 1] : tensor<32x262144xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x262144xf32>>
return
}
|
Uh oh!
There was an error while loading. Please reload this page.
What happened?
sharktank_nightly CI failing for FLUX and SDXL Vae with the following error for multiple dispatches
minimal repro
repro compile command
complete compile command from CI
Note: full vae model and min repro both compile error free without this flag
--iree-dispatch-creation-enable-aggressive-fusion
Steps to reproduce your issue
No response
What component(s) does this issue relate to?
No response
Version information
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: