Globally change load/store/dma_start/dma_wait operations over to affine.load/store/dma_start/dma_wait.
In most places, this is just a name change (with the exception of affine.dma_start swapping the operand positions of its tag memref and num_elements operands).
Significant code changes occur here:
*) Vectorization: LoopAnalysis.cpp, Vectorize.cpp
*) Affine Transforms: Transforms/Utils/Utils.cpp
PiperOrigin-RevId: 256395088
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 77b944f..1eee40b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -133,9 +133,9 @@
forOps.push_back(cast<AffineForOp>(op));
else if (op->getNumRegions() != 0)
hasNonForRegion = true;
- else if (isa<LoadOp>(op))
+ else if (isa<AffineLoadOp>(op))
loadOpInsts.push_back(op);
- else if (isa<StoreOp>(op))
+ else if (isa<AffineStoreOp>(op))
storeOpInsts.push_back(op);
});
}
@@ -143,8 +143,8 @@
// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
static bool isMemRefDereferencingOp(Operation &op) {
- if (isa<LoadOp>(op) || isa<StoreOp>(op) || isa<DmaStartOp>(op) ||
- isa<DmaWaitOp>(op))
+ if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
+ isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
return true;
return false;
}
@@ -174,7 +174,7 @@
unsigned getLoadOpCount(Value *memref) {
unsigned loadOpCount = 0;
for (auto *loadOpInst : loads) {
- if (memref == cast<LoadOp>(loadOpInst).getMemRef())
+ if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
++loadOpCount;
}
return loadOpCount;
@@ -184,7 +184,7 @@
unsigned getStoreOpCount(Value *memref) {
unsigned storeOpCount = 0;
for (auto *storeOpInst : stores) {
- if (memref == cast<StoreOp>(storeOpInst).getMemRef())
+ if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
++storeOpCount;
}
return storeOpCount;
@@ -194,7 +194,7 @@
void getStoreOpsForMemref(Value *memref,
SmallVectorImpl<Operation *> *storeOps) {
for (auto *storeOpInst : stores) {
- if (memref == cast<StoreOp>(storeOpInst).getMemRef())
+ if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
storeOps->push_back(storeOpInst);
}
}
@@ -203,7 +203,7 @@
void getLoadOpsForMemref(Value *memref,
SmallVectorImpl<Operation *> *loadOps) {
for (auto *loadOpInst : loads) {
- if (memref == cast<LoadOp>(loadOpInst).getMemRef())
+ if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
loadOps->push_back(loadOpInst);
}
}
@@ -213,10 +213,10 @@
void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
llvm::SmallDenseSet<Value *, 2> loadMemrefs;
for (auto *loadOpInst : loads) {
- loadMemrefs.insert(cast<LoadOp>(loadOpInst).getMemRef());
+ loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
}
for (auto *storeOpInst : stores) {
- auto *memref = cast<StoreOp>(storeOpInst).getMemRef();
+ auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
if (loadMemrefs.count(memref) > 0)
loadAndStoreMemrefSet->insert(memref);
}
@@ -308,7 +308,7 @@
bool writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
- auto *memref = cast<StoreOp>(storeOpInst).getMemRef();
+ auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
auto *op = memref->getDefiningOp();
// Return true if 'memref' is a block argument.
if (!op)
@@ -333,7 +333,7 @@
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
// Return false if there exist out edges from 'id' on 'memref'.
- if (getOutEdgeCount(id, cast<StoreOp>(storeOpInst).getMemRef()) > 0)
+ if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0)
return false;
}
return true;
@@ -658,28 +658,28 @@
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
- auto *memref = cast<LoadOp>(opInst).getMemRef();
+ auto *memref = cast<AffineLoadOp>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
- auto *memref = cast<StoreOp>(opInst).getMemRef();
+ auto *memref = cast<AffineStoreOp>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[&op] = node.id;
nodes.insert({node.id, node});
- } else if (auto loadOp = dyn_cast<LoadOp>(op)) {
+ } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
// Create graph node for top-level load op.
Node node(nextNodeId++, &op);
node.loads.push_back(&op);
- auto *memref = cast<LoadOp>(op).getMemRef();
+ auto *memref = cast<AffineLoadOp>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
- } else if (auto storeOp = dyn_cast<StoreOp>(op)) {
+ } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &op);
node.stores.push_back(&op);
- auto *memref = cast<StoreOp>(op).getMemRef();
+ auto *memref = cast<AffineStoreOp>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (op.getNumRegions() != 0) {
@@ -740,7 +740,7 @@
dstLoads->clear();
SmallVector<Operation *, 4> srcLoadsToKeep;
for (auto *load : *srcLoads) {
- if (cast<LoadOp>(load).getMemRef() == memref)
+ if (cast<AffineLoadOp>(load).getMemRef() == memref)
dstLoads->push_back(load);
else
srcLoadsToKeep.push_back(load);
@@ -861,7 +861,7 @@
// Builder to create constants at the top level.
OpBuilder top(forInst->getFunction().getBody());
// Create new memref type based on slice bounds.
- auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef();
+ auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
unsigned rank = oldMemRefType.getRank();
@@ -976,7 +976,7 @@
// Gather all memrefs from 'srcNode' store ops.
DenseSet<Value *> storeMemrefs;
for (auto *storeOpInst : srcNode->stores) {
- storeMemrefs.insert(cast<StoreOp>(storeOpInst).getMemRef());
+ storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
}
// Return false if any of the following are true:
// *) 'srcNode' writes to a live in/out memref other than 'memref'.
@@ -1461,7 +1461,7 @@
DenseSet<Value *> visitedMemrefs;
while (!loads.empty()) {
// Get memref of load on top of the stack.
- auto *memref = cast<LoadOp>(loads.back()).getMemRef();
+ auto *memref = cast<AffineLoadOp>(loads.back()).getMemRef();
if (visitedMemrefs.count(memref) > 0)
continue;
visitedMemrefs.insert(memref);
@@ -1517,7 +1517,7 @@
// Gather 'dstNode' store ops to 'memref'.
SmallVector<Operation *, 2> dstStoreOpInsts;
for (auto *storeOpInst : dstNode->stores)
- if (cast<StoreOp>(storeOpInst).getMemRef() == memref)
+ if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
dstStoreOpInsts.push_back(storeOpInst);
unsigned bestDstLoopDepth;
@@ -1562,7 +1562,7 @@
// Create private memref for 'memref' in 'dstAffineForOp'.
SmallVector<Operation *, 4> storesForMemref;
for (auto *storeOpInst : sliceCollector.storeOpInsts) {
- if (cast<StoreOp>(storeOpInst).getMemRef() == memref)
+ if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
storesForMemref.push_back(storeOpInst);
}
assert(storesForMemref.size() == 1);
@@ -1584,7 +1584,7 @@
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
- auto *loadMemRef = cast<LoadOp>(loadOpInst).getMemRef();
+ auto *loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
if (visitedMemrefs.count(loadMemRef) == 0)
loads.push_back(loadOpInst);
}
@@ -1742,7 +1742,7 @@
// Check that all stores are to the same memref.
DenseSet<Value *> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
- storeMemrefs.insert(cast<StoreOp>(storeOpInst).getMemRef());
+ storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
}
if (storeMemrefs.size() != 1)
return false;
@@ -1753,7 +1753,7 @@
auto fn = dstNode->op->getFunction();
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
for (auto *user : fn.getArgument(i)->getUsers()) {
- if (auto loadOp = dyn_cast<LoadOp>(user)) {
+ if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*user, &loops);