[MLIR][Affine-loop-fusion] Fix a bug in affine-loop-fusion pass when there are non-affine operations
When there is a mix of affine load/store and non-affine operations (e.g. std.load, std.store),
affine-loop-fusion ignores the present of non-affine ops, thus changing the program semantics.
E.g. we have a program of three affine loops operating on the same memref in which one of them uses std.load and std.store, as follows.
```
affine.for
affine.store %1
affine.for
std.load %1
std.store %1
affine.for
affine.load %1
affine.store %1
```
affine-loop-fusion will produce the following result which changed the program semantics:
```
affine.for
std.load %1
std.store %1
affine.for
affine.store %1
affine.load %1
affine.store %1
```
This patch is to fix the above problem by checking non-affine users of the memref that are between the source and destination nodes of interest.
Differential Revision: https://reviews.llvm.org/D82158
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 6a7a88e..f71ff2a 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -948,6 +948,65 @@
return newMemRef;
}
+/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
+/// 'dstId'), if there is any non-affine operation accessing 'memref', return
+/// false. Otherwise, return true.
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+ Value memref,
+ MemRefDependenceGraph *mdg) {
+ auto *srcNode = mdg->getNode(srcId);
+ auto *dstNode = mdg->getNode(dstId);
+ Value::user_range users = memref.getUsers();
+ // For each MemRefDependenceGraph's node that is between 'srcNode' and
+ // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
+ // non-affine operation in the node accesses the 'memref'.
+ for (auto &idAndNode : mdg->nodes) {
+ Operation *op = idAndNode.second.op;
+ // Take care of operations between 'srcNode' and 'dstNode'.
+ if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
+ // Walk inside the operation to find any use of the memref.
+ // Interrupt the walk if found.
+ auto walkResult = op->walk([&](Operation *user) {
+ // Skip affine ops.
+ if (isMemRefDereferencingOp(*user))
+ return WalkResult::advance();
+ // Find a non-affine op that uses the memref.
+ if (llvm::is_contained(users, user))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Check whether a memref value in node 'srcId' has a non-affine that
+/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
+/// 'dstNode').
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+ MemRefDependenceGraph *mdg) {
+ // Collect memref values in node 'srcId'.
+ auto *srcNode = mdg->getNode(srcId);
+ llvm::SmallDenseSet<Value, 2> memRefValues;
+ srcNode->op->walk([&](Operation *op) {
+ // Skip affine ops.
+ if (isa<AffineForOp>(op))
+ return WalkResult::advance();
+ for (Value v : op->getOperands())
+ // Collect memref values only.
+ if (v.getType().isa<MemRefType>())
+ memRefValues.insert(v);
+ return WalkResult::advance();
+ });
+ // Looking for users between node 'srcId' and node 'dstId'.
+ for (Value memref : memRefValues)
+ if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg))
+ return true;
+ return false;
+}
+
// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
// may write to multiple memrefs but it is required that only one of them,
// 'srcLiveOutStoreOp', has output edges.
@@ -1008,6 +1067,12 @@
// TODO(andydavis) Check the shape and lower bounds here too.
if (srcNumElements != dstNumElements)
return false;
+
+ // Return false if 'memref' is used by a non-affine operation that is
+ // between node 'srcId' and node 'dstId'.
+ if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
+ return false;
+
return true;
}
@@ -1793,6 +1858,12 @@
}
if (storeMemrefs.size() != 1)
return false;
+
+ // Skip if a memref value in one node is used by a non-affine memref
+ // access that lies between 'dstNode' and 'sibNode'.
+ if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
+ hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
+ return false;
return true;
};