-
Notifications
You must be signed in to change notification settings - Fork 702
[Im2Col] Support converting group convs to im2col #20611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,54 +62,38 @@ static SmallVector<int64_t> getBasisFromShape(ArrayRef<int64_t> shape) { | |
return basis; | ||
} | ||
|
||
// Collect all AffineDimExprs from an AffineExpr. | ||
static void collectDimExprs(ArrayRef<AffineExpr> exprs, | ||
DenseSet<AffineExpr> &out) { | ||
for (auto &expr : exprs) { | ||
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { | ||
out.insert(dimExpr); | ||
} else if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(expr)) { | ||
< 10000 td class="blob-code blob-code-deletion js-file-line"> collectDimExprs({binOpExpr.getLHS(), binOpExpr.getRHS()}, out); | ||
} else { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "Non-dimension expression found: " << expr << "\n"); | ||
} | ||
} | ||
} | ||
|
||
// Computes `inputKPerm` that maps the input spatial and channel dimension order | ||
// to filter's. | ||
static SmallVector<int64_t> computeInputKPerm(AffineMap inputMap, | ||
AffineMap filterMap) { | ||
DenseSet<AffineExpr> inputDimsSet; | ||
DenseSet<AffineExpr> filterDimsSet; | ||
collectDimExprs(inputMap.getResults(), inputDimsSet); | ||
collectDimExprs(filterMap.getResults(), filterDimsSet); | ||
|
||
// Get shared dims from input and filter in order of appearance. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous logic included dimensions shared between the input and filter, but group dimensions are included in that and we don't want them here. |
||
SmallVector<AffineExpr> inputSharedDims; | ||
SmallVector<AffineExpr> filterSharedDims; | ||
for (AffineExpr expr : inputMap.getResults()) { | ||
expr.walk([&](AffineExpr dimExpr) { | ||
if (filterDimsSet.contains(dimExpr)) { | ||
inputSharedDims.push_back(dimExpr); | ||
static SmallVector<int64_t> | ||
computeInputKPerm(AffineMap inputMap, AffineMap filterMap, | ||
const mlir::linalg::ConvolutionDimensions &convDims) { | ||
// Get reduction dims from input and filter in order of appearance. | ||
auto reductionDims = | ||
llvm::concat<const unsigned>(convDims.inputChannel, convDims.filterLoop); | ||
SmallVector<int64_t> inputReductionDims; | ||
for (AffineExpr dimExpr : inputMap.getResults()) { | ||
for (unsigned reductionDim : reductionDims) { | ||
if (dimExpr.isFunctionOfDim(reductionDim)) { | ||
inputReductionDims.push_back(reductionDim); | ||
} | ||
}); | ||
} | ||
} | ||
for (AffineExpr expr : filterMap.getResults()) { | ||
expr.walk([&](AffineExpr dimExpr) { | ||
if (inputDimsSet.contains(dimExpr)) { | ||
filterSharedDims.push_back(dimExpr); | ||
SmallVector<int64_t> filterReductionDims; | ||
for (AffineExpr dimExpr : filterMap.getResults()) { | ||
for (unsigned reductionDim : reductionDims) { | ||
if (dimExpr.isFunctionOfDim(reductionDim)) { | ||
filterReductionDims.push_back(reductionDim); | ||
} | ||
}); | ||
} | ||
} | ||
|
||
// Compute the permutation that maps inputSharedDims to filterSharedDims. | ||
SmallVector<int64_t> inputKPerm; | ||
for (AffineExpr filterExpr : filterSharedDims) { | ||
auto it = llvm::find(inputSharedDims, filterExpr); | ||
assert(it != inputSharedDims.end() && | ||
for (int64_t dim : filterReductionDims) { | ||
auto it = llvm::find(inputReductionDims, dim); | ||
assert(it != inputReductionDims.end() && | ||
"Filter dimension not found in input shared dimensions"); | ||
inputKPerm.push_back(std::distance(inputSharedDims.begin(), it)); | ||
inputKPerm.push_back(std::distance(inputReductionDims.begin(), it)); | ||
} | ||
return inputKPerm; | ||
} | ||
|
@@ -211,18 +195,20 @@ class ConvertConvGeneric final | |
rewriter.getIndexAttr(filterShape[maybeDim.value()])); | ||
} | ||
|
||
// Shape of the resulting tensor from im2col. | ||
SmallVector<int64_t> colTensorShape; | ||
SmallVector<int64_t> batchPos; | ||
for (auto batch : convDims.batch) { | ||
std::optional<int64_t> maybeBatch = inputMap.getResultPosition( | ||
getAffineDimExpr(batch, inputMap.getContext())); | ||
if (!maybeBatch) { | ||
return rewriter.notifyMatchFailure(linalgOp, | ||
"Failed to infer batch shape."); | ||
} | ||
batchPos.push_back(maybeBatch.value()); | ||
colTensorShape.push_back(inputShape[maybeBatch.value()]); | ||
// Batch dims for the im2col also include the depth/group dimensions of the | ||
// conv. | ||
auto im2colBatchIterDims = | ||
llvm::to_vector(llvm::concat<unsigned>(convDims.depth, convDims.batch)); | ||
SmallVector<int64_t> batchPos(im2colBatchIterDims.size()); | ||
for (int64_t convDim : im2colBatchIterDims) { | ||
AffineExpr convDimExpr = getAffineDimExpr(convDim, getContext()); | ||
int64_t im2colInputDim = inputMap.getResultPosition(convDimExpr).value(); | ||
|
||
AffineExpr igemmDimExpr = igemmConvDetails.convToIgemmDimMap.at(convDim); | ||
int64_t igemmInputDim = igemmConvDetails.getIgemmInputImageMap() | ||
.getResultPosition(igemmDimExpr) | ||
.value(); | ||
batchPos[igemmInputDim] = im2colInputDim; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this relying on say a dim like "d3" in the original map stays d3 in the new map. Wondering if this is a fairly safe assumption and if not should we guard against a wrong mapping if that doesnt happen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point, I think with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've fixed this here by including the dimension mapping in |
||
} | ||
|
||
SmallVector<int64_t> mPos; | ||
|
@@ -236,7 +222,6 @@ class ConvertConvGeneric final | |
for (auto [idx, e] : llvm::enumerate(outputMap.getResults())) { | ||
if (e.isFunctionOfDim(outputImage)) { | ||
mShape.push_back(outputShape[idx]); | ||
colTensorShape.push_back(outputShape[idx]); | ||
} | ||
} | ||
} | ||
|
@@ -251,12 +236,11 @@ class ConvertConvGeneric final | |
} | ||
// The index at which the reduction dimension bounds starts in | ||
// igemmLoopBounds. | ||
int64_t reductionBoundIndex = convDims.batch.size() + | ||
convDims.outputImage.size() + | ||
convDims.outputChannel.size(); | ||
int64_t reductionBoundIndex = | ||
convDims.batch.size() + convDims.depth.size() + | ||
convDims.outputImage.size() + convDims.outputChannel.size(); | ||
SmallVector<int64_t> kShape(igemmLoopBounds.begin() + reductionBoundIndex, | ||
igemmLoopBounds.end()); | ||
colTensorShape.insert(colTensorShape.end(), kShape.begin(), kShape.end()); | ||
|
||
SmallVector<OpFoldResult> mBasis = | ||
getAsIndexOpFoldResult(getContext(), getBasisFromShape(mShape)); | ||
|
@@ -266,9 +250,17 @@ class ConvertConvGeneric final | |
SmallVector<OpFoldResult> kOffset(kBasis.size(), rewriter.getIndexAttr(0)); | ||
SmallVector<OpFoldResult> mOffset(mBasis.size(), rewriter.getIndexAttr(0)); | ||
|
||
SmallVector<int64_t> inputKPerm = computeInputKPerm(inputMap, filterMap); | ||
SmallVector<int64_t> inputKPerm = | ||
computeInputKPerm(inputMap, filterMap, convDims); | ||
|
||
auto loc = linalgOp.getLoc(); | ||
// Shape of the resulting tensor from im2col. | ||
SmallVector<int64_t> colTensorShape; | ||
for (int64_t dim : batchPos) { | ||
colTensorShape.push_back(inputShape[dim]); | ||
} | ||
colTensorShape.append(mShape); | ||
colTensorShape.append(kShape); | ||
Value colTensor = rewriter.create<tensor::EmptyOp>( | ||
loc, colTensorShape, inputType.getElementType()); | ||
Value img2ColTensor = | ||
|
Uh oh!
There was an error while loading. Please reload this page.