10000 Fuse multi-consumer insert_slices into dispatch regions as in-place operations. · Issue #11102 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fuse multi-consumer insert_slices into dispatch regions as in-pla 10000 ce operations. #11102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
benvanik opened this issue Nov 9, 2022 · 3 comments
Assignees
Labels
codegen Shared code generation infrastructure and dialects performance ⚡ Performance/optimization related work across the compiler and runtime

Comments

@benvanik
Copy link
Collaborator
benvanik commented Nov 9, 2022

Certain models use a pattern where they produce a result and then insert that into multiple tensors. Today these are lowered down to transient allocations + DMA copies but when doing these basic tensor broadcast operations it'd be better for the producer dispatch to do multiple writes. Doing this would reduce transient memory requirements (as we'd be placing the N copies of the output directly into their destination and eliminate the transient memory) and break the serialized dependency chain (subsequent dispatches could begin executing earlier as they don't need to wait for the copy to complete).

This may be related to #10840 (multi-result fusion).

From ESRGAN, note that %189 is produced and then inserted into 4 different tensors:

    %189 = linalg.generic {indexing_maps = [#map2, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%188, %173 : tensor<1x64x90x62xf32>, tensor<1x64x90x62xf32>) outs(%0 : tensor<1x64x90x62xf32>) {
    ^bb0(%in: f32, %in_2018: f32, %out: f32):
      %1177 = arith.addf %in, %in_2018 : f32
      linalg.yield %1177 : f32
    } -> tensor<1x64x90x62xf32>
    %padded_911 = tensor.pad %189 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x64x90x62xf32> to tensor<1x64x92x64xf32>
    %190 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_587 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %191 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_911, %cst_588 : tensor<1x64x92x64xf32>, tensor<32x64x3x3xf32>) outs(%190 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %192 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%191 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_912 = tensor.insert_slice %189 into %7[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x96x90x62xf32>
    %inserted_slice_913 = tensor.insert_slice %192 into %inserted_slice_912[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x96x90x62xf32>
    %padded_914 = tensor.pad %inserted_slice_913 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x96x90x62xf32> to tensor<1x96x92x64xf32>
    %193 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_585 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %194 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_914, %cst_586 : tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) outs(%193 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %195 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%194 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_915 = tensor.insert_slice %189 into %11[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x128x90x62xf32>
    %inserted_slice_916 = tensor.insert_slice %192 into %inserted_slice_915[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x128x90x62xf32>
    %inserted_slice_917 = tensor.insert_slice %195 into %inserted_slice_916[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x128x90x62xf32>
    %padded_918 = tensor.pad %inserted_slice_917 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x128x90x62xf32> to tensor<1x128x92x64xf32>
    %196 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_583 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %197 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_918, %cst_584 : tensor<1x128x92x64xf32>, tensor<32x128x3x3xf32>) outs(%196 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %198 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%197 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_919 = tensor.insert_slice %189 into %15[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_920 = tensor.insert_slice %192 into %inserted_slice_919[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_921 = tensor.insert_slice %195 into %inserted_slice_920[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %inserted_slice_922 = tensor.insert_slice %198 into %inserted_slice_921[0, 128, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x160x90x62xf32>
    %padded_923 = tensor.pad %inserted_slice_922 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_701 : f32
    } : tensor<1x160x90x62xf32> to tensor<1x160x92x64xf32>
    %199 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_581 : tensor<32xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<1x32x90x62xf32>
    %200 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded_923, %cst_582 : tensor<1x160x92x64xf32>, tensor<32x160x3x3xf32>) outs(%199 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
    %201 = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%200 : tensor<1x32x90x62xf32>) outs(%3 : tensor<1x32x90x62xf32>) {
    ^bb0(%in: f32, %out: f32):
      %1177 = arith.cmpf ugt, %in, %cst_701 : f32
      %1178 = arith.select %1177, %in, %cst_701 : f32
      %1179 = arith.select %1177, %cst_701, %in : f32
      %1180 = arith.truncf %cst_702 : f64 to f32
      %1181 = arith.mulf %1179, %1180 : f32
      %1182 = arith.addf %1178, %1181 : f32
      linalg.yield %1182 : f32
    } -> tensor<1x32x90x62xf32>
    %inserted_slice_924 = tensor.insert_slice %189 into %19[0, 0, 0, 0] [1, 64, 90, 62] [1, 1, 1, 1] : tensor<1x64x90x62xf32> into tensor<1x192x90x62xf32>
    %inserted_slice_925 = tensor.insert_slice %192 into %inserted_slice_924[0, 64, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x192x90x62xf32>
    %inserted_slice_926 = tensor.insert_slice %195 into %inserted_slice_925[0, 96, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x192x90x62xf32>
    %inserted_slice_927 = tensor.insert_slice %198 into %inserted_slice_926[0, 128, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x192x90x62xf32>
    %inserted_slice_928 = tensor.insert_slice %201 into %inserted_slice_927[0, 160, 0, 0] [1, 32, 90, 62] [1, 1, 1, 1] : tensor<1x32x90x62xf32> into tensor<1x192x90x62xf32>

Since all of the target tensors (%7, %11, %15, etc) exist prior to all of this work we could instead have the %189 dispatch just insert into all 4 of them (by taking them as read/write IO).

@benvanik benvanik added codegen Shared code generation infrastructure and dialects performance ⚡ Performance/optimization related work across the compiler and runtime labels Nov 9, 2022
@benvanik
Copy link
Collaborator Author
benvanik commented Nov 9, 2022

The original torch code from https://github.com/nod-ai/SHARK/pull/418/files#diff-f98c3ce6646546dce80ee6d6eca9f9537efcc70a300dd8df3b0f128e8bbb4316R40-R46:

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

@benvanik
Copy link
Collaborator Author
benvanik commented Nov 9, 2022

Shady idea I could hack together: a flow-level pass that runs after dispatch workgroups formation that goes and inserts new results and replicates flow.dispatch.tensor.store ops pointing at them. This wouldn't generalize to conditional stores or anything and may make codegen bufferization unhappy (not sure how much it relies on single stores for allocation placement).

e.g.:

    %1529 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%1528, %cst_262) : (tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) -> tensor<1x32x90x62xf32> =
        (%arg3: !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>>, %arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
      %cst_351 = arith.constant 0.000000e+00 : f32
      %cst_352 = arith.constant 0.199999988 : f32
      %cst_353 = arith.constant dense<[[-0.104562163, -0.0794978737, -0.139019474, -0.0678559243, -0.0663451776, -0.0395833179, -0.139937162, -0.0967885255, -0.119102009, -0.138187289, -0.0833081305, -0.106967404, 0.0852515175, -0.0525256135, -0.090108551, -0.036612425, -0.113223538, -0.153768227, -0.13075842, -0.066075474, -0.129493013, -0.0539637394, -0.0388106443, -0.158874199, -0.0255966205, -0.13540563, -0.0318158343, -0.0988952666, -0.110548653, -0.160866946, -0.043045409, -0.114123404]]> : tensor<1x32xf32>
      %2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 96, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>> -> tensor<1x96x92x64xf32>
      %2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 96, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>> -> tensor<32x96x3x3xf32>
      %2037 = tensor.empty() : tensor<1x32x90x62xf32>
      %2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
      %2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) outs(%2038 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
      %2040 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2039, %cst_353 : tensor<1x32x90x62xf32>, tensor<1x32xf32>) outs(%2037 : tensor<1x32x90x62xf32>) {
      ^bb0(%in: f32, %in_354: f32, %out: f32):
        %2041 = arith.addf %in, %in_354 : f32
        %2042 = arith.cmpf ugt, %2041, %cst_351 : f32
        %2043 = arith.select %2042, %2041, %cst_351 : f32
        %2044 = arith.select %2042, %cst_351, %2041 : f32
        %2045 = arith.mulf %2044, %cst_352 : f32
        %2046 = arith.addf %2043, %2045 : f32
        linalg.yield %2046 : f32
      } -> tensor<1x32x90x62xf32>
      flow.dispatch.tensor.store %2040, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
      flow.return
    } count(%arg3: index, %arg4: index, %arg5: index, %arg6: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6
      flow.return %x, %y, %z : index, index, index
    }
    %1532 = flow.tensor.update %1529, %1531[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1531 as tensor<1x128x90x62xf32>
    %1538 = flow.tensor.update %1529, %1537[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1537 as tensor<1x160x90x62xf32>
    %1545 = flow.tensor.update %1529, %1544[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1544 as tensor<1x192x90x62xf32>

->

    %1529:3 = flow.dispatch.workgroups[%c1, %c32, %c90, %c62](%1528, %cst_262) : (tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) -> (tensor<1x32x90x62xf32>, tensor<1x32x90x62xf32>, tensor<1x32x90x62xf32>) =
        (%arg3: !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>>, %arg4: !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>>,
        %arg5: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>, %arg6: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>, %arg7: !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>) {
      %cst_351 = arith.constant 0.000000e+00 : f32
      %cst_352 = arith.constant 0.199999988 : f32
      %cst_353 = arith.constant dense<[[-0.104562163, -0.0794978737, -0.139019474, -0.0678559243, -0.0663451776, -0.0395833179, -0.139937162, -0.0967885255, -0.119102009, -0.138187289, -0.0833081305, -0.106967404, 0.0852515175, -0.0525256135, -0.090108551, -0.036612425, -0.113223538, -0.153768227, -0.13075842, -0.066075474, -0.129493013, -0.0539637394, -0.0388106443, -0.158874199, -0.0255966205, -0.13540563, -0.0318158343, -0.0988952666, -0.110548653, -0.160866946, -0.043045409, -0.114123404]]> : tensor<1x32xf32>
      %2035 = flow.dispatch.tensor.load %arg3, offsets = [0, 0, 0, 0], sizes = [1, 96, 92, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x96x92x64xf32>> -> tensor<1x96x92x64xf32>
      %2036 = flow.dispatch.tensor.load %arg4, offsets = [0, 0, 0, 0], sizes = [32, 96, 3, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>> -> tensor<32x96x3x3xf32>
      %2037 = tensor.empty() : tensor<1x32x90x62xf32>
      %2038 = linalg.fill ins(%cst_351 : f32) outs(%2037 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
      %2039 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%2035, %2036 : tensor<1x96x92x64xf32>, tensor<32x96x3x3xf32>) outs(%2038 : tensor<1x32x90x62xf32>) -> tensor<1x32x90x62xf32>
      %2040 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2039, %cst_353 : tensor<1x32x90x62xf32>, tensor<1x32xf32>) outs(%2037 : tensor<1x32x90x62xf32>) {
      ^bb0(%in: f32, %in_354: f32, %out: f32):
        %2041 = arith.addf %in, %in_354 : f32
        %2042 = arith.cmpf ugt, %2041, %cst_351 : f32
        %2043 = arith.select %2042, %2041, %cst_351 : f32
        %2044 = arith.select %2042, %cst_351, %2041 : f32
        %2045 = arith.mulf %2044, %cst_352 : f32
        %2046 = arith.addf %2043, %2045 : f32
        linalg.yield %2046 : f32
      } -> tensor<1x32x90x62xf32>
      flow.dispatch.tensor.store %2040, %arg5, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
      flow.dispatch.tensor.store %2040, %arg6, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
      flow.dispatch.tensor.store %2040, %arg7, offsets = [0, 0, 0, 0], sizes = [1, 32, 90, 62], strides = [1, 1, 1, 1] : tensor<1x32x90x62xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x32x90x62xf32>>
      flow.return
    } count(%arg3: index, %arg4: index, %arg5: index, %arg6: index) -> (index, index, index) {
      %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg3, %arg4, %arg5, %arg6
      flow.return %x, %y, %z : index, index, index
    }
    %1532 = flow.tensor.update %1529#0, %1531[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1531 as tensor<1x128x90x62xf32>
    %1538 = flow.tensor.update %1529#1, %1537[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1537 as tensor<1x160x90x62xf32>
    %1545 = flow.tensor.update %1529#2, %1544[%c0, %c96, %c0, %c0] : tensor<1x32x90x62xf32> -> %1544 as tensor<1x192x90x62xf32>

Like this the flow.tensor.update -> in-place storage of write-only outputs would be able to place all the allocations.

@benvanik
Copy link
Collaborator Author

I believe we may still need some improvements here but it'd be better done with more recent examples (if we even get them anymore from frontends on models we care about).

@benvanik benvanik closed this as not planned Won't fix, can't repro, duplicate, stale Apr 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
codegen Shared code generation infrastructure and dialects performance ⚡ Performance/optimization related work across the compiler and runtime
Projects
No open projects
Status: No status
Development

No branches or pull requests

3 participants
0