8000 [DataTiling] Add matmul_k option to SetEncoding pass. by pashu123 · Pull Request #20529 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[DataTiling] Add matmul_k option to SetEncoding pass. #20529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions compiler/src/iree/compiler/DispatchCreation/Passes.cpp
8000
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"

namespace mlir::iree_compiler::DispatchCreation {
//===----------------------------------------------------------------------===//
// Command Line Options
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -79,7 +80,16 @@ static llvm::cl::opt<bool> clHoistEncodingsForConstExpr(
"--iree-opt-data-tiling=false must be set as wells"),
llvm::cl::init(true));

namespace mlir::iree_compiler::DispatchCreation {
static llvm::cl::opt<DispatchCreation::EncodingOptions> clSetEncodingStrategy(
"iree-dispatch-creation-set-encoding-strategy",
llvm::cl::desc("Set the encoding strategy for operations."),
llvm::cl::values(
clEnumValN(
DispatchCreation::EncodingOptions::Generic, "generic",
"Using EncodingAttr which encodes as much information as possible"),
clEnumValN(DispatchCreation::EncodingOptions::MatmulK, "matmulk",
"Only encodes the reduction dimenesions in the encoding.")),
llvm::cl::init(DispatchCreation::EncodingOptions::Generic));

//===----------------------------------------------------------------------===//
// Utilities
Expand Down Expand Up @@ -244,7 +254,10 @@ addDispatchRegionCreationPasses(OpPassManager &passManager,
// Set encodings on all eligible ops. All ops should be in compiler
// formed dispatch regions, so encodings will be placed inside of the
// dispatch regions with the data-tiled op.
.addPass(createSetEncodingPass)
.addPass([&]() {
return DispatchCreation::createSetEncodingPass(
DispatchCreation::SetEncodingPassOptions{clSetEncodingStrategy});
})
// SetEncodingOps should not be in the same dispatch as the data-tiled
// op, so hoist them out of their current dispatch regions. Also, bubble
// SetEncodingOps through special operations like bit-extending ops and
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

namespace mlir::iree_compiler::DispatchCreation {

enum class EncodingOptions { MatmulK, Generic };

//===----------------------------------------------------------------------===//
// Pipelines
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 17 additions & 2 deletions compiler/src/iree/compiler/DispatchCreation/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,29 @@ def HoistEncodingOpsPass :
}


def SetEncodingPass :
InterfacePass<"iree-dispatch-creation-set-encoding", "mlir::FunctionOpInterface"> {
def SetEncodingPass : InterfacePass<"iree-dispatch-creation-set-encoding",
"mlir::FunctionOpInterface"> {
let summary = "Introduces tensor encoding for flow dispatch regions.";
let dependentDialects = [
"mlir::linalg::LinalgDialect",
"IREE::Flow::FlowDialect",
"IREE::Encoding::IREEEncodingDialect",
];
let options = [
Option<
"encodingOption", "encoding-option",
"mlir::iree_compiler::DispatchCreation::EncodingOptions",
/*default=*/
"mlir::iree_compiler::DispatchCreation::EncodingOptions::Generic",
"Select the type of encoding options to add.",
[{::llvm::cl::values(
clEnumValN(
mlir::iree_compiler::DispatchCreation::EncodingOptions::MatmulK,
"matmulk", "Only encodes reduction dimensions in the encoding."),
clEnumValN(
mlir::iree_compiler::DispatchCreation::EncodingOptions::Generic,
"default", "Uses EncodingAttr which encodes as much information as possible."))}]>,
];
}

def ConvertEncodingToFlowPass :
Expand Down
58 changes: 44 additions & 14 deletions compiler/src/iree/compiler/DispatchCreation/SetEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ namespace mlir::iree_compiler::DispatchCreation {
#include "iree/compiler/DispatchCreation/Passes.h.inc"

using IREE::Encoding::EncodingAttr;
using IREE::Encoding::MatmulKAttr;

//===---------------------------------------------------------------------===//
// Utility functions
//===---------------------------------------------------------------------===//

Value setEncoding(OpBuilder &builder, Location loc, Value source,
EncodingAttr encodingAttr) {
auto sourceType = cast<RankedTensorType>(source.getType());
auto resultType = RankedTensorType::get(
sourceType.getShape(), sourceType.getElementType(), encodingAttr);
static Value setEncoding(OpBuilder &builder, Location loc, Value source,
Attribute encodingAttr) {
auto resultType =
cast<RankedTensorType>(source.getType()).cloneWithEncoding(encodingAttr);
return builder.create<IREE::Encoding::SetEncodingOp>(loc, resultType, source);
};

Expand Down Expand Up @@ -163,11 +163,13 @@ class SetContractionOpEncoding final
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
public:
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
explicit SetContractionOpEncoding(MLIRContext *ctx)
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx) {}
explicit SetContractionOpEncoding(MLIRContext *ctx, EncodingOptions &option)
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx),
encodingOption(option) {}

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {

if (!linalgOp.hasPureTensorSemantics()) {
return failure();
}
Expand Down Expand Up @@ -228,14 +230,39 @@ class SetContractionOpEncoding final

auto opType = IREE::Encoding::EncodingOpType::matmul;
auto setEncodingWrapper = [&](Value src, int64_t operandIndex) -> Value {
auto encoding =
EncodingAttr::get(linalgOp.getContext(), operandIndex, opType,
elemTypes, maps, iterationSizes);
MLIRContext *ctx = linalgOp.getContext();
Attribute encoding;
switch (encodingOption) {
case EncodingOptions::Generic: {
encoding = EncodingAttr::get(ctx, operandIndex, opType, elemTypes, maps,
iterationSizes);
break;
}
case EncodingOptions::MatmulK: {
SmallVector<int32_t> kDims;
AffineMap indexingMap = maps[operandIndex];
auto cDims = linalg::inferContractionDims(linalgOp);
for (auto k : cDims->k) {
std::optional<unsigned> dimIdx =
indexingMap.getResultPosition(rewriter.getAffineDimExpr(k));
if (!dimIdx) {
continue;
}
kDims.push_back(dimIdx.value());
}
encoding = MatmulKAttr::get(ctx, kDims);
break;
}
default: {
assert(false && "Unsupported encoding option");
return Value();
}
}
return setEncoding(rewriter, loc, src, encoding);
};
Value encodedLhs = setEncodingWrapper(lhs, IREE::Encoding::MATMUL_LHS);
Value encodedRhs = setEncodingWrapper(rhs, IREE::Encoding::MATMUL_RHS);
Value encodedOut = setEncodingWrapper(out, IREE::Encoding::MATMUL_RESULT);
auto encodedLhs = setEncodingWrapper(lhs, IREE::Encoding::MATMUL_LHS);
auto encodedRhs = setEncodingWrapper(rhs, IREE::Encoding::MATMUL_RHS);
auto encodedOut = setEncodingWrapper(out, IREE::Encoding::MATMUL_RESULT);
Value opTiled = clone(rewriter, linalgOp, encodedOut.getType(),
ValueRange{encodedLhs, encodedRhs, encodedOut})
->getResult(0);
Expand All @@ -248,6 +275,9 @@ class SetContractionOpEncoding final
rewriter.replaceOp(linalgOp, result);
return success();
}

private:
EncodingOptions encodingOption;
};

/// Pattern to fold a `linalg.fill` -> `iree_encoding.set_encoding`
Expand Down Expand Up @@ -281,7 +311,7 @@ struct SetEncodingPass final : impl::SetEncodingPassBase<SetEncodingPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<SetContractionOpEncoding>(context);
patterns.add<SetContractionOpEncoding>(context, encodingOption.getValue());
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldFillWithSetEncoding>(context);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-encoding))" %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-encoding))" %s | FileCheck %s --check-prefixes=CHECK-ALL,CHECK
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-set-encoding{encoding-option=matmulk}))" %s | FileCheck %s --check-prefixes=CHECK-ALL,MATMULK

