Handle MemRefRegion::compute return value in loop fusion pass (NFC).

PiperOrigin-RevId: 236685849
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 2e84d3c..1e4e020 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -1177,7 +1177,9 @@
 
   // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
   MemRefRegion region(srcStoreOpInst->getLoc());
-  region.compute(srcStoreOpInst, dstLoopDepth);
+  bool validRegion = region.compute(srcStoreOpInst, dstLoopDepth);
+  (void)validRegion;
+  assert(validRegion && "unexpected memref region failure");
   SmallVector<int64_t, 4> newShape;
   std::vector<SmallVector<int64_t, 4>> lbs;
   SmallVector<int64_t, 8> lbDivisors;
@@ -1304,7 +1306,11 @@
   // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
   auto *srcStoreOpInst = srcNode->stores.front();
   MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
-  srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0);
+  if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for source operation\n.");
+    return false;
+  }
   SmallVector<int64_t, 4> srcShape;
   // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
   // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
@@ -1319,7 +1325,11 @@
   assert(dstStoreOps.size() == 1);
   auto *dstStoreOpInst = dstStoreOps[0];
   MemRefRegion dstWriteRegion(dstStoreOpInst->getLoc());
-  dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0);
+  if (!dstWriteRegion.compute(dstStoreOpInst, /*loopDepth=*/0)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for dest operation\n.");
+    return false;
+  }
   SmallVector<int64_t, 4> dstShape;
   // Query 'dstWriteRegion' for 'dstShape' and 'dstNumElements'.
   // by 'dstStoreOpInst' at depth 'dstLoopDepth'.
@@ -1444,7 +1454,12 @@
 
   // Compute src loop nest write region size.
   MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
-  srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0);
+  if (!srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Unable to compute MemRefRegion for source operation\n.");
+    return false;
+  }
+
   Optional<int64_t> maybeSrcWriteRegionSizeBytes =
       srcWriteRegion.getRegionSize();
   if (!maybeSrcWriteRegionSizeBytes.hasValue())
@@ -1528,8 +1543,10 @@
     // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
     // nest at loop depth 'i'
     MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
-    sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
-                             &sliceStates[i - 1]);
+    if (!sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
+                                  &sliceStates[i - 1]))
+      continue;
+
     Optional<int64_t> maybeSliceWriteRegionSizeBytes =
         sliceWriteRegion.getRegionSize();
     if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
@@ -1594,8 +1611,10 @@
     return false;
   }
 
-  assert(bestDstLoopDepth.hasValue() &&
-         "expected to have a value per logic above");
+  if (!bestDstLoopDepth.hasValue()) {
+    LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
+    return false;
+  }
 
   // Set dstLoopDepth based on best values from search.
   *dstLoopDepth = bestDstLoopDepth.getValue();