[MLIR][Affine-loop-fusion] Fix a bug in affine-loop-fusion pass when there are non-affine operations

When there is a mix of affine load/store and non-affine operations (e.g. std.load, std.store),
affine-loop-fusion ignores the present of non-affine ops, thus changing the program semantics.

E.g. we have a program of three affine loops operating on the same memref in which one of them uses std.load and std.store, as follows.
```
affine.for
  affine.store %1
affine.for
  std.load %1
  std.store %1
affine.for
  affine.load %1
  affine.store %1
```
affine-loop-fusion will produce the following result which changed the program semantics:
```
affine.for
  std.load %1
  std.store %1
affine.for
  affine.store %1
  affine.load %1
  affine.store %1
```

This patch is to fix the above problem by checking non-affine users of the memref that are between the source and destination nodes of interest.

Differential Revision: https://reviews.llvm.org/D82158
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 6a7a88e..f71ff2a 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -948,6 +948,65 @@
   return newMemRef;
 }
 
+/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
+/// 'dstId'), if there is any non-affine operation accessing 'memref', return
+/// false. Otherwise, return true.
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+                                       Value memref,
+                                       MemRefDependenceGraph *mdg) {
+  auto *srcNode = mdg->getNode(srcId);
+  auto *dstNode = mdg->getNode(dstId);
+  Value::user_range users = memref.getUsers();
+  // For each MemRefDependenceGraph's node that is between 'srcNode' and
+  // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
+  // non-affine operation in the node accesses the 'memref'.
+  for (auto &idAndNode : mdg->nodes) {
+    Operation *op = idAndNode.second.op;
+    // Take care of operations between 'srcNode' and 'dstNode'.
+    if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
+      // Walk inside the operation to find any use of the memref.
+      // Interrupt the walk if found.
+      auto walkResult = op->walk([&](Operation *user) {
+        // Skip affine ops.
+        if (isMemRefDereferencingOp(*user))
+          return WalkResult::advance();
+        // Find a non-affine op that uses the memref.
+        if (llvm::is_contained(users, user))
+          return WalkResult::interrupt();
+        return WalkResult::advance();
+      });
+      if (walkResult.wasInterrupted())
+        return true;
+    }
+  }
+  return false;
+}
+
+/// Check whether a memref value in node 'srcId' has a non-affine that
+/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
+/// 'dstNode').
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+                                       MemRefDependenceGraph *mdg) {
+  // Collect memref values in node 'srcId'.
+  auto *srcNode = mdg->getNode(srcId);
+  llvm::SmallDenseSet<Value, 2> memRefValues;
+  srcNode->op->walk([&](Operation *op) {
+    // Skip affine ops.
+    if (isa<AffineForOp>(op))
+      return WalkResult::advance();
+    for (Value v : op->getOperands())
+      // Collect memref values only.
+      if (v.getType().isa<MemRefType>())
+        memRefValues.insert(v);
+    return WalkResult::advance();
+  });
+  // Looking for users between node 'srcId' and node 'dstId'.
+  for (Value memref : memRefValues)
+    if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg))
+      return true;
+  return false;
+}
+
 // Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
 // may write to multiple memrefs but it is required that only one of them,
 // 'srcLiveOutStoreOp', has output edges.
@@ -1008,6 +1067,12 @@
   // TODO(andydavis) Check the shape and lower bounds here too.
   if (srcNumElements != dstNumElements)
     return false;
+
+  // Return false if 'memref' is used by a non-affine operation that is
+  // between node 'srcId' and node 'dstId'.
+  if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
+    return false;
+
   return true;
 }
 
@@ -1793,6 +1858,12 @@
       }
       if (storeMemrefs.size() != 1)
         return false;
+
+      // Skip if a memref value in one node is used by a non-affine memref
+      // access that lies between 'dstNode' and 'sibNode'.
+      if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
+          hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
+        return false;
       return true;
     };