8000 [Im2Col] Support converting group convs to im2col by rkayaith · Pull Request #20611 · iree-org/iree · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Im2Col] Support converting group convs to im2col #20611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 23, 2025

Conversation

rkayaith
Copy link
Member
@rkayaith rkayaith commented Apr 23, 2025

This adds support for converting group convs to im2col, allowing them to go down the IGEMM path.

Group dimensions are parallel iterator dims that index into the image, filter, and output. For im2col they are treated as a batch dimension.

This also fixes #20498

collectDimExprs(inputMap.getResults(), inputDimsSet);
collectDimExprs(filterMap.getResults(), filterDimsSet);

// Get shared dims from input and filter in order of appearance.
Copy link
Member Author
@rkayaith rkayaith Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous logic included dimensions shared between the input and filter, but group dimensions are included in that and we don't want them here.

@rkayaith
Copy link
Member Author

On MI300X this results in a ~90% reduction in execution time on a number of configurations (benchmarked using boo_driver):

convbfp16 -n 2 -c 896 -H 59 -W 91 -k 896 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 16 -F 1 -t 1              -5033.66 (-94.3%)
convbfp16 -n 2 -c 448 -H 118 -W 182 -k 448 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 8 -F 1 -t 1             -9760.99 (-91.4%)
convbfp16 -n 2 -c 224 -H 470 -W 725 -k 224 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 4 -F 1 -t 1             -11450.01 (-93.5%)
convbfp16 -n 2 -c 224 -H 235 -W 363 -k 224 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 4 -F 1 -t 1             -11380.67 (-94.2%)
convbfp16 -n 2 -c 448 -H 235 -W 363 -k 448 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 8 -F 1 -t 1             -9761.35 (-90.8%)
convbfp16 -n 2 -c 896 -H 118 -W 182 -k 896 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 16 -F 1 -t 1            -5039.59 (-93.7%)
convbfp16 -n 2 -c 2016 -H 59 -W 91 -k 2016 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 --in_layout NHWC --fil_layout NHWC --out_layout NHWC -m conv -g 36 -F 1 -t 1            -1504.16 (-84.7%)

@rkayaith rkayaith marked this pull request as ready for review April 23, 2025 04:03
@rkayaith rkayaith requested a review from nirvedhmeshram April 23, 2025 04:04
@nirvedhmeshram nirvedhmeshram requested a review from yzhang93 April 23, 2025 16:20
Copy link
Contributor
@nirvedhmeshram nirvedhmeshram left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I have a few minor questions and comments.

// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
// CHECK-SAME: m_offset = [0, 0] * [8, 1] k_offset = [0] * [1]
// CHECK-SAME: batch_pos = [3, 0] m_pos = [1, 2] k_pos = [4]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering why the group dimension is bubbled out like this to the front, to me doing [0,3] seems more natural.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's trivial to change, I can put the batch dims first if preferred. Does this materially affect the code generation? I was curious if this affects performance but in my earlier tests it didn't.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it creates a transpose access pattern when writing the results back, it might matter if we start doing something fancy there but you are right might not make a difference now but lets change to [0,3] so atleast we tried to keep the dimensions in the same sequence within the limits of what im2col needs which is (batch,M,reductions)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the ordering here: 97e70b2

int64_t igemmInputDim = igemmConvDetails.getIgemmInputImageMap()
.getResultPosition(dimExpr)
.value();
batchPos[igemmInputDim] = im2colInputDim;
Copy link
Contributor
@nirvedhmeshram nirvedhmeshram Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this relying on say a dim like "d3" in the original map stays d3 in the new map. Wondering if this is a fairly safe assumption and if not should we guard against a wrong mapping if that doesnt happen.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I think with linalg.generic it should be possible to come up with a case where they don't match. let me try and come up with a test+fix

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed this here by including the dimension mapping in igemmConvDetails and using that to map to the correct dimension: 9f5a51c

Copy link
Contributor
@yzhang93 yzhang93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Just some minor comments.

// CHECK-SAME: %[[IMG:.+]]: [[IMG_T:tensor<1x10x10x7x4xf32>]]
// CHECK-SAME: %[[FIL:.+]]: [[FIL_T:tensor<7x16x3x3x4xf32>]]
// CHECK-SAME: %[[OUT:.+]]: [[OUT_T:tensor<1x8x8x7x16xf32>]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : [[LHS_T:tensor<7x1x8x8x36xf32>]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it require the g dimension to be in front? Does it work when the n dimension is larger than 1? It's better to modify either of these two tests to cover a case that n>1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dimension order and size doesn't matter. I've updated the tests to use a non-unit batch size

Copy link
Contributor
@nirvedhmeshram nirvedhmeshram left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor
@yzhang93 yzhang93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes and improvement! LGTM.

@rkayaith rkayaith merged commit b86ed92 into iree-org:main Apr 23, 2025
41 checks passed
@rkayaith rkayaith deleted the group-conv-im2col branch April 23, 2025 21:29
KyleHerndon pushed a commit to KyleHerndon/iree that referenced this pull request May 7, 2025
This adds support for converting group convs to im2col, allowing them to
go down the IGEMM path.

Group dimensions are parallel iterator dims that index into the image,
filter, and output. For im2col they are treated as a batch dimension.

This also fixes iree-org#20498
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[GPU] compilation failure for alternative bwd grouped conv
3 participants
0