util.func public @matmul_f32f32f32(%arg0 : tensor<100x250xf32>, %arg1 : tensor<250x500xf32>,
%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<100x250xf32>, tensor<250x500xf32>)
outs(%arg2 : tensor<100x500xf32>) -> tensor<100x500xf32>
util.return %0 : tensor<100x500xf32>
}
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
// CHECK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
// CHECK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
// CHECK: util.func public @matmul_f32f32f32(
// CHECK-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
// CHECK: %[[LHS:.+]] = iree_encoding.set_encoding %[[ARG0]]
// CHECK-SAME: tensor<100x250xf32, #[[LHS_ENCODING]]>
// CHECK: %[[RHS:.+]] = iree_encoding.set_encoding %[[ARG1]]
// CHECK-SAME: tensor<250x500xf32, #[[RHS_ENCODING]]>
// CHECK: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-SAME: tensor<100x500xf32, #[[OUT_ENCODING]]>
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<100x500xf32, #[[OUT_ENCODING]]> -> tensor<100x500xf32>
// CHECK: util.return %[[RESULT]]
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
// CHECK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
// CHECK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [100, 500, 250]>
// MATMULK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
// MATMULK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [0]>
// MATMULK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = []>
// CHECK-ALL: util.func public @matmul_f32f32f32(
// CHECK-ALL-SAME: %[[ARG0:.+]]: tensor<100x250xf32>
// CHECK-ALL-SAME: %[[ARG1:.+]]: tensor<250x500xf32>
// CHECK-ALL-SAME: %[[ARG2:.+]]: tensor<100x500xf32>
// CHECK-ALL: %[[LHS:.+]] = iree_encoding.set_encoding %[[ARG0]]
// CHECK-ALL-SAME: tensor<100x250xf32, #[[LHS_ENCODING]]>
// CHECK-ALL: %[[RHS:.+]] = iree_encoding.set_encoding %[[ARG1]]
// CHECK-ALL-SAME: tensor<250x500xf32, #[[RHS_ENCODING]]>
// CHECK-ALL: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-ALL-SAME: tensor<100x500xf32, #[[OUT_ENCODING]]>
// CHECK-ALL: %[[MATMUL:.+]] = linalg.matmul
// CHECK-ALL-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-ALL-SAME: outs(%[[OUTS]] :
// CHECK-ALL: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<100x500xf32, #[[OUT_ENCODING]]> -> tensor<100x500xf32>
// CHECK-ALL: util.return %[[RESULT]]

