Refactor / improve replaceAllMemRefUsesWith
Refactor replaceAllMemRefUsesWith to split it into two methods: the new
method does the replacement on a single op, and is used by the existing
one.
- make the methods return LogicalResult instead of bool
- Earlier, when replacement failed (due to non-deferencing uses of the
memref), the set of ops that had already been processed would have
been replaced leaving the IR in an inconsistent state. Now, a
pass is made over all ops to first check for non-deferencing
uses, and then replacement is performed. No test cases were affected
because all clients of this method were first checking for
non-deferencing uses before calling this method (for other reasons).
This isn't true for a use case in another upcoming PR (scalar
replacement); clients can now bail out with consistent IR on failure
of replaceAllMemRefUsesWith. Add test case.
- multiple deferencing uses of the same memref in a single op is
possible (we have no such use cases/scenarios), and this has always
remained unsupported. Add an assertion for this.
- minor fix to another test pipeline-data-transfer case.
Signed-off-by: Uday Bondhugula <[email protected]>
Closes tensorflow/mlir#87
PiperOrigin-RevId: 265808183
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 46713dc..a17481f 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -952,12 +952,13 @@
? AffineMap()
: b.getAffineMap(outerIVs.size() + rank, 0, remapExprs);
// Replace all users of 'oldMemRef' with 'newMemRef'.
- bool ret =
+ LogicalResult res =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
/*domInstFilter=*/&*forOp.getBody()->begin());
- assert(ret && "replaceAllMemrefUsesWith should always succeed here");
- (void)ret;
+ assert(succeeded(res) &&
+ "replaceAllMemrefUsesWith should always succeed here");
+ (void)res;
return newMemRef;
}