Refactor / improve replaceAllMemRefUsesWith

Refactor replaceAllMemRefUsesWith to split it into two methods: the new
method does the replacement on a single op, and is used by the existing
one.

- make the methods return LogicalResult instead of bool

- Earlier, when replacement failed (due to non-deferencing uses of the
  memref), the set of ops that had already been processed would have
  been replaced leaving the IR in an inconsistent state. Now, a
  pass is made over all ops to first check for non-deferencing
  uses, and then replacement is performed. No test cases were affected
  because all clients of this method were first checking for
  non-deferencing uses before calling this method (for other reasons).
  This isn't true for a use case in another upcoming PR (scalar
  replacement); clients can now bail out with consistent IR on failure
  of replaceAllMemRefUsesWith. Add test case.

- multiple deferencing uses of the same memref in a single op is
  possible (we have no such use cases/scenarios), and this has always
  remained unsupported. Add an assertion for this.

- minor fix to another test pipeline-data-transfer case.

Signed-off-by: Uday Bondhugula <[email protected]>

Closes tensorflow/mlir#87

PiperOrigin-RevId: 265808183
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 46713dc..a17481f 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -952,12 +952,13 @@
                         ? AffineMap()
                         : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs);
   // Replace all users of 'oldMemRef' with 'newMemRef'.
-  bool ret =
+  LogicalResult res =
       replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
                                /*extraOperands=*/outerIVs,
                                /*domInstFilter=*/&*forOp.getBody()->begin());
-  assert(ret && "replaceAllMemrefUsesWith should always succeed here");
-  (void)ret;
+  assert(succeeded(res) &&
+         "replaceAllMemrefUsesWith should always succeed here");
+  (void)res;
   return newMemRef;
 }