-
Notifications
You must be signed in to change notification settings - Fork 702
Fuse multi-consumer insert_slices into dispatch regions as in-pla 10000 ce operations. #11102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The original torch code from https://github.com/nod-ai/SHARK/pull/418/files#diff-f98c3ce6646546dce80ee6d6eca9f9537efcc70a300dd8df3b0f128e8bbb4316R40-R46: def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x |
Shady idea I could hack together: a flow-level pass that runs after dispatch workgroups formation that goes and inserts new results and replicates flow.dispatch.tensor.store ops pointing at them. This wouldn't generalize to conditional stores or anything and may make codegen bufferization unhappy (not sure how much it relies on single stores for allocation placement). e.g.: %1529 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%1528, %cst_262) : (tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) -> tensor<1x32x90x62xf32> =
(%arg3: !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>>, %arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
%cst_351 = arith.constant 0.000000e+00 : f32
%cst_352 = arith.constant 0.199999988 : f32
%cst_353 = arith.constant dense<[[-0.104562163, -0.0794978737, -0.139019474, -0.0678559243, -0.0663451776, -0.0395833179, -0.139937162, -0.0967885255, -0.119102009, -0.138187289, -0.0833081305, -0.106967404, 0.0852515175, -0.0525256135, -0.090108551, -0.036612425, -0.113223538, -0.153768227, -0.13075842, -0.066075474, -0.129493013, -0.0539637394, -0.0388106443, -0.158874199, -0.0255966205, -0.13540563, -0.0318158343, -0.0988952666, -0.110548653, -0.160866946, -0.043045409, -0.114123404]]> : tensor<1x32xf32>
%2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 96, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>> -> tensor<1x96x92x64xf32>
%2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 96, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>> -> tensor<32x96x3x3xf32>
%2037 = tensor.empty() : tensor<1x32x90x62xf32>
%2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) outs(%2038 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%2040 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2039, %cst_353 : tensor<1x32x90x62xf32>, tensor<1x32xf32>) outs(%2037 : tensor<1x32x90x62xf32>) {
^bb0(%in: f32, %in_354: f32, %out: f32):
%2041 = arith.addf %in, %in_354 : f32
%2042 = arith.cmpf ugt, %2041, %cst_351 : f32
%2043 = arith.select %2042, %2041, %cst_351 : f32
%2044 = arith.select %2042, %cst_351, %2041 : f32
%2045 = arith.mulf %2044, %cst_352 : f32
%2046 = arith.addf %2043, %2045 : f32
linalg.yield %2046 : f32
} -> tensor<1x32x90x62xf32>
flow.dispatch.tensor.store %2040, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
flow.return
} count(%arg3: index, %arg4: index, %arg5: index, %arg6: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6
flow.return %x, %y, %z : index, index, index
}
%1532 = flow.tensor.update %1529, %1531[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1531 as tensor<1x128x90x62xf32>
%1538 = flow.tensor.update %1529, %1537[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1537 as tensor<1x160x90x62xf32>
%1545 = flow.tensor.update %1529, %1544[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1544 as tensor<1x192x90x62xf32> -> %1529:3 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%1528, %cst_262) : (tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) -> (tensor<1x32x90x62xf32>, tensor<1x32x90x62xf32>, tensor<1x32x90x62xf32>) =
(%arg3: !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>>,
%arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>, %arg6: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>, %arg7: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
%cst_351 = arith.constant 0.000000e+00 : f32
%cst_352 = arith.constant 0.199999988 : f32
%cst_353 = arith.constant dense<[[-0.104562163, -0.0794978737, -0.139019474, -0.0678559243, -0.0663451776, -0.0395833179, -0.139937162, -0.0967885255, -0.119102009, -0.138187289, -0.0833081305, -0.106967404, 0.0852515175, -0.0525256135, -0.090108551, -0.036612425, -0.113223538, -0.153768227, -0.13075842, -0.066075474, -0.129493013, -0.0539637394, -0.0388106443, -0.158874199, -0.0255966205, -0.13540563, -0.0318158343, -0.0988952666, -0.110548653, -0.160866946, -0.043045409, -0.114123404]]> : tensor<1x32xf32>
%2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 96, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>> -> tensor<1x96x92x64xf32>
%2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 96, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>> -> tensor<32x96x3x3xf32>
%2037 = tensor.empty() : tensor<1x32x90x62xf32>
%2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
%2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) outs(%2038 : tensor<1x32x90x62x Like this the flow.tensor.update -> in-place storage of write-only outputs would be able to place all the allocations. |
I believe we may still need some improvements here but it'd be better done with more recent examples (if we even get them anymore from frontends on models we care about). |
Uh oh!
There was an error while loading. Please reload this page.
Certain models use a pattern where they produce a result and then insert that into multiple tensors. Today these are lowered down to transient allocations + DMA copies but when doing these basic tensor broadcast operations it'd be better for the producer dispatch to do multiple writes. Doing this would reduce transient memory requirements (as we'd be placing the N copies of the output directly into their destination and eliminate the transient memory) and break the serialized dependency chain (subsequent dispatches could begin executing earlier as they don't need to wait for the copy to complete).
This may be related to #10840 (multi-result fusion).
From ESRGAN, note that %189 is produced and then inserted into 4 different tensors:
Since all of the target tensors (%7, %11, %15, etc) exist prior to all of this work we could instead have the %189 dispatch just insert into all 4 of them (by taking them as read/write IO).
The text was updated successfully, but these errors were encountered: