8000 [DispatchCreation] Make truncate operations fuse with producers. by MaheshRavishankar · Pull Request #19847 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[DispatchCreation] Make truncate operations fuse with producers. #19847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8000
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,21 @@ void ElementwiseOpFusionPass::runOnOperation() {
[&](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();
if (!producer || !consumer) {
return false;
}

// If `intraDispatch` is true, make sure that producer and consumer are
// inside dispatch.
if (intraDispatch &&
IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}

if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
// If `intraDispatch` is false, make sure that the producer and consumer
// are outside dispatch.
if (!intraDispatch &&
!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
return false;
}

Expand All @@ -131,11 +144,14 @@ void ElementwiseOpFusionPass::runOnOperation() {
if (operands.size() >= kIreeMaxOperandCount)
return false;

return areFusableAsElementwiseOps(context, fusedOperand,
fuseMultiReduction);
ElementwiseOpsFusabilityOptions options;
options.fuseMultiReduction = fuseMultiReduction;
options.fuseTruncateOps = fuseTruncateOps;
return areFusableAsElementwiseOps(context, fusedOperand, options);
};

RewritePatternSet linalgFusionPatterns(context);
linalgFusionPatterns.insert<GatherFusionPattern>(context);
linalg::populateElementwiseOpsFusionPatterns(linalgFusionPatterns,
fuseElementwiseOpsControlFn);

Expand All @@ -158,7 +174,6 @@ void ElementwiseOpFusionPass::runOnOperation() {
RewritePatternSet linalgExtFusionPatterns(context);
IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes(
linalgExtFusionPatterns, foldTransposeControlFn);
linalgExtFusionPatterns.insert<GatherFusionPattern>(context);
if (failed(applyPatternsGreedily(
getOperation(), std::move(linalgExtFusionPatterns), rewriteConfig))) {
getOperation()->emitOpError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,19 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
continue;
}

// 7. Skip dequantization-like `producer` ops as we would rather fuse
// 7. Skip bit-extend-like `producer` ops as we would rather fuse
// by cloning the producer instead of multi-use fusion.
if (IREE::LinalgExt::isBitExtendOp(producer)) {
return;
}

// 8. All uses from `producer` -> `consumer` need to be fusable.
// 8. Skip bit-truncate-like `producer` ops as we would rather fuse
// these operations with their producers.
if (IREE::LinalgExt::isBitTruncateOp(producer)) {
return;
}

// 9. All uses from `producer` -> `consumer` need to be fusable.
// Without this the `producer` is still live, and there is no
// advantage to do the fusion.
if (llvm::any_of(getAllUsesInConsumer(producer, genericOp),
Expand Down
19 changes: 17 additions & 2 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace mlir::iree_compiler::DispatchCreation {

bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
bool fuseMultiReduction) {
ElementwiseOpsFusabilityOptions options) {
Operation *producerOp = fusedOperand->get().getDefiningOp();
Operation *consumerOp = fusedOperand->getOwner();
if (!producerOp)
Expand Down Expand Up @@ -73,6 +73,20 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
if (!linalgConsumerOp) {
return false;
}

// Do not fuse with bit-truncate-like operations with their consumers, unless
// the consumer has only one ins operand and is an elementwise operation. The
// elementwise oepration implies that the `outs` operand is not real usage
// (and is typically a `tensor.empty`), so the core condition is that there is
// only one "real" operand of the consumer.
if (!options.fuseTruncateOps &&
IREE::LinalgExt::isBitTruncateOp(producerOp) &&
!(linalgConsumerOp.getNumLoops() ==
linalgConsumerOp.getNumParallelLoops() &&
linalgConsumerOp.getNumDpsInputs() == 1)) {
return false;
}

// If the producer has a single use (this op), only fuse if
// - 1) The consumer op is all parallel loops. The parallelism of the consumer
// can be used as a way to amortize cost of redundant computation
Expand All @@ -86,7 +100,8 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
.isPermutation()) {
return false;
}
if (!fuseMultiReduction && linalgConsumerOp.getNumReductionLoops() != 1) {
if (!options.fuseMultiReduction &&
linalgConsumerOp.getNumReductionLoops() != 1) {
return false;
}
if (linalg::isaContractionOpInterface(linalgConsumerOp) ||
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ namespace mlir::iree_compiler::DispatchCreation {

/// Return true of the producer and consumer of `operand` are fusable
/// using elementwise op fusion transformation.
struct ElementwiseOpsFusabilityOptions {
// Control fusion with consumer that has multiple reduction dimensions.
bool fuseMultiReduction = false;
// Control fusion with producer that is a truncate-like operation.
bool fuseTruncateOps = false;
};
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
bool fuseMultiReduction);
ElementwiseOpsFusabilityOptions options);

} // namespace mlir::iree_compiler::DispatchCreation
15 changes: 13 additions & 2 deletions compiler/src/iree/compiler/DispatchCreation/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
.addPass([]() {
return DispatchCreation::createElementwiseOpFusionPass(
ElementwiseOpFusionPassOptions{
clEnableElementWiseFuseMultiReduction});
/*intraDispatch=*/false,
/*fuseMultiReduction=*/clEnableElementWiseFuseMultiReduction,
/*fuseTruncateOps=*/false});
})
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
Expand All @@ -148,7 +150,9 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
.addPass([]() {
return DispatchCreation::createElementwiseOpFusionPass(
ElementwiseOpFusionPassOptions{
clEnableElementWiseFuseMultiReduction});
/*intraDispatch=*/false,
/*fuseMultiReduction=*/clEnableElementWiseFuseMultiReduction,
/*fuseTruncateOps=*/false});
})
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
Expand Down Expand Up @@ -225,6 +229,13 @@ static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
clEnableFusePaddingIntoLinalgConsumerOps,
clEnableFusePaddingIntoLinalgProducerOps});
})
// Elementwise fuse operations that are iside a dispatch if possible.
.addPass([&]() {
return DispatchCreation::createElementwiseOpFusionPass(
ElementwiseOpFusionPassOptions{/*intraDispatch=*/true,
/*fuseMultiReduction=*/false,
/*fuseTruncateOps=*/true});
})
// Clone all producers into the dispatch region to perpare for being
// isolated from above. This enables running additional transformations
// afterwards that would need the full dispatch content but don't want to
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/iree/compiler/DispatchCreation/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ def ElementwiseOpFusionPass :
Pass<"iree-dispatch-creation-elementwise-op-fusion", ""> {
let summary = "Fuse elementwise operations.";
let options = [
Option<"intraDispatch", "intra-dispatch", "bool",
/*default=*/"false", "Fuse operations within a dispatch only (default is to fuse only operations outside of a dispatch)">,
Option<"fuseMultiReduction", "fuse-multi-reduction", "bool",
/*default=*/"true", "Fuse ops that have multiple reduction iterators">
/*default=*/"true", "Fuse ops that have multiple reduction iterators">,
Option<"fuseTruncateOps", "fuse-truncate-ops", "bool",
/*default=*/"false", "Fuse producer truncate-like operations with consumers">,
];
let dependentDialects = [
"mlir::affine::AffineDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ struct SinkReshapesPass final
/// we just approximate it (and try to be optimistic)
static bool isFusableUsingTileAndFuse(Operation *producer,
Operation *consumer) {
return llvm::isa_and_nonnull<IREE::LinalgExt::LinalgFusionOpInterface,
return IREE::LinalgExt::isBitTruncateOp(producer) ||
llvm::isa_and_nonnull<IREE::LinalgExt::LinalgFusionOpInterface,
linalg::LinalgOp, tensor::UnPackOp,
IREE::Encoding::UnsetEncodingOp>(producer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ iree_lit_test_suite(
"pad_fusion_with_consumer.mlir",
"pad_fusion_with_producer.mlir",
"pipeline_tests.mlir",
"pipeline_tests_aggressive.mlir",
"set_encoding.mlir",
"sink_reshapes.mlir",
"split_reduction.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ iree_lit_test_suite(
"pad_fusion_with_consumer.mlir"
"pad_fusion_with_producer.mlir"
"pipeline_tests.mlir"
"pipeline_tests_aggressive.mlir"
"set_encoding.mlir"
"sink_reshapes.mlir"
"split_reduction.mlir"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ util.func @mixed_conv(%arg0 : tensor<2x130x130x16xf16>, %arg1 : tensor<3x3x16x32
util.return %truncf : tensor<2x128x128x320xf16>
}
// CHECK-LABEL: func public @mixed_conv(
// CHECK: flow.dispatch.workgroups
// CHECK: flow.dispatch.workgroups
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.workgroups
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: outs(%[[FILL]] :
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// RUN: iree-opt --iree-dispatch-creation-pipeline --iree-dispatch-creation-enable-aggressive-fusion --split-input-file --mlir-print-local-scope %s | FileCheck %s

util.func public @truncate_fusion(%arg0: tensor<2x64x64x320xi8>, %arg1: tensor<2x66x66x640xi8>, %arg2: tensor<3x3x640x640xi8>, %arg3: tensor<640xi32>, %arg4: tensor<640xf32>, %arg5: tensor<640x320xi8>, %arg6: tensor<640xi32>, %arg7: tensor<640xf32>) -> tensor<2x640x64x64xf16> {
%c0_i32 = arith.constant 0 : i32
%0 = tensor.empty() : tensor<2x64x64x320xi8>
%1 = tensor.empty() : tensor<2x64x64x640xi32>
%2 = linalg.fill ins(%c0_i32 : i32) outs(%1 : tensor<2x64x64x640xi32>) -> tensor<2x64x64x640xi32>
%3 = tensor.empty() : tensor<2x64x64x640xf32>
%4 = tensor.empty() : tensor<2x640x64x64xf16>
%5 = tensor.empty() : tensor<2x64x64x640xf16>
%6 = tensor.empty() : tensor<2x64x64x320xf16>
%7 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg1, %arg2 : tensor<2x66x66x640xi8>, tensor<3x3x640x640xi8>) outs(%2 : tensor<2x64x64x640xi32>) -> tensor<2x64x64x640xi32>
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%7, %arg3 : tensor<2x64x64x640xi32>, tensor<640xi32>) outs(%1 : tensor<2x64x64x640xi32>) {
^bb0(%in: i32, %in_0: i32, %out: i32):
%19 = arith.addi %in, %in_0 : i32
linalg.yield %19 : i32
} -> tensor<2x64x64x640xi32>
%9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<2x64x64x640xi32>) outs(%3 : tensor<2x64x64x640xf32>) {
^bb0(%in: i32, %out: f32):
%19 = arith.sitofp %in : i32 to f32
linalg.yield %19 : f32
} -> tensor<2x64x64x640xf32>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%9, %arg4 : tensor<2x64x64x640xf32>, tensor<640xf32>) outs(%3 : tensor<2x64x64x640xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%19 = arith.mulf %in, %in_0 : f32
linalg.yield %19 : f32
} -> tensor<2x64x64x640xf32>
%11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%10 : tensor<2x64x64x640xf32>) outs(%5 : tensor<2x64x64x6 CEB7 40xf16>) {
^bb0(%in: f32, %out: f16):
%19 = arith.truncf %in : f32 to f16
linalg.yield %19 : f16
} -> tensor<2x64x64x640xf16>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d4, d3)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>], iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]} ins(%arg0, %arg5 : tensor<2x64x64x320xi8>, tensor<640x320xi8>) outs(%2 : tensor<2x64x64x640xi32>) { ^bb0(%in: i8, %in_0: i8, %out: i32):
%19 = arith.extsi %in : i8 to i32
%20 = arith.extsi %in_0 : i8 to i32
%21 = arith.muli %19, %20 : i32
%22 = arith.addi %out, %21 : i32
linalg.yield %22 : i32
} -> tensor<2x64x64x640xi32>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%12, %arg6 : tensor<2x64x64x640xi32>, tensor<640xi32>) outs(%1 : tensor<2x64x64x640xi32>) {
^bb0(%in: i32, %in_0: i32, %out: i32):
%19 = arith.addi %in, %in_0 : i32
linalg.yield %19 : i32
} -> tensor<2x64x64x640xi32>
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<2x64x64x640xi32>) outs(%3 : tensor<2x64x64x640xf32>) {
^bb0(%in: i32, %out: f32):
%19 = arith.sitofp %in : i32 to f32
linalg.yield %19 : f32
} -> tensor<2x64x64x640xf32>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%14, %arg7 : tensor<2x64x64x640xf32>, tensor<640xf32>) outs(%3 : tensor<2x64x64x640xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%19 = arith.mulf %in, %in_0 : f32
linalg.yield %19 : f32
} -> tensor<2x64x64x640xf32>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%15 : tensor<2x64x64x640xf32>) outs(%5 : tensor<2x64x64x640xf16>) {
^bb0(%in: f32, %out: f16):
%19 = arith.truncf %in : f32 to f16
linalg.yield %19 : f16
} -> tensor<2x64x64x640xf16>
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%16, %11 : tensor<2x64x64x640xf16>, tensor<2x64x64x640xf16>) outs(%5 : tensor<2x64x64x640xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%19 = arith.addf %in, %in_0 : f16
linalg.yield %19 : f16
} -> tensor<2x64x64x640xf16>
%18 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%17 : tensor<2x64x64x640xf16>) outs(%4 : tensor<2x640x64x64xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<2x640x64x64xf16>
util.return %18 : tensor<2x640x64x64xf16>
}

// CHECK-LABEL: func public @truncate_fusion
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.workgroups
// CHECK: %[[MUL:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: outs(%{{.*}} : tensor<8192x640xi32>)
// CHECK: %[[TRUNC0:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[MUL]]
// CHECK-SAME: outs(%{{.*}} : tensor<8192x640xf16>)
// CHECK: flow.dispatch.tensor.store %[[TRUNC0]]
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.workgroups
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {{.*}} -> tensor<2x64x64x640xi32>
// CHECK: %[[TRUNC1:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%{{[a-zA-Z0-9]+}}, %[[CONV]]
// CHECK-SAME: outs(%{{.*}} : tensor<2x640x64x64xf16>)
// CHECK: flow.dispatch.tensor.store %[[TRUNC1]]
// CHECK: return %[[DISPATCH1]]
Loading
0