// -----

Expand Down Expand Up @@ -72,27 +76,30 @@ util.func public @matmul_f32f32f32_parallel_reduce_parallel(%arg0 : tensor<32x12
} -> tensor<4096x32xf32>
util.return %0 : tensor<4096x32xf32>
}
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
// CHECK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
// CHECK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
// CHECK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
// CHECK: util.func public @matmul_f32f32f32_parallel_reduce_parallel(
// CHECK-SAME: %[[ARG0:.+]]: tensor<32x128xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<128x4096xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<4096x32xf32>
// CHECK: %[[LHS:.+]] = iree_encoding.set_encoding %[[ARG0]]
// CHECK-SAME: tensor<32x128xf32, #[[LHS_ENCODING]]>
// CHECK: %[[RHS:.+]] = iree_encoding.set_encoding %[[ARG1]]
// CHECK-SAME: tensor<128x4096xf32, #[[RHS_ENCODING]]>
// CHECK: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-SAME: tensor<4096x32xf32, #[[OUT_ENCODING]]>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<4096x32xf32, #[[OUT_ENCODING]]> -> tensor<4096x32xf32>
// CHECK: util.return %[[RESULT]]
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
// CHECK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
// CHECK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
// CHECK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]], iteration_sizes = [32, 128, 4096]>
// MATMULK-DAG: #[[LHS_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [1]>
// MATMULK-DAG: #[[RHS_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = [0]>
// MATMULK-DAG: #[[OUT_ENCODING:.+]] = #iree_encoding.matmul_k<k_dims = []>
// CHECK-ALL: util.func public @matmul_f32f32f32_parallel_reduce_parallel(
// CHECK-ALL-SAME: %[[ARG0:.+]]: tensor<32x128xf32>
// CHECK-ALL-SAME: %[[ARG1:.+]]: tensor<128x4096xf32>
// CHECK-ALL-SAME: %[[ARG2:.+]]: tensor<4096x32xf32>
// CHECK-ALL: %[[LHS:.+]] = iree_encoding.set_encoding %[[ARG0]]
// CHECK-ALL-SAME: tensor<32x128xf32, #[[LHS_ENCODING]]>
// CHECK-ALL: %[[RHS:.+]] = iree_encoding.set_encoding %[[ARG1]]
// CHECK-ALL-SAME: tensor<128x4096xf32, #[[RHS_ENCODING]]>
// CHECK-ALL: %[[OUTS:.+]] = iree_encoding.set_encoding %[[ARG2]]
// CHECK-ALL-SAME: tensor<4096x32xf32, #[[OUT_ENCODING]]>
// CHECK-ALL: %[[MATMUL:.+]] = linalg.generic
// CHECK-ALL-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-ALL-SAME: outs(%[[OUTS]] :
// CHECK-ALL: %[[RESULT:.+]] = iree_encoding.unset_encoding %[[MATMUL]] : tensor<4096x32xf32, #[[OUT_ENCODING]]> -> tensor<4096x32xf32>
// CHECK-ALL: util.return %[[RESULT]]

// -----

Expand Down
17 changes: 15 additions & 2 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ static llvm::cl::opt<DemotionOption> clDemoteContractionInputsToBF16Strategy(
clEnumValN(DemotionOption::None, "none", "Demote no contraction ops.")),
llvm::cl::init(DemotionOption::None));

static llvm::cl::opt<DispatchCreation::EncodingOptions> clSetEncodingStrategy(
"iree-global-opt-set-encoding-strategy",
llvm::cl::desc("Set the encoding strategy for operations."),
llvm::cl::values(
clEnumValN(
DispatchCreation::EncodingOptions::Generic, "generic",
"Using EncodingAttr which encodes as much information as possible"),
clEnumValN(DispatchCreation::EncodingOptions::MatmulK, "matmulk",
"Only encodes the reduction dimenesions in the encoding.")),
llvm::cl::init(DispatchCreation::EncodingOptions::Generic));

static llvm::cl::opt<bool> clWarnOnUninitializedValues(
"iree-global-opt-enable-warn-on-uninitialized-values",
llvm::cl::desc("Warn on some classes of uses of uninitialized values."),
Expand Down Expand Up @@ -175,8 +186,10 @@ void buildGlobalOptimizationPassPipeline(

// Enable data tiling after they are in a canonical form.
if (transformOptions.options.dataTiling) {
FunctionLikeNest(mainPassManager)
.addPass(DispatchCreation::createSetEncodingPass);
FunctionLikeNest(mainPassManager).addPass([&]() {
return DispatchCreation::createSetEncodingPass(
DispatchCreation::SetEncodingPassOptions{clSetEncodingStrategy});
});
// TODO(hanchung): Make data-tiling passes be FunctionOpInterface pass, so
// we can use `FunctionLikNest` here.
if (clEnableEarlyMaterialization) {
Expand Down
Loading
0