[mlir][Affine] Introduce affine memory interfaces
This patch introduces interfaces for read and write ops with affine
restrictions. I used `read`/`write` intead of `load`/`store` for the
interfaces so that they can also be implemented by dma ops.
For now, they are only implemented by affine.load, affine.store,
affine.vector_load and affine.vector_store.
For testing purposes, this patch also migrates affine loop fusion and
required analysis to use the new interfaces. No other changes are made
beyond that.
Co-authored-by: Alex Zinenko <[email protected]>
Reviewed By: bondhugula, ftynse
Differential Revision: https://reviews.llvm.org/D79829
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 72dfc1d..bb219fa 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -70,7 +70,7 @@
// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
static bool isMemRefDereferencingOp(Operation &op) {
- if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+ if (isa<AffineReadOpInterface>(op) || isa<AffineWriteOpInterface>(op) ||
isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
return true;
return false;
@@ -92,9 +92,9 @@
forOps.push_back(cast<AffineForOp>(op));
else if (op->getNumRegions() != 0)
hasNonForRegion = true;
- else if (isa<AffineLoadOp>(op))
+ else if (isa<AffineReadOpInterface>(op))
loadOpInsts.push_back(op);
- else if (isa<AffineStoreOp>(op))
+ else if (isa<AffineWriteOpInterface>(op))
storeOpInsts.push_back(op);
});
}
@@ -125,7 +125,7 @@
unsigned getLoadOpCount(Value memref) {
unsigned loadOpCount = 0;
for (auto *loadOpInst : loads) {
- if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+ if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
++loadOpCount;
}
return loadOpCount;
@@ -135,7 +135,7 @@
unsigned getStoreOpCount(Value memref) {
unsigned storeOpCount = 0;
for (auto *storeOpInst : stores) {
- if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+ if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
++storeOpCount;
}
return storeOpCount;
@@ -145,7 +145,7 @@
void getStoreOpsForMemref(Value memref,
SmallVectorImpl<Operation *> *storeOps) {
for (auto *storeOpInst : stores) {
- if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
+ if (memref == cast<AffineWriteOpInterface>(storeOpInst).getMemRef())
storeOps->push_back(storeOpInst);
}
}
@@ -154,7 +154,7 @@
void getLoadOpsForMemref(Value memref,
SmallVectorImpl<Operation *> *loadOps) {
for (auto *loadOpInst : loads) {
- if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
+ if (memref == cast<AffineReadOpInterface>(loadOpInst).getMemRef())
loadOps->push_back(loadOpInst);
}
}
@@ -164,10 +164,10 @@
void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) {
llvm::SmallDenseSet<Value, 2> loadMemrefs;
for (auto *loadOpInst : loads) {
- loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
+ loadMemrefs.insert(cast<AffineReadOpInterface>(loadOpInst).getMemRef());
}
for (auto *storeOpInst : stores) {
- auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
if (loadMemrefs.count(memref) > 0)
loadAndStoreMemrefSet->insert(memref);
}
@@ -259,7 +259,7 @@
bool writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
- auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
auto *op = memref.getDefiningOp();
// Return true if 'memref' is a block argument.
if (!op)
@@ -272,13 +272,14 @@
return false;
}
- // Returns the unique AffineStoreOp in `node` that meets all the following:
+ // Returns the unique AffineWriteOpInterface in `node` that meets all the
+ // following:
// *) store is the only one that writes to a function-local memref live out
// of `node`,
// *) store is not the source of a self-dependence on `node`.
- // Otherwise, returns a null AffineStoreOp.
- AffineStoreOp getUniqueOutgoingStore(Node *node) {
- AffineStoreOp uniqueStore;
+ // Otherwise, returns a null AffineWriteOpInterface.
+ AffineWriteOpInterface getUniqueOutgoingStore(Node *node) {
+ AffineWriteOpInterface uniqueStore;
// Return null if `node` doesn't have any outgoing edges.
auto outEdgeIt = outEdges.find(node->id);
@@ -287,7 +288,7 @@
const auto &nodeOutEdges = outEdgeIt->second;
for (auto *op : node->stores) {
- auto storeOp = cast<AffineStoreOp>(op);
+ auto storeOp = cast<AffineWriteOpInterface>(op);
auto memref = storeOp.getMemRef();
// Skip this store if there are no dependences on its memref. This means
// that store either:
@@ -322,7 +323,8 @@
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
// Return false if there exist out edges from 'id' on 'memref'.
- if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0)
+ auto storeMemref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+ if (getOutEdgeCount(id, storeMemref) > 0)
return false;
}
return true;
@@ -651,28 +653,28 @@
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
- auto memref = cast<AffineLoadOp>(opInst).getMemRef();
+ auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
- auto memref = cast<AffineStoreOp>(opInst).getMemRef();
+ auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[&op] = node.id;
nodes.insert({node.id, node});
- } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
+ } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &op);
node.loads.push_back(&op);
- auto memref = cast<AffineLoadOp>(op).getMemRef();
+ auto memref = cast<AffineReadOpInterface>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
- } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
+ } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &op);
node.stores.push_back(&op);
- auto memref = cast<AffineStoreOp>(op).getMemRef();
+ auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (op.getNumRegions() != 0) {
@@ -733,7 +735,7 @@
dstLoads->clear();
SmallVector<Operation *, 4> srcLoadsToKeep;
for (auto *load : *srcLoads) {
- if (cast<AffineLoadOp>(load).getMemRef() == memref)
+ if (cast<AffineReadOpInterface>(load).getMemRef() == memref)
dstLoads->push_back(load);
else
srcLoadsToKeep.push_back(load);
@@ -854,7 +856,7 @@
// Builder to create constants at the top level.
OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
// Create new memref type based on slice bounds.
- auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
+ auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
unsigned rank = oldMemRefType.getRank();
@@ -962,9 +964,10 @@
// Returns true if 'dstNode's read/write region to 'memref' is a super set of
// 'srcNode's write region to 'memref' and 'srcId' has only one output edge.
// TODO(andydavis) Generalize this to handle more live in/out cases.
-static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
- AffineStoreOp srcLiveOutStoreOp,
- MemRefDependenceGraph *mdg) {
+static bool
+canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
+ AffineWriteOpInterface srcLiveOutStoreOp,
+ MemRefDependenceGraph *mdg) {
assert(srcLiveOutStoreOp && "Expected a valid store op");
auto *dstNode = mdg->getNode(dstId);
Value memref = srcLiveOutStoreOp.getMemRef();
@@ -1450,7 +1453,7 @@
DenseSet<Value> visitedMemrefs;
while (!loads.empty()) {
// Get memref of load on top of the stack.
- auto memref = cast<AffineLoadOp>(loads.back()).getMemRef();
+ auto memref = cast<AffineReadOpInterface>(loads.back()).getMemRef();
if (visitedMemrefs.count(memref) > 0)
continue;
visitedMemrefs.insert(memref);
@@ -1488,7 +1491,7 @@
// feasibility for loops with multiple stores.
unsigned maxLoopDepth = 0;
for (auto *op : srcNode->stores) {
- auto storeOp = cast<AffineStoreOp>(op);
+ auto storeOp = cast<AffineWriteOpInterface>(op);
if (storeOp.getMemRef() != memref) {
srcStoreOp = nullptr;
break;
@@ -1563,7 +1566,7 @@
// Gather 'dstNode' store ops to 'memref'.
SmallVector<Operation *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
- if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+ if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() == memref)
dstStoreOpInsts.push_back(storeOpInst);
unsigned bestDstLoopDepth;
@@ -1601,7 +1604,8 @@
// Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<Operation *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
- if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
+ if (cast<AffineWriteOpInterface>(storeOpInst).getMemRef() ==
+ memref)
storesForMemref.push_back(storeOpInst);
}
// TODO(andydavis) Use union of memref write regions to compute
@@ -1624,7 +1628,8 @@
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
- auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
+ auto loadMemRef =
+ cast<AffineReadOpInterface>(loadOpInst).getMemRef();
// NOTE: Change 'loads' to a hash set in case efficiency is an
// issue. We still use a vector since it's expected to be small.
if (visitedMemrefs.count(loadMemRef) == 0 &&
@@ -1785,7 +1790,8 @@
// Check that all stores are to the same memref.
DenseSet<Value> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
- storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
+ storeMemrefs.insert(
+ cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
}
if (storeMemrefs.size() != 1)
return false;
@@ -1796,7 +1802,7 @@
auto fn = dstNode->op->getParentOfType<FuncOp>();
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
for (auto *user : fn.getArgument(i).getUsers()) {
- if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
+ if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*user, &loops);