Loop fusion improvements:
*) After a private memref buffer is created for a fused loop nest, dependences on the old memref are reduced, which can open up fusion opportunities. In these cases, users of the old memref are added back to the worklist to be reconsidered for fusion.
*) Fixed a bug in fusion insertion point dependence check where the memref being privatized was being skipped from the check.

PiperOrigin-RevId: 232477853
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 77e5a6a..d7e1b61 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -365,21 +365,20 @@
   // Computes and returns an insertion point instruction, before which the
   // the fused <srcId, dstId> loop nest can be inserted while preserving
   // dependences. Returns nullptr if no such insertion point is found.
-  Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId,
-                                              Value *memrefToSkip) {
+  Instruction *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
     if (outEdges.count(srcId) == 0)
       return getNode(dstId)->inst;
 
     // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
     SmallPtrSet<Instruction *, 2> srcDepInsts;
     for (auto &outEdge : outEdges[srcId])
-      if (outEdge.id != dstId && outEdge.value != memrefToSkip)
+      if (outEdge.id != dstId)
         srcDepInsts.insert(getNode(outEdge.id)->inst);
 
     // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
     SmallPtrSet<Instruction *, 2> dstDepInsts;
     for (auto &inEdge : inEdges[dstId])
-      if (inEdge.id != srcId && inEdge.value != memrefToSkip)
+      if (inEdge.id != srcId)
         dstDepInsts.insert(getNode(inEdge.id)->inst);
 
     Instruction *srcNodeInst = getNode(srcId)->inst;
@@ -1366,18 +1365,24 @@
 struct GreedyFusion {
 public:
   MemRefDependenceGraph *mdg;
-  SmallVector<unsigned, 4> worklist;
+  SmallVector<unsigned, 8> worklist;
+  llvm::SmallDenseSet<unsigned, 16> worklistSet;
 
   GreedyFusion(MemRefDependenceGraph *mdg) : mdg(mdg) {
     // Initialize worklist with nodes from 'mdg'.
+    // TODO(andydavis) Add a priority queue for prioritizing nodes by different
+    // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
     worklist.resize(mdg->nodes.size());
     std::iota(worklist.begin(), worklist.end(), 0);
+    worklistSet.insert(worklist.begin(), worklist.end());
   }
 
   void run(unsigned localBufSizeThreshold, Optional<unsigned> fastMemorySpace) {
     while (!worklist.empty()) {
       unsigned dstId = worklist.back();
       worklist.pop_back();
+      worklistSet.erase(dstId);
+
       // Skip if this node was removed (fused into another node).
       if (mdg->nodes.count(dstId) == 0)
         continue;
@@ -1437,8 +1442,8 @@
 
           // Compute an instruction list insertion point for the fused loop
           // nest which preserves dependences.
-          Instruction *insertPointInst = mdg->getFusedLoopNestInsertionPoint(
-              srcNode->id, dstNode->id, memref);
+          Instruction *insertPointInst =
+              mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
           if (insertPointInst == nullptr)
             continue;
 
@@ -1516,6 +1521,22 @@
             if (mdg->canRemoveNode(srcNode->id)) {
               mdg->removeNode(srcNode->id);
               srcNode->inst->erase();
+            } else {
+              // Add remaining users of 'oldMemRef' back on the worklist (if not
+              // already there), as its replacement with a local/private memref
+              // has reduced dependences on 'oldMemRef' which may have created
+              // new fusion opportunities.
+              if (mdg->outEdges.count(srcNode->id) > 0) {
+                SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
+                    mdg->outEdges[srcNode->id];
+                for (auto &outEdge : oldOutEdges) {
+                  if (outEdge.value == memref &&
+                      worklistSet.count(outEdge.id) == 0) {
+                    worklist.push_back(outEdge.id);
+                    worklistSet.insert(outEdge.id);
+                  }
+                }
+              }
             }
           }
         }