Handle MemRefRegion::compute return value in loop fusion pass (NFC).
PiperOrigin-RevId: 236685849
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 2e84d3c..1e4e020 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -1177,7 +1177,9 @@
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
MemRefRegion region(srcStoreOpInst->getLoc());
- region.compute(srcStoreOpInst, dstLoopDepth);
+ bool validRegion = region.compute(srcStoreOpInst, dstLoopDepth);
+ (void)validRegion;
+ assert(validRegion && "unexpected memref region failure");
SmallVector<int64_t, 4> newShape;
std::vector<SmallVector<int64_t, 4>> lbs;
SmallVector<int64_t, 8> lbDivisors;
@@ -1304,7 +1306,11 @@
// Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
auto *srcStoreOpInst = srcNode->stores.front();
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
- srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0);
+ if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute MemRefRegion for source operation\n.");
+ return false;
+ }
SmallVector<int64_t, 4> srcShape;
// Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
// by 'srcStoreOpInst' at depth 'dstLoopDepth'.
@@ -1319,7 +1325,11 @@
assert(dstStoreOps.size() == 1);
auto *dstStoreOpInst = dstStoreOps[0];
MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc());
- dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0);
+ if (!dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute MemRefRegion for dest operation\n.");
+ return false;
+ }
SmallVector<int64_t, 4> dstShape;
// Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'.
// by 'dstStoreOpInst' at depth 'dstLoopDepth'.
@@ -1444,7 +1454,12 @@
// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
- srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0);
+ if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute MemRefRegion for source operation\n.");
+ return false;
+ }
+
Optional<int64_t> maybeSrcWriteRegionSizeBytes =
srcWriteRegion.getRegionSize();
if (!maybeSrcWriteRegionSizeBytes.hasValue())
@@ -1528,8 +1543,10 @@
// nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
// nest at loop depth 'i'
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
- sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
- &sliceStates[i - 1]);
+ if (!sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
+ &sliceStates[i - 1]))
+ continue;
+
Optional<int64_t> maybeSliceWriteRegionSizeBytes =
sliceWriteRegion.getRegionSize();
if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
@@ -1594,8 +1611,10 @@
return false;
}
- assert(bestDstLoopDepth.hasValue() &&
- "expected to have a value per logic above");
+ if (!bestDstLoopDepth.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
+ return false;
+ }
// Set dstLoopDepth based on best values from search.
*dstLoopDepth = bestDstLoopDepth.getValue();