-
Notifications
You must be signed in to change notification settings - Fork 703
Fills/dispatches when padding not getting folded into consumers/producers. 8000 #11049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I think that is addressed by setting the flag |
Yep, just trying to increase visibility and have something to point at for people asking what's up with the bubbles in the pipeline :) |
I need to resurrect that PR... Will do that and add some notes. |
ESRGAN suffers from this as well and would benefit from padding propagation into consumers. Example where if the padding in // and other elementwise ops producing %189/%192/etc
%198 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%197 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %out: f32):
%1177 = arith.cmpf ugt, %in, %cst_701 : f32
%1178 = arith.select %1177, %in, %cst_701 : f32
%1179 = arith.select %1177, %cst_701, %in : f32
%1180 = arith.truncf %cst_702 : f64 to f32
%1181 = arith.mulf %1179, %1180 : f32
%1182 = arith.addf %1178, %1181 : f32
linalg.yield %1182 : f32
} -> tensor<1x32x90x62xf32>
%inserted_slice_919 = tensor.insert_slice %189 into %15[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x160x90x62xf32>
%inserted_slice_920 = tensor.insert_slice %192 into %inserted_slice_919[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
%inserted_slice_921 = tensor.insert_slice %195 into %inserted_slice_920[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
%inserted_slice_922 = tensor.insert_slice %198 into %inserted_slice_921[0, 128, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
%padded_923 = tensor.pad %inserted_slice_922 low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %cst_701 : f32
} : tensor<1x160x90x62xf32> to tensor<1x160x92x64xf32>
%199 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_581 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x32x90x62xf32>
%200 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_923, %cst_582 : tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) outs(%199 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%201 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%200 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %out: f32):
%1177 = arith.cmpf ugt, %in, %cst_701 : f32
%1178 = arith.select %1177, %in, %cst_701 : f32
%1179 = arith.select %1177, %cst_701, %in : f32
%1180 = arith.truncf %cst_702 : f64 to f32
%1181 = arith.mulf %1179, %1180 : f32
%1182 = arith.addf %1178, %1181 : f32
linalg.yield %1182 : f32
} -> tensor<1x32x90x62xf32> Today we end up with this which requires the %39 splat and %41 dispatch copy: %32 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%31, %cst_4) : (tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) -> tensor<1x32x90x62xf32> =
(%arg3: !flow.dispatch.tensor<readonly:tensor<1x160x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x160x3x3xf32>>, %arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
%cst_351 = arith.constant 0.000000e+00 : f32
%cst_352 = arith.constant 0.199999988 : f32
%cst_353 = arith.constant dense<[[-0.00735217752, -0.029075671, -0.0011687536, -0.0265800748, -0.016661156, -0.0216491632, -0.0427877456, -0.0533559099, -0.0249305591, -0.0207087267, -0.0253318828, -0.0515014119, -0.0422265045, -0.0368615724, 0.00198965892, -0.0221594162, -0.0266306344, -0.0617676973, -0.0261138938, -0.00482901605, -0.0400608778, -0.0137573751, -0.00975679792, -0.0443469957, -0.0315653086, -0.0245542042, -0.0320154652, -6.253720e-02, -0.0274252892, 0.00514560752, -0.0166819859, -0.0136556849]]> : tensor<1x32xf32>
%2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 160, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x160x92x64xf32>> -> tensor<1x160x92x64xf32>
%2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 160, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x160x3x3xf32>> -> tensor<32x160x3x3xf32>
%2037 = tensor.empty() : tensor<1x32x90x62xf32>
%2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) outs(%2038 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%2040 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2039, %cst_353 : tensor<1x32x90x62xf32>, tensor<1x32xf32>) outs(%2037 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %in_354: f32, %out: f32):
%2041 = arith.addf %in, %in_354 : f32
%2042 = arith.cmpf ugt, %2041, %cst_351 : f32
%2043 = arith.select %2042, %2041, %cst_351 : f32
%2044 = arith.select %2042, %cst_351, %2041 : f32
%2045 = arith.mulf %2044, %cst_352 : f32
%2046 = arith.addf %2043, %2045 : f32
linalg.yield %2046 : f32
} -> tensor<1x32x90x62xf32>
flow.dispatch.tensor.store %2040, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
flow.return
} count(%arg3: index, %arg4: index, %arg5: index, %arg6: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6
flow.return %x, %y, %z : index, index, index
}
%33 = tensor.empty() : tensor<1x192x90x62xf32>
%34 = flow.tensor.update %4, %33[%c0, %c0, %c0, %c0] : tensor<1x64x90x62xf32> -> %33 as tensor<1x192x90x62xf32>
%35 = flow.tensor.update %8, %34[%c0, %c64, %c0, %c0] : tensor<1x32x90x62xf32> -> %34 as tensor<1x192x90x62xf32>
%36 = flow.tensor.update %15, %35[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %35 as tensor<1x192x90x62xf32>
%37 = flow.tensor.update %23, %36[%c0, %c128, %c0, %c0] : tensor<1x32x90x62xf32> -> %36 as tensor<1x192x90x62xf32>
%38 = flow.tensor.update %32, %37[%c0, %c160, %c0, %c0] : tensor<1x32x90x62xf32> -> %37 as tensor<1x192x90x62xf32>
%39 = flow.tensor.splat %cst : tensor<1x192x92x64xf32>
%40 = flow.tensor.reshape %38 : tensor<1x192x90x62xf32> -> tensor<192x90x62xf32>
%41 = flow.dispatch.workgroups[%c192, %c90, %c62](%40, %39) : (tensor<192x90x62xf32>, tensor<1x192x92x64xf32>) -> %39 =
(%arg3: !flow.dispatch.tensor<readonly:tensor<192x90x62xf32>>, %arg4: !flow.dispatch.tensor<readwrite:tensor<1x192x92x64xf32>>) {
%2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0], sizes = [192, 90, 62], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x90x62xf32>> -> tensor<192x90x62xf32>
flow.dispatch.tensor.store %2035, %arg4, offsets = [0, 0, 1, 1], sizes = [1, 192, 90, 62], strides = [1, 1, 1, 1] : tensor<192x90x62xf32> -> !flow.dispatch.tensor<readwrite:tensor<1x192x92x64xf32>>
flow.return
} count(%arg3: index, %arg4: index, %arg5: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5
flow.return %x, %y, %z : index, index, index
} |
Looking at all the splats - from both this and #6972 - we're doing 2004094976 (2GB!!!!) of memset(0)'s in ESRGAN. As an example, in just the last few stages of ESRGAN we're memset(0)'ing 100MB: |
I tried the |
(out of date, but still something that should be verified eventually) |
Was looking at
tests/e2e/models/resnet50_fake_weights.mlir
and noticed that there are still a lot of fills/slow dispatch-based memcpys (~18 fills/dispatches and unique executables per each because of the unique sizes). This adds quite a bit of latency to the system as the fill -> dispatch that does just memcpy -> actual consumer are serialized. Thankfully we can run the fill concurrently with the producer but that is a large additional transient value we need to allocate/keep live and still an extra 33% baseline latency ([producer|fill] -> pad dispatch -> consumer vs. producer -> consumer). 23% of the dispatches we compile/store in the binary/execute at runtime are these pads and a ~25% savings on that would be awesome. Now that we have some latency-sensitive models with convs (where I think we end up with the most pads) getting rid of this noise will help keep focus on the actual codegen improvements and not the dispatch scheduling.I think #9194 was supposed to prevent this, but there's also a draft #10184 that may have intended to do it. Fixing this would let us finally close the old #2783. Feel free to close as a dupe or consider this a ping with an easily available reproducer :)
What this looks like during execution is (with dispatch_6 as the serialized pad):
Ideally we'd just see dispatch_9 -> dispatch_7 (matmul -> conv) with no intervening ops.
The text was updated successfully, but these errors were encountered: