-
Notifications
You must be signed in to change notification settings - Fork 702
[Im2Col] Support converting group convs to im2col #20611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
collectDimExprs(inputMap.getResults(), inputDimsSet); | ||
collectDimExprs(filterMap.getResults(), filterDimsSet); | ||
|
||
// Get shared dims from input and filter in order of appearance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous logic included dimensions shared between the input and filter, but group dimensions are included in that and we don't want them here.
On MI300X this results in a ~90% reduction in execution time on a number of configurations (benchmarked using boo_driver):
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I have a few minor questions and comments.
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp
Show resolved
Hide resolved
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col | ||
// CHECK-SAME: strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3] | ||
// CHECK-SAME: m_offset = [0, 0] * [8, 1] k_offset = [0] * [1] | ||
// CHECK-SAME: batch_pos = [3, 0] m_pos = [1, 2] k_pos = [4] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering why the group dimension is bubbled out like this to the front, to me doing [0,3] seems more natural.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's trivial to change, I can put the batch dims first if preferred. Does this materially affect the code generation? I was curious if this affects performance but in my earlier tests it didn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it creates a transpose access pattern when writing the results back, it might matter if we start doing something fancy there but you are right might not make a difference now but lets change to [0,3] so atleast we tried to keep the dimensions in the same sequence within the limits of what im2col needs which is (batch,M,reductions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed the ordering here: 97e70b2
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv_to_im2col.mlir
Outdated
Show resolved
Hide resolved
int64_t igemmInputDim = igemmConvDetails.getIgemmInputImageMap() | ||
.getResultPosition(dimExpr) | ||
.value(); | ||
batchPos[igemmInputDim] = im2colInputDim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this relying on say a dim like "d3" in the original map stays d3 in the new map. Wondering if this is a fairly safe assumption and if not should we guard against a wrong mapping if that doesnt happen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, I think with linalg.generic
it should be possible to come up with a case where they don't match. let me try and come up with a test+fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've fixed this here by including the dimension mapping in igemmConvDetails
and using that to map to the correct dimension: 9f5a51c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Just some minor comments.
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConvToIm2ColOp.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv_to_im2col.mlir
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/conv_to_im2col.mlir
Outdated
Show resolved
Hide resolved
// CHECK-SAME: %[[IMG:.+]]: [[IMG_T:tensor<1x10x10x7x4xf32>]] | ||
// CHECK-SAME: %[[FIL:.+]]: [[FIL_T:tensor<7x16x3x3x4xf32>]] | ||
// CHECK-SAME: %[[OUT:.+]]: [[OUT_T:tensor<1x8x8x7x16xf32>]] | ||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : [[LHS_T:tensor<7x1x8x8x36xf32>]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it require the g dimension
to be in front? Does it work when the n dimension
is larger than 1? It's better to modify either of these two tests to cover a case that n>1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the dimension order and size doesn't matter. I've updated the tests to use a non-unit batch size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes and improvement! LGTM.
This adds support for converting group convs to im2col, allowing them to go down the IGEMM path. Group dimensions are parallel iterator dims that index into the image, filter, and output. For im2col they are treated as a batch dimension. This also fixes iree-org#20498
This adds support for converting group convs to im2col, allowing them to go down the IGEMM path.
Group dimensions are parallel iterator dims that index into the image, filter, and output. For im2col they are treated as a batch dimension.
This also fixes #20498