[mlir][Affine] Introduce affine memory interfaces

This patch introduces interfaces for read and write ops with affine
restrictions. I used `read`/`write` intead of `load`/`store` for the
interfaces so that they can also be implemented by dma ops.
For now, they are only implemented by affine.load, affine.store,
affine.vector_load and affine.vector_store.

For testing purposes, this patch also migrates affine loop fusion and
required analysis to use the new interfaces. No other changes are made
beyond that.

Co-authored-by: Alex Zinenko <[email protected]>

Reviewed By: bondhugula, ftynse

Differential Revision: https://reviews.llvm.org/D79829
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 72dfc1d..bb219fa 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -70,7 +70,7 @@
 
 // TODO(b/117228571) Replace when this is modeled through side-effects/op traits
 static bool isMemRefDereferencingOp(Operation &op) {
-  if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+  if (isa<AffineReadOpInterface>(op) || isa<AffineWriteOpInterface>(op) ||
       isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
     return true;
   return false;
@@ -92,9 +92,9 @@
         forOps.push_back(cast<AffineForOp>(op));
       else if (op->getNumRegions() != 0)
         hasNonForRegion = true;
-      else if (isa<AffineLoadOp>(op))
+      else if (isa<AffineReadOpInterface>(op))
         loadOpInsts.push_back(op);
-      else if (isa<AffineStoreOp>(op))
+      else if (isa<AffineWriteOpInterface>(op))
         storeOpInsts.push_back(op);
     });
   }
@@ -125,7 +125,7 @@
     unsigned getLoadOpCount(Value memref) {
       unsigned loadOpCount = 0;
       for (auto *loadOpInst : loads) {
-        if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+        if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
           ++loadOpCount;
       }
       return loadOpCount;
@@ -135,7 +135,7 @@
     unsigned getStoreOpCount(Value memref) {
       unsigned storeOpCount = 0;
       for (auto *storeOpInst : stores) {
-        if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+        if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
           ++storeOpCount;
       }
       return storeOpCount;
@@ -145,7 +145,7 @@
     void getStoreOpsForMemref(Value memref,
                               SmallVectorImpl<Operation *> *storeOps) {
       for (auto *storeOpInst : stores) {
-        if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+        if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
           storeOps->push_back(storeOpInst);
       }
     }
@@ -154,7 +154,7 @@
     void getLoadOpsForMemref(Value memref,
                              SmallVectorImpl<Operation *> *loadOps) {
       for (auto *loadOpInst : loads) {
-        if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+        if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
           loadOps->push_back(loadOpInst);
       }
     }
@@ -164,10 +164,10 @@
     void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) {
       llvm::SmallDenseSet<Value, 2> loadMemrefs;
       for (auto *loadOpInst : loads) {
-        loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
+        loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef());
       }
       for (auto *storeOpInst : stores) {
-        auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+        auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
         if (loadMemrefs.count(memref) > 0)
           loadAndStoreMemrefSet->insert(memref);
       }
@@ -259,7 +259,7 @@
   bool writesToLiveInOrEscapingMemrefs(unsigned id) {
     Node *node = getNode(id);
     for (auto *storeOpInst : node->stores) {
-      auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+      auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
       auto *op = memref.getDefiningOp();
       // Return true if 'memref' is a block argument.
       if (!op)
@@ -272,13 +272,14 @@
     return false;
   }
 
-  // Returns the unique AffineStoreOp in `node` that meets all the following:
+  // Returns the unique AffineWriteOpInterface in `node` that meets all the
+  // following:
   //   *) store is the only one that writes to a function-local memref live out
   //      of `node`,
   //   *) store is not the source of a self-dependence on `node`.
-  // Otherwise, returns a null AffineStoreOp.
-  AffineStoreOp getUniqueOutgoingStore(Node *node) {
-    AffineStoreOp uniqueStore;
+  // Otherwise, returns a null AffineWriteOpInterface.
+  AffineWriteOpInterface getUniqueOutgoingStore(Node *node) {
+    AffineWriteOpInterface uniqueStore;
 
     // Return null if `node` doesn't have any outgoing edges.
     auto outEdgeIt = outEdges.find(node->id);
@@ -287,7 +288,7 @@
 
     const auto &nodeOutEdges = outEdgeIt->second;
     for (auto *op : node->stores) {
-      auto storeOp = cast<AffineStoreOp>(op);
+      auto storeOp = cast<AffineWriteOpInterface>(op);
       auto memref = storeOp.getMemRef();
       // Skip this store if there are no dependences on its memref. This means
       // that store either:
@@ -322,7 +323,8 @@
     Node *node = getNode(id);
     for (auto *storeOpInst : node->stores) {
       // Return false if there exist out edges from 'id' on 'memref'.
-      if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0)
+      auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+      if (getOutEdgeCount(id, storeMemref) > 0)
         return false;
     }
     return true;
@@ -651,28 +653,28 @@
       Node node(nextNodeId++, &op);
       for (auto *opInst : collector.loadOpInsts) {
         node.loads.push_back(opInst);
-        auto memref = cast<AffineLoadOp>(opInst).getMemRef();
+        auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
         memrefAccesses[memref].insert(node.id);
       }
       for (auto *opInst : collector.storeOpInsts) {
         node.stores.push_back(opInst);
-        auto memref = cast<AffineStoreOp>(opInst).getMemRef();
+        auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
         memrefAccesses[memref].insert(node.id);
       }
       forToNodeMap[&op] = node.id;
       nodes.insert({node.id, node});
-    } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+    } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
       // Create graph node for top-level load op.
       Node node(nextNodeId++, &op);
       node.loads.push_back(&op);
-      auto memref = cast<AffineLoadOp>(op).getMemRef();
+      auto memref = cast<AffineReadOpInterface>(op).getMemRef();
       memrefAccesses[memref].insert(node.id);
       nodes.insert({node.id, node});
-    } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+    } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
       // Create graph node for top-level store op.
       Node node(nextNodeId++, &op);
       node.stores.push_back(&op);
