8000 [Im2Col] Support converting group convs to im2col by rkayaith · Pull Request #20611 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Im2Col] Support converting group convs to im2col #20611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
< 10000 td class="blob-code blob-code-deletion js-file-line"> collectDimExprs({binOpExpr.getLHS(), binOpExpr.getRHS()}, out);
Original file line number Diff line number Diff line change
Expand Up @@ -62,54 +62,38 @@ static SmallVector<int64_t> getBasisFromShape(ArrayRef<int64_t> shape) {
return basis;
}

// Collect all AffineDimExprs from an AffineExpr.
static void collectDimExprs(ArrayRef<AffineExpr> exprs,
DenseSet<AffineExpr> &out) {
for (auto &expr : exprs) {
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
out.insert(dimExpr);
} else if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
} else {
LLVM_DEBUG(llvm::dbgs()
<< "Non-dimension expression found: " << expr << "\n");
}
}
}

// Computes `inputKPerm` that maps the input spatial and channel dimension order
// to filter's.
static SmallVector<int64_t> computeInputKPerm(AffineMap inputMap,
AffineMap filterMap) {
DenseSet<AffineExpr> inputDimsSet;
DenseSet<AffineExpr> filterDimsSet;
collectDimExprs(inputMap.getResults(), inputDimsSet);
collectDimExprs(filterMap.getResults(), filterDimsSet);

// Get shared dims from input and filter in order of appearance.
Copy link
Member Author
@rkayaith rkayaith Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous logic included dimensions shared between the input and filter, but group dimensions are included in that and we don't want them here.

SmallVector<AffineExpr> inputSharedDims;
SmallVector<AffineExpr> filterSharedDims;
for (AffineExpr expr : inputMap.getResults()) {
expr.walk([&](AffineExpr dimExpr) {
if (filterDimsSet.contains(dimExpr)) {
inputSharedDims.push_back(dimExpr);
static SmallVector<int64_t>
computeInputKPerm(AffineMap inputMap, AffineMap filterMap,
const mlir::linalg::ConvolutionDimensions &convDims) {
// Get reduction dims from input and filter in order of appearance.
auto reductionDims =
llvm::concat<const unsigned>(convDims.inputChannel, convDims.filterLoop);
SmallVector<int64_t> inputReductionDims;
for (AffineExpr dimExpr : inputMap.getResults()) {
for (unsigned reductionDim : reductionDims) {
if (dimExpr.isFunctionOfDim(reductionDim)) {
inputReductionDims.push_back(reductionDim);
}
});
}
}
for (AffineExpr expr : filterMap.getResults()) {
expr.walk([&](AffineExpr dimExpr) {
if (inputDimsSet.contains(dimExpr)) {
filterSharedDims.push_back(dimExpr);
SmallVector<int64_t> filterReductionDims;
for (AffineExpr dimExpr : filterMap.getResults()) {
for (unsigned reductionDim : reductionDims) {
if (dimExpr.isFunctionOfDim(reductionDim)) {
filterReductionDims.push_back(reductionDim);
}
});
}
}

// Compute the permutation that maps inputSharedDims to filterSharedDims.
SmallVector<int64_t> inputKPerm;
for (AffineExpr filterExpr : filterSharedDims) {
auto it = llvm::find(inputSharedDims, filterExpr);
assert(it != inputSharedDims.end() &&
for (int64_t dim : filterReductionDims) {
auto it = llvm::find(inputReductionDims, dim);
assert(it != inputReductionDims.end() &&
"Filter dimension not found in input shared dimensions");
inputKPerm.push_back(std::distance(inputSharedDims.begin(), it));
inputKPerm.push_back(std::distance(inputReductionDims.begin(), it));
}
return inputKPerm;
}
Expand Down Expand Up @@ -211,18 +195,20 @@ class ConvertConvGeneric final
rewriter.getIndexAttr(filterShape[maybeDim.value()]));
}

// Shape of the resulting tensor from im2col.
SmallVector<int64_t> colTensorShape;
SmallVector<int64_t> batchPos;
for (auto batch : convDims.batch) {
std::optional<int64_t> maybeBatch = inputMap.getResultPosition(
getAffineDimExpr(batch, inputMap.getContext()));
if (!maybeBatch) {
return rewriter.notifyMatchFailure(linalgOp,
"Failed to infer batch shape.");
}
batchPos.push_back(maybeBatch.value());
colTensorShape.push_back(inputShape[maybeBatch.value()]);
// Batch dims for the im2col also include the depth/group dimensions of the
// conv.
auto im2colBatchIterDims =
llvm::to_vector(llvm::concat<unsigned>(convDims.depth, convDims.batch));
SmallVector<int64_t> batchPos(im2colBatchIterDims.size());
for (int64_t convDim : im2colBatchIterDims) {
AffineExpr convDimExpr = getAffineDimExpr(convDim, getContext());
int64_t im2colInputDim = inputMap.getResultPosition(convDimExpr).value();

AffineExpr igemmDimExpr = igemmConvDetails.convToIgemmDimMap.at(convDim);
int64_t igemmInputDim = igemmConvDetails.getIgemmInputImageMap()
.getResultPosition(igemmDimExpr)
.value();
batchPos[igemmInputDim] = im2colInputDim;
Copy link
Contributor
@nirvedhmeshram nirvedhmeshram Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this relying on say a dim like "d3" in the original map stays d3 in the new map. Wondering if this is a fairly safe assumption and if not should we guard against a wrong mapping if that doesnt happen.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I think with linalg.generic it should be possible to come up with a case where they don't match. let me try and come up with a test+fix

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed this here by including the dimension mapping in igemmConvDetails and using that to map to the correct dimension: 9f5a51c

}

SmallVector<int64_t> mPos;
Expand All @@ -236,7 +222,6 @@ class ConvertConvGeneric final
for (auto [idx, e] : llvm::enumerate(outputMap.getResults())) {
if (e.isFunctionOfDim(outputImage)) {
mShape.push_back(outputShape[idx]);
colTensorShape.push_back(outputShape[idx]);
}
}
}
Expand All @@ -251,12 +236,11 @@ class ConvertConvGeneric final
}
// The index at which the reduction dimension bounds starts in
// igemmLoopBounds.
int64_t reductionBoundIndex = convDims.batch.size() +
convDims.outputImage.size() +
convDims.outputChannel.size();
int64_t reductionBoundIndex =
convDims.batch.size() + convDims.depth.size() +
convDims.outputImage.size() + convDims.outputChannel.size();
SmallVector<int64_t> kShape(igemmLoopBounds.begin() + reductionBoundIndex,
igemmLoopBounds.end());
colTensorShape.insert(colTensorShape.end(), kShape.begin(), kShape.end());

SmallVector<OpFoldResult> mBasis =
getAsIndexOpFoldResult(getContext(), getBasisFromShape(mShape));
Expand All @@ -266,9 +250,17 @@ class ConvertConvGeneric final
SmallVector<OpFoldResult> kOffset(kBasis.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> mOffset(mBasis.size(), rewriter.getIndexAttr(0));

SmallVector<int64_t> inputKPerm = computeInputKPerm(inputMap, filterMap);
SmallVector<int64_t> inputKPerm =
computeInputKPerm(inputMap, filterMap, convDims);

auto loc = linalgOp.getLoc();
// Shape of the resulting tensor from im2col.
SmallVector<int64_t> colTensorShape;
for (int64_t dim : batchPos) {
colTensorShape.push_back(inputShape[dim]);
}
colTensorShape.append(mShape);
colTensorShape.append(kShape);
Value colTensor = rewriter.create<tensor::EmptyOp>(
loc, colTensorShape, inputType.getElementType());
Value img2ColTensor =
Expand Down
CE3F
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,111 @@ util.func public @conv_1d_nhc_chf(%arg0: tensor<1x3x2xf32>, %arg1: tensor<2x2x2x
// CHECK-SAME: input_k_perm = [1, 0]
// CHECK-SAME: ins({{.*}} : tensor<1x3x2xf32>)
// CHECK-SAME: outs({{.*}} : tensor<1x2x4xf32>) -> tensor<1x2x4xf32>

// -----

util.func public @conv_2d_nhwgc_gfhwc(%arg0: tensor<2x10x10x7x4xf32>, %arg1: tensor<7x16x3x3x4xf32>, %arg2: tensor<2x8x8x7x16xf32>) -> tensor<2x8x8x7x16xf32> {
%0 = linalg.conv_2d_nhwgc_gfhwc
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1: tensor<2x10x10x7x4xf32>, tensor<7x16x3x3x4xf32>)
outs(%arg2: tensor<2x8x8x7x16xf32>) -> tensor<2x8x8x7x16xf32>
util.return %0 : tensor<2x8x8x7x16xf32>
}
// n h w g f c
// CHECK-DAG: #[[LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2, d5)>
// CHECK-DAG: #[[RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// CHECK-DAG: #[[OUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>
// CHECK: util.func public @conv_2d_nhwgc_gfhwc(
// CHECK-SAME: %[[IMG:.+]]: [[IMG_T:tensor<2x10x10x7x4xf32>]]
// CHECK-SAME: %[[FIL:.+]]: [[FIL_T:tensor<7x16x3x3x4xf32>]]
// CHECK-SAME: %[[OUT:.+]]: [[OUT_T:tensor<2x8x8x7x16xf32>]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : [[LHS_T:tensor<2x7x8x8x36xf32>]]
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
// CHECK-SAME: m_offset = [0, 0] * [8, 1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0, 3] m_pos = [1, 2] k_pos = [4]
// CHECK-SAME: input_k_perm = [0, 1, 2]
// CHECK-SAME: ins(%[[IMG]] : [[IMG_T]])
// CHECK-SAME: outs(%[[EMPTY]] : [[LHS_T]])
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FIL]] {{\[}}[0], [1], [2, 3, 4]] : [[FIL_T]] into [[RHS_T:tensor<7x16x36xf32>]]
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[LHS_MAP]], #[[RHS_MAP]], #[[OUT_MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[COLLAPSED]] : [[LHS_T]], [[RHS_T]])
// CHECK-SAME: outs(%[[OUT]] : [[OUT_T]]) {
// CHECK: }
// CHECK: util.return %[[MATMUL]]

// -----

util.func public @conv_2d_ngchw_fgchw(%arg0: tensor<2x7x4x10x10xf32>, %arg1: tensor<16x7x4x3x3xf32>, %arg2: tensor<2x7x16x8x8xf32>) -> tensor<2x7x16x8x8xf32> {
%0 = linalg.conv_2d_ngchw_fgchw
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1: tensor<2x7x4x10x10xf32>, tensor<16x7x4x3x3xf32>)
outs(%arg2: tensor<2x7x16x8x8xf32>) -> tensor<2x7x16x8x8xf32>
util.return %0 : tensor<2x7x16x8x8xf32>
}
// n g f h w c
// CHECK-DAG: #[[LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5)>
// CHECK-DAG: #[[RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4, d5)>
// CHECK-DAG: #[[OUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>
// CHECK: util.func public @conv_2d_ngchw_fgchw(
// CHECK-SAME: %[[IMG:.+]]: [[IMG_T:tensor<2x7x4x10x10xf32>]]
// CHECK-SAME: %[[FIL:.+]]: [[FIL_T:tensor<16x7x4x3x3xf32>]]
// CHECK-SAME: %[[OUT:.+]]: [[OUT_T:tensor<2x7x16x8x8xf32>]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : [[RHS_T:tensor<2x7x8x8x36xf32>]]
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
// CHECK-SAME: m_offset = [0, 0] * [8, 1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [0, 1] m_pos = [3, 4] k_pos = [2]
// CHECK-SAME: input_k_perm = [0, 1, 2]
// CHECK-SAME: ins(%[[IMG]] : [[IMG_T]])
// CHECK-SAME: outs(%[[EMPTY]] : [[LHS_T]])
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FIL]] {{\[}}[0], [1], [2, 3, 4]] : [[FIL_T]] into [[LHS_T:tensor<16x7x36xf32>]]
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[LHS_MAP]], #[[RHS_MAP]], #[[OUT_MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[COLLAPSED]], %[[IM2COL]] : [[LHS_T]], [[RHS_T]])
// CHECK-SAME: outs(%[[OUT]] : [[OUT_T]]) {
// CHECK: }
// CHECK: util.return %[[MATMUL]]

