Open
Description
We have opportunities to fuse mmt4d op with its consumers, but there are performance issues. This happens if we simplifies 1D pack/unpack to expand_shape/collapse_shape ops. Generally it is good because they become metadata op. This issue describes what's happening in mmt4d fusion. Below is the snippet from GPT2 model.
func.func @mmt4d_fusion(%arg0: tensor<1x768x1x1xf32>, %arg1: tensor<192x768x16x1xf32>, %arg2: tensor<1x192x1x16xf32>) -> tensor<1x192x1x16xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1x192x1x16xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x192x1x16xf32>) -> tensor<1x192x1x16xf32>
%2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<1x768x1x1xf32>, tensor<192x768x16x1xf32>) outs(%1 : tensor<1x192x1x16xf32>) -> tensor<1x192x1x16xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2, %arg2 : tensor<1x192x1x16xf32>, tensor<1x192x1x16xf32>) outs(%0 : tensor<1x192x1x16xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
linalg.yield %4 : f32
} -> tensor<1x192x1x16xf32>
return %3 : tensor<1x192x1x16xf32>
}
There are couple issues in current pipeline:
- The Mmt4dTilingExpert pipeline is not working well with current TilingConfig. It only considers three tiling levels. We should make it more compatible with TilingConfig.
- The codegen assumes that only leading parallel dims are shared. This is not the case in the example because the result indexing_map (in mmt4d op) is
(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)
while the indexing_map (in generic op) is identical (i.e.,(d0, d1, d2, d3) -> (d0, d1, d2, d3)
). This leads a bug in multi lowering_config. The mapping is not taken into account in the method. - LLVMCPUTileAndFuse should ignore
iree_codegen.ukernel.generic
op. The iteration domain information is gone after converting a linalg op toukernel.generic
op. Thus, we are not able to tile and fuse the ukernel op. - If we want to enable ukernel in fusion cases, we need to revisit when to convert the mmt4d op to ukernel op. It can either happen after distribution or first level of TileAndFuse. The former one could introduce a big stack buffer because the parallel dimensions are only tiled for distribution. The latter one needs more investigation.
(The other option is to consider specialization. If this is fusion, we go with codegen. Otherwise, we go with ukernel path. I haven't explored this path, so no further comments.)