-      auto memref = cast<AffineStoreOp>(op).getMemRef();
+      auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
       memrefAccesses[memref].insert(node.id);
       nodes.insert({node.id, node});
     } else if (op.getNumRegions() != 0) {
@@ -733,7 +735,7 @@
   dstLoads->clear();
   SmallVector<Operation *, 4> srcLoadsToKeep;
   for (auto *load : *srcLoads) {
-    if (cast<AffineLoadOp>(load).getMemRef() == memref)
+    if (cast<AffineReadOpInterface>(load).getMemRef() == memref)
       dstLoads->push_back(load);
     else
       srcLoadsToKeep.push_back(load);
@@ -854,7 +856,7 @@
   // Builder to create constants at the top level.
   OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
   // Create new memref type based on slice bounds.
-  auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
+  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
   unsigned rank = oldMemRefType.getRank();
 
@@ -962,9 +964,10 @@
 // Returns true if 'dstNode's read/write region to 'memref' is a super set of
 // 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
 // TODO(andydavis) Generalize this to handle more live in/out cases.
-static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
-                                           AffineStoreOp srcLiveOutStoreOp,
-                                           MemRefDependenceGraph *mdg) {
+static bool
+canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+                               AffineWriteOpInterface srcLiveOutStoreOp,
+                               MemRefDependenceGraph *mdg) {
   assert(srcLiveOutStoreOp && "Expected a valid store op");
   auto *dstNode = mdg->getNode(dstId);
   Value memref = srcLiveOutStoreOp.getMemRef();
@@ -1450,7 +1453,7 @@
       DenseSet<Value> visitedMemrefs;
       while (!loads.empty()) {
         // Get memref of load on top of the stack.
-        auto memref = cast<AffineLoadOp>(loads.back()).getMemRef();
+        auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef();
         if (visitedMemrefs.count(memref) > 0)
           continue;
         visitedMemrefs.insert(memref);
@@ -1488,7 +1491,7 @@
             // feasibility for loops with multiple stores.
             unsigned maxLoopDepth = 0;
             for (auto *op : srcNode->stores) {
-              auto storeOp = cast<AffineStoreOp>(op);
+              auto storeOp = cast<AffineWriteOpInterface>(op);
               if (storeOp.getMemRef() != memref) {
                 srcStoreOp = nullptr;
                 break;
@@ -1563,7 +1566,7 @@
           // Gather 'dstNode' store ops to 'memref'.
           SmallVector<Operation *, 2> dstStoreOpInsts;
           for (auto *storeOpInst : dstNode->stores)
-            if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+            if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref)
               dstStoreOpInsts.push_back(storeOpInst);
 
           unsigned bestDstLoopDepth;
@@ -1601,7 +1604,8 @@
               // Create private memref for 'memref' in 'dstAffineForOp'.
               SmallVector<Operation *, 4> storesForMemref;
               for (auto *storeOpInst : sliceCollector.storeOpInsts) {
-                if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+                if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
+                    memref)
                   storesForMemref.push_back(storeOpInst);
               }
               // TODO(andydavis) Use union of memref write regions to compute
@@ -1624,7 +1628,8 @@
             // Add new load ops to current Node load op list 'loads' to
             // continue fusing based on new operands.
             for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
-              auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
+              auto loadMemRef =
+                  cast<AffineReadOpInterface>(loadOpInst).getMemRef();
               // NOTE: Change 'loads' to a hash set in case efficiency is an
               // issue. We still use a vector since it's expected to be small.
               if (visitedMemrefs.count(loadMemRef) == 0 &&
@@ -1785,7 +1790,8 @@
       // Check that all stores are to the same memref.
       DenseSet<Value> storeMemrefs;
       for (auto *storeOpInst : sibNode->stores) {
-        storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
+        storeMemrefs.insert(
+            cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
       }
       if (storeMemrefs.size() != 1)
         return false;
@@ -1796,7 +1802,7 @@
     auto fn = dstNode->op->getParentOfType<FuncOp>();
     for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
       for (auto *user : fn.getArgument(i).getUsers()) {
-        if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
+        if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
           // Gather loops surrounding 'use'.
           SmallVector<AffineForOp, 4> loops;
           getLoopIVs(*user, &loops);