// -----
// n g h w f c kh kw
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d0, d2, d3, d4)>
// Output has 'n' and 'g' dimensions transposed.
util.func public @conv_2d_ngchw_fgchw_gnfhw(%arg0: tensor<2x7x4x10x10xf32>, %arg1: tensor<16x7x4x3x3xf32>, %arg2: tensor<7x2x16x8x8xf32>) -> tensor<7x2x16x8x8xf32> {
%0 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
} ins(%arg0, %arg1 : tensor<2x7x4x10x10xf32>, tensor<16x7x4x3x3xf32>) outs(%arg2 : tensor<7x2x16x8x8xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
%2 = arith.addf %out, %1 : f32
linalg.yield %2 : f32
} -> tensor<7x2x16x8x8xf32>
util.return %0 : tensor<7x2x16x8x8xf32>
}
// g n f h w c
// CHECK-DAG: #[[LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d5)>
// CHECK-DAG: #[[RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4, d5)>
// CHECK-DAG: #[[OUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>
// CHECK: util.func public @conv_2d_ngchw_fgchw_gnfhw(
// CHECK-SAME: %[[IMG:.+]]: [[IMG_T:tensor<2x7x4x10x10xf32>]]
// CHECK-SAME: %[[FIL:.+]]: [[FIL_T:tensor<16x7x4x3x3xf32>]]
// CHECK-SAME: %[[OUT:.+]]: [[OUT_T:tensor<7x2x16x8x8xf32>]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : [[RHS_T:tensor<2x7x8x8x36xf32>]]
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: batch_pos = [0, 1] m_pos = [3, 4] k_pos = [2]
// CHECK-SAME: ins(%[[IMG]] : [[IMG_T]])
// CHECK-SAME: outs(%[[EMPTY]] : [[RHS_T]])
// CHECK-DAG: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FIL]] {{\[}}[0], [1], [2, 3, 4]] : [[FIL_T]] into [[LHS_T:tensor<16x7x36xf32>]]
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[LHS_MAP]], #[[RHS_MAP]], #[[OUT_MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[COLLAPSED]], %[[IM2COL]] : [[LHS_T]], [[RHS_T]])
// CHECK-SAME: outs(%[[OUT]] : [[OUT_T]]) {
// CHECK: }
// CHECK: util.return %[[MATMUL]]
Loading
Loading
0