NFC: Introduce new ValuePtr/ValueRef typedefs to simplify the transition to Value being value-typed.
This is an initial step to refactoring the representation of OpResult as proposed in: https://groups.google.com/a/tensorflow.org/g/mlir/c/XXzzKhqqF_0/m/v6bKb08WCgAJ
This change will make it much simpler to incrementally transition all of the existing code to use value-typed semantics.
PiperOrigin-RevId: 286844725
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 5694c99..60f0264 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -172,7 +172,7 @@
Node(unsigned id, Operation *op) : id(id), op(op) {}
// Returns the load op count for 'memref'.
- unsigned getLoadOpCount(Value *memref) {
+ unsigned getLoadOpCount(ValuePtr memref) {
unsigned loadOpCount = 0;
for (auto *loadOpInst : loads) {
if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
@@ -182,7 +182,7 @@
}
// Returns the store op count for 'memref'.
- unsigned getStoreOpCount(Value *memref) {
+ unsigned getStoreOpCount(ValuePtr memref) {
unsigned storeOpCount = 0;
for (auto *storeOpInst : stores) {
if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
@@ -192,7 +192,7 @@
}
// Returns all store ops in 'storeOps' which access 'memref'.
- void getStoreOpsForMemref(Value *memref,
+ void getStoreOpsForMemref(ValuePtr memref,
SmallVectorImpl<Operation *> *storeOps) {
for (auto *storeOpInst : stores) {
if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
@@ -201,7 +201,7 @@
}
// Returns all load ops in 'loadOps' which access 'memref'.
- void getLoadOpsForMemref(Value *memref,
+ void getLoadOpsForMemref(ValuePtr memref,
SmallVectorImpl<Operation *> *loadOps) {
for (auto *loadOpInst : loads) {
if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
@@ -211,13 +211,13 @@
// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
// has at least one load and store operation.
- void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
- llvm::SmallDenseSet<Value *, 2> loadMemrefs;
+ void getLoadAndStoreMemrefSet(DenseSet<ValuePtr> *loadAndStoreMemrefSet) {
+ llvm::SmallDenseSet<ValuePtr, 2> loadMemrefs;
for (auto *loadOpInst : loads) {
loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
}
for (auto *storeOpInst : stores) {
- auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
if (loadMemrefs.count(memref) > 0)
loadAndStoreMemrefSet->insert(memref);
}
@@ -239,7 +239,7 @@
// defines an SSA value and another graph node which uses the SSA value
// (e.g. a constant operation defining a value which is used inside a loop
// nest).
- Value *value;
+ ValuePtr value;
};
// Map from node id to Node.
@@ -250,7 +250,7 @@
DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
// Map from memref to a count on the dependence edges associated with that
// memref.
- DenseMap<Value *, unsigned> memrefEdgeCount;
+ DenseMap<ValuePtr, unsigned> memrefEdgeCount;
// The next unique identifier to use for newly created graph nodes.
unsigned nextNodeId = 0;
@@ -309,7 +309,7 @@
bool writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
- auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
+ auto memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
auto *op = memref->getDefiningOp();
// Return true if 'memref' is a block argument.
if (!op)
@@ -338,7 +338,7 @@
const auto &nodeOutEdges = outEdgeIt->second;
for (auto *op : node->stores) {
auto storeOp = cast<AffineStoreOp>(op);
- auto *memref = storeOp.getMemRef();
+ auto memref = storeOp.getMemRef();
// Skip this store if there are no dependences on its memref. This means
// that store either:
// *) writes to a memref that is only read within the same loop nest
@@ -381,7 +381,7 @@
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
// is for 'value' if non-null, or for any value otherwise. Returns false
// otherwise.
- bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) {
+ bool hasEdge(unsigned srcId, unsigned dstId, ValuePtr value = nullptr) {
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
return false;
}
@@ -395,7 +395,7 @@
}
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
- void addEdge(unsigned srcId, unsigned dstId, Value *value) {
+ void addEdge(unsigned srcId, unsigned dstId, ValuePtr value) {
if (!hasEdge(srcId, dstId, value)) {
outEdges[srcId].push_back({dstId, value});
inEdges[dstId].push_back({srcId, value});
@@ -405,7 +405,7 @@
}
// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
- void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
+ void removeEdge(unsigned srcId, unsigned dstId, ValuePtr value) {
assert(inEdges.count(dstId) > 0);
assert(outEdges.count(srcId) > 0);
if (value->getType().isa<MemRefType>()) {
@@ -459,7 +459,7 @@
// Returns the input edge count for node 'id' and 'memref' from src nodes
// which access 'memref' with a store operation.
- unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
+ unsigned getIncomingMemRefAccesses(unsigned id, ValuePtr memref) {
unsigned inEdgeCount = 0;
if (inEdges.count(id) > 0)
for (auto &inEdge : inEdges[id])
@@ -474,7 +474,7 @@
// Returns the output edge count for node 'id' and 'memref' (if non-null),
// otherwise returns the total output edge count from node 'id'.
- unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) {
+ unsigned getOutEdgeCount(unsigned id, ValuePtr memref = nullptr) {
unsigned outEdgeCount = 0;
if (outEdges.count(id) > 0)
for (auto &outEdge : outEdges[id])
@@ -548,7 +548,7 @@
// Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
// has been replaced in node at 'dstId' by a private memref depending
// on the value of 'createPrivateMemRef'.
- void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef,
+ void updateEdges(unsigned srcId, unsigned dstId, ValuePtr oldMemRef,
bool createPrivateMemRef) {
// For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
if (inEdges.count(srcId) > 0) {
@@ -681,7 +681,7 @@
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(FuncOp f) {
- DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
+ DenseMap<ValuePtr, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
if (f.getBlocks().size() != 1)
@@ -701,12 +701,12 @@
Node node(nextNodeId++, &op);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
- auto *memref = cast<AffineLoadOp>(opInst).getMemRef();
+ auto memref = cast<AffineLoadOp>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
- auto *memref = cast<AffineStoreOp>(opInst).getMemRef();
+ auto memref = cast<AffineStoreOp>(opInst).getMemRef();
memrefAccesses[memref].insert(node.id);
}
forToNodeMap[&op] = node.id;
@@ -715,14 +715,14 @@
// Create graph node for top-level load op.
Node node(nextNodeId++, &op);
node.loads.push_back(&op);
- auto *memref = cast<AffineLoadOp>(op).getMemRef();
+ auto memref = cast<AffineLoadOp>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
// Create graph node for top-level store op.
Node node(nextNodeId++, &op);
node.stores.push_back(&op);
- auto *memref = cast<AffineStoreOp>(op).getMemRef();
+ auto memref = cast<AffineStoreOp>(op).getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
} else if (op.getNumRegions() != 0) {
@@ -743,7 +743,7 @@
if (!node.loads.empty() || !node.stores.empty())
continue;
auto *opInst = node.op;
- for (auto *value : opInst->getResults()) {
+ for (auto value : opInst->getResults()) {
for (auto *user : value->getUsers()) {
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*user, &loops);
@@ -777,7 +777,7 @@
// Removes load operations from 'srcLoads' which operate on 'memref', and
// adds them to 'dstLoads'.
-static void moveLoadsAccessingMemrefTo(Value *memref,
+static void moveLoadsAccessingMemrefTo(ValuePtr memref,
SmallVectorImpl<Operation *> *srcLoads,
SmallVectorImpl<Operation *> *dstLoads) {
dstLoads->clear();
@@ -893,10 +893,11 @@
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO(bondhugula): consider refactoring the common code from generateDma and
// this one.
-static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
- unsigned dstLoopDepth,
- Optional<unsigned> fastMemorySpace,
- uint64_t localBufSizeThreshold) {
+static ValuePtr createPrivateMemRef(AffineForOp forOp,
+ Operation *srcStoreOpInst,
+ unsigned dstLoopDepth,
+ Optional<unsigned> fastMemorySpace,
+ uint64_t localBufSizeThreshold) {
auto *forInst = forOp.getOperation();
// Create builder to insert alloc op just before 'forOp'.
@@ -904,7 +905,7 @@
// Builder to create constants at the top level.
OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
// Create new memref type based on slice bounds.
- auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
+ auto oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
unsigned rank = oldMemRefType.getRank();
@@ -928,7 +929,7 @@
// 'outerIVs' holds the values that this memory region is symbolic/parametric
// on; this would correspond to loop IVs surrounding the level at which the
// slice is being materialized.
- SmallVector<Value *, 8> outerIVs;
+ SmallVector<ValuePtr, 8> outerIVs;
cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
// Build 'rank' AffineExprs from MemRefRegion 'lbs'
@@ -960,7 +961,7 @@
auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
{}, newMemSpace);
// Gather alloc operands for the dynamic dimensions of the memref.
- SmallVector<Value *, 4> allocOperands;
+ SmallVector<ValuePtr, 4> allocOperands;
unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
@@ -973,7 +974,7 @@
// consumer loop nests to reduce their live range. Currently they are added
// at the beginning of the function, because loop nests can be reordered
// during the fusion pass.
- Value *newMemRef =
+ ValuePtr newMemRef =
top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
// Build an AffineMap to remap access functions based on lower bound offsets.
@@ -1016,7 +1017,7 @@
MemRefDependenceGraph *mdg) {
assert(srcLiveOutStoreOp && "Expected a valid store op");
auto *dstNode = mdg->getNode(dstId);
- Value *memref = srcLiveOutStoreOp.getMemRef();
+ ValuePtr memref = srcLiveOutStoreOp.getMemRef();
// Return false if 'srcNode' has more than one output edge on 'memref'.
if (mdg->getOutEdgeCount(srcId, memref) > 1)
return false;
@@ -1495,10 +1496,10 @@
SmallVector<Operation *, 4> loads = dstNode->loads;
SmallVector<Operation *, 4> dstLoadOpInsts;
- DenseSet<Value *> visitedMemrefs;
+ DenseSet<ValuePtr> visitedMemrefs;
while (!loads.empty()) {
// Get memref of load on top of the stack.
- auto *memref = cast<AffineLoadOp>(loads.back()).getMemRef();
+ auto memref = cast<AffineLoadOp>(loads.back()).getMemRef();
if (visitedMemrefs.count(memref) > 0)
continue;
visitedMemrefs.insert(memref);
@@ -1653,7 +1654,7 @@
}
// TODO(andydavis) Use union of memref write regions to compute
// private memref footprint.
- auto *newMemRef = createPrivateMemRef(
+ auto newMemRef = createPrivateMemRef(
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
fastMemorySpace, localBufSizeThreshold);
visitedMemrefs.insert(newMemRef);
@@ -1671,7 +1672,7 @@
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
- auto *loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
+ auto loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
if (visitedMemrefs.count(loadMemRef) == 0)
loads.push_back(loadOpInst);
}
@@ -1737,10 +1738,10 @@
// Attempt to fuse 'dstNode' with sibling nodes in the graph.
void fuseWithSiblingNodes(Node *dstNode) {
DenseSet<unsigned> visitedSibNodeIds;
- std::pair<unsigned, Value *> idAndMemref;
+ std::pair<unsigned, ValuePtr> idAndMemref;
while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
unsigned sibId = idAndMemref.first;
- Value *memref = idAndMemref.second;
+ ValuePtr memref = idAndMemref.second;
// TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
// stores to the same memref in 'sibNode' loop nest.
auto *sibNode = mdg->getNode(sibId);
@@ -1804,10 +1805,10 @@
// 'idAndMemrefToFuse' on success. Returns false otherwise.
bool findSiblingNodeToFuse(Node *dstNode,
DenseSet<unsigned> *visitedSibNodeIds,
- std::pair<unsigned, Value *> *idAndMemrefToFuse) {
+ std::pair<unsigned, ValuePtr> *idAndMemrefToFuse) {
// Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
// on 'memref'.
- auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) {
+ auto canFuseWithSibNode = [&](Node *sibNode, ValuePtr memref) {
// Skip if 'outEdge' is not a read-after-write dependence.
// TODO(andydavis) Remove restrict to single load op restriction.
if (sibNode->getLoadOpCount(memref) != 1)
@@ -1819,15 +1820,15 @@
return false;
// Skip sib node if it loads to (and stores from) the same memref on
// which it also has an input dependence edge.
- DenseSet<Value *> loadAndStoreMemrefSet;
+ DenseSet<ValuePtr> loadAndStoreMemrefSet;
sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
- if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
+ if (llvm::any_of(loadAndStoreMemrefSet, [=](ValuePtr memref) {
return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
}))
return false;
// Check that all stores are to the same memref.
- DenseSet<Value *> storeMemrefs;
+ DenseSet<ValuePtr> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
}
@@ -1856,7 +1857,7 @@
if (visitedSibNodeIds->count(sibNode->id) > 0)
continue;
// Skip 'use' if it does not load from the same memref as 'dstNode'.
- auto *memref = loadOp.getMemRef();
+ auto memref = loadOp.getMemRef();
if (dstNode->getLoadOpCount(memref) == 0)
continue;
// Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
@@ -1950,7 +1951,7 @@
for (auto &pair : mdg->memrefEdgeCount) {
if (pair.second > 0)
continue;
- auto *memref = pair.first;
+ auto memref = pair.first;
// Skip if there exist other uses (return operation or function calls).
if (!memref->use_empty())
continue;