Refactor / improve replaceAllMemRefUsesWith
Refactor replaceAllMemRefUsesWith to split it into two methods: the new
method does the replacement on a single op, and is used by the existing
one.
- make the methods return LogicalResult instead of bool
- Earlier, when replacement failed (due to non-deferencing uses of the
memref), the set of ops that had already been processed would have
been replaced leaving the IR in an inconsistent state. Now, a
pass is made over all ops to first check for non-deferencing
uses, and then replacement is performed. No test cases were affected
because all clients of this method were first checking for
non-deferencing uses before calling this method (for other reasons).
This isn't true for a use case in another upcoming PR (scalar
replacement); clients can now bail out with consistent IR on failure
of replaceAllMemRefUsesWith. Add test case.
- multiple deferencing uses of the same memref in a single op is
possible (we have no such use cases/scenarios), and this has always
remained unsupported. Add an assertion for this.
- minor fix to another test pipeline-data-transfer case.
Signed-off-by: Uday Bondhugula <[email protected]>
Closes tensorflow/mlir#87
PiperOrigin-RevId: 265808183
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 46713dc..a17481f 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -952,12 +952,13 @@
? AffineMap()
: b.getAffineMap(outerIVs.size() + rank, 0, remapExprs);
// Replace all users of 'oldMemRef' with 'newMemRef'.
- bool ret =
+ LogicalResult res =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
/*domInstFilter=*/&*forOp.getBody()->begin());
- assert(ret && "replaceAllMemrefUsesWith should always succeed here");
- (void)ret;
+ assert(succeeded(res) &&
+ "replaceAllMemrefUsesWith should always succeed here");
+ (void)res;
return newMemRef;
}
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 0cd979a..a814af9 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -115,13 +115,14 @@
auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
forOp.getInductionVar());
- // replaceAllMemRefUsesWith will always succeed unless the forOp body has
- // non-deferencing uses of the memref (dealloc's are fine though).
- if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
- /*extraIndices=*/{ivModTwoOp},
- /*indexRemap=*/AffineMap(),
- /*extraOperands=*/{},
- /*domInstFilter=*/&*forOp.getBody()->begin())) {
+ // replaceAllMemRefUsesWith will succeed unless the forOp body has
+ // non-dereferencing uses of the memref (dealloc's are fine though).
+ if (failed(replaceAllMemRefUsesWith(
+ oldMemRef, newMemRef,
+ /*extraIndices=*/{ivModTwoOp},
+ /*indexRemap=*/AffineMap(),
+ /*extraOperands=*/{},
+ /*domInstFilter=*/&*forOp.getBody()->begin()))) {
LLVM_DEBUG(
forOp.emitError("memref replacement for double buffering failed"));
ivModTwoOp.erase();
@@ -276,9 +277,9 @@
if (!doubleBuffer(oldMemRef, forOp)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
- LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
- LLVM_DEBUG(dmaStartInst->dump());
- // IR still in a valid state.
+ LLVM_DEBUG(llvm::dbgs()
+ << "double buffering failed for" << dmaStartInst << "\n";);
+ // IR still valid and semantically correct.
return;
}
// If the old memref has no more uses, remove its 'dead' alloc if it was
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 8d7b7a8..b0c9b94 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -57,16 +57,181 @@
return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref);
}
-bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
- ArrayRef<Value *> extraIndices,
- AffineMap indexRemap,
- ArrayRef<Value *> extraOperands,
- Operation *domInstFilter,
- Operation *postDomInstFilter) {
+// Perform the replacement in `op`.
+LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+ Operation *op,
+ ArrayRef<Value *> extraIndices,
+ AffineMap indexRemap,
+ ArrayRef<Value *> extraOperands) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
- (void)newMemRefRank;
+ (void)oldMemRefRank;
+ if (indexRemap) {
+ assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
+ assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
+ assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
+ } else {
+ assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
+ }
+
+ // Assert same elemental type.
+ assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
+ newMemRef->getType().cast<MemRefType>().getElementType());
+
+ if (!isMemRefDereferencingOp(*op))
+ // Failure: memref used in a non-dereferencing context (potentially
+ // escapes); no replacement in these cases.
+ return failure();
+
+ SmallVector<unsigned, 2> usePositions;
+ for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
+ if (opEntry.value() == oldMemRef)
+ usePositions.push_back(opEntry.index());
+ }
+
+ // If memref doesn't appear, nothing to do.
+ if (usePositions.empty())
+ return success();
+
+ if (usePositions.size() > 1) {
+ // TODO(mlir-team): extend it for this case when needed (rare).
+ assert(false && "multiple dereferencing uses in a single op not supported");
+ return failure();
+ }
+
+ unsigned memRefOperandPos = usePositions.front();
+
+ OpBuilder builder(op);
+ NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
+ AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
+ unsigned oldMapNumInputs = oldMap.getNumInputs();
+ SmallVector<Value *, 4> oldMapOperands(
+ op->operand_begin() + memRefOperandPos + 1,
+ op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
+
+ // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
+ SmallVector<Value *, 4> oldMemRefOperands;
+ SmallVector<Value *, 4> affineApplyOps;
+ oldMemRefOperands.reserve(oldMemRefRank);
+ if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
+ for (auto resultExpr : oldMap.getResults()) {
+ auto singleResMap = builder.getAffineMap(
+ oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
+ auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+ oldMapOperands);
+ oldMemRefOperands.push_back(afOp);
+ affineApplyOps.push_back(afOp);
+ }
+ } else {
+ oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
+ }
+
+ // Construct new indices as a remap of the old ones if a remapping has been
+ // provided. The indices of a memref come right after it, i.e.,
+ // at position memRefOperandPos + 1.
+ SmallVector<Value *, 4> remapOperands;
+ remapOperands.reserve(extraOperands.size() + oldMemRefRank);
+ remapOperands.append(extraOperands.begin(), extraOperands.end());
+ remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
+
+ SmallVector<Value *, 4> remapOutputs;
+ remapOutputs.reserve(oldMemRefRank);
+
+ if (indexRemap &&
+ indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
+ // Remapped indices.
+ for (auto resultExpr : indexRemap.getResults()) {
+ auto singleResMap = builder.getAffineMap(
+ indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
+ auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
+ remapOperands);
+ remapOutputs.push_back(afOp);
+ affineApplyOps.push_back(afOp);
+ }
+ } else {
+ // No remapping specified.
+ remapOutputs.append(remapOperands.begin(), remapOperands.end());
+ }
+
+ SmallVector<Value *, 4> newMapOperands;
+ newMapOperands.reserve(newMemRefRank);
+
+ // Prepend 'extraIndices' in 'newMapOperands'.
+ for (auto *extraIndex : extraIndices) {
+ assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
+ "single result op's expected to generate these indices");
+ assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
+ "invalid memory op index");
+ newMapOperands.push_back(extraIndex);
+ }
+
+ // Append 'remapOutputs' to 'newMapOperands'.
+ newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
+
+ // Create new fully composed AffineMap for new op to be created.
+ assert(newMapOperands.size() == newMemRefRank);
+ auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
+ // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
+ fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
+ newMap = simplifyAffineMap(newMap);
+ canonicalizeMapAndOperands(&newMap, &newMapOperands);
+ // Remove any affine.apply's that became dead as a result of composition.
+ for (auto *value : affineApplyOps)
+ if (value->use_empty())
+ value->getDefiningOp()->erase();
+
+ // Construct the new operation using this memref.
+ OperationState state(op->getLoc(), op->getName());
+ state.setOperandListToResizable(op->hasResizableOperandsList());
+ state.operands.reserve(op->getNumOperands() + extraIndices.size());
+ // Insert the non-memref operands.
+ state.operands.append(op->operand_begin(),
+ op->operand_begin() + memRefOperandPos);
+ // Insert the new memref value.
+ state.operands.push_back(newMemRef);
+
+ // Insert the new memref map operands.
+ state.operands.append(newMapOperands.begin(), newMapOperands.end());
+
+ // Insert the remaining operands unmodified.
+ state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
+ oldMapNumInputs,
+ op->operand_end());
+
+ // Result types don't change. Both memref's are of the same elemental type.
+ state.types.reserve(op->getNumResults());
+ for (auto *result : op->getResults())
+ state.types.push_back(result->getType());
+
+ // Add attribute for 'newMap', other Attributes do not change.
+ auto newMapAttr = builder.getAffineMapAttr(newMap);
+ for (auto namedAttr : op->getAttrs()) {
+ if (namedAttr.first == oldMapAttrPair.first) {
+ state.attributes.push_back({namedAttr.first, newMapAttr});
+ } else {
+ state.attributes.push_back(namedAttr);
+ }
+ }
+
+ // Create the new operation.
+ auto *repOp = builder.createOperation(state);
+ op->replaceAllUsesWith(repOp);
+ op->erase();
+
+ return success();
+}
+
+LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
+ ArrayRef<Value *> extraIndices,
+ AffineMap indexRemap,
+ ArrayRef<Value *> extraOperands,
+ Operation *domInstFilter,
+ Operation *postDomInstFilter) {
+ unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
+ (void)newMemRefRank; // unused in opt mode
+ unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
+ (void)oldMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
@@ -89,170 +254,44 @@
postDomInfo = std::make_unique<PostDominanceInfo>(
postDomInstFilter->getParentOfType<FuncOp>());
- // The ops where memref replacement succeeds are replaced with new ones.
- SmallVector<Operation *, 8> opsToErase;
-
- // Walk all uses of old memref. Operation using the memref gets replaced.
- for (auto *opInst : llvm::make_early_inc_range(oldMemRef->getUsers())) {
+ // Walk all uses of old memref; collect ops to perform replacement. We use a
+ // DenseSet since an operation could potentially have multiple uses of a
+ // memref (although rare), and the replacement later is going to erase ops.
+ DenseSet<Operation *> opsToReplace;
+ for (auto *op : oldMemRef->getUsers()) {
// Skip this use if it's not dominated by domInstFilter.
- if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
+ if (domInstFilter && !domInfo->dominates(domInstFilter, op))
continue;
// Skip this use if it's not post-dominated by postDomInstFilter.
- if (postDomInstFilter &&
- !postDomInfo->postDominates(postDomInstFilter, opInst))
+ if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
continue;
- // Skip dealloc's - no replacement is necessary, and a replacement doesn't
- // hurt dealloc's.
- if (isa<DeallocOp>(opInst))
+ // Skip dealloc's - no replacement is necessary, and a memref replacement
+ // at other uses doesn't hurt these dealloc's.
+ if (isa<DeallocOp>(op))
continue;
- // Check if the memref was used in a non-deferencing context. It is fine for
- // the memref to be used in a non-deferencing way outside of the region
- // where this replacement is happening.
- if (!isMemRefDereferencingOp(*opInst))
- // Failure: memref used in a non-deferencing op (potentially escapes); no
- // replacement in these cases.
- return false;
+ // Check if the memref was used in a non-dereferencing context. It is fine
+ // for the memref to be used in a non-dereferencing way outside of the
+ // region where this replacement is happening.
+ if (!isMemRefDereferencingOp(*op))
+ // Failure: memref used in a non-dereferencing op (potentially escapes);
+ // no replacement in these cases.
+ return failure();
- auto getMemRefOperandPos = [&]() -> unsigned {
- unsigned i, e;
- for (i = 0, e = opInst->getNumOperands(); i < e; i++) {
- if (opInst->getOperand(i) == oldMemRef)
- break;
- }
- assert(i < opInst->getNumOperands() && "operand guaranteed to be found");
- return i;
- };
-
- OpBuilder builder(opInst);
- unsigned memRefOperandPos = getMemRefOperandPos();
- NamedAttribute oldMapAttrPair =
- getAffineMapAttrForMemRef(opInst, oldMemRef);
- AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
- unsigned oldMapNumInputs = oldMap.getNumInputs();
- SmallVector<Value *, 4> oldMapOperands(
- opInst->operand_begin() + memRefOperandPos + 1,
- opInst->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
- SmallVector<Value *, 4> affineApplyOps;
-
- // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
- SmallVector<Value *, 4> oldMemRefOperands;
- oldMemRefOperands.reserve(oldMemRefRank);
- if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
- for (auto resultExpr : oldMap.getResults()) {
- auto singleResMap = builder.getAffineMap(
- oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
- auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
- singleResMap, oldMapOperands);
- oldMemRefOperands.push_back(afOp);
- affineApplyOps.push_back(afOp);
- }
- } else {
- oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
- }
-
- // Construct new indices as a remap of the old ones if a remapping has been
- // provided. The indices of a memref come right after it, i.e.,
- // at position memRefOperandPos + 1.
- SmallVector<Value *, 4> remapOperands;
- remapOperands.reserve(extraOperands.size() + oldMemRefRank);
- remapOperands.append(extraOperands.begin(), extraOperands.end());
- remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
-
- SmallVector<Value *, 4> remapOutputs;
- remapOutputs.reserve(oldMemRefRank);
-
- if (indexRemap &&
- indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
- // Remapped indices.
- for (auto resultExpr : indexRemap.getResults()) {
- auto singleResMap = builder.getAffineMap(
- indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
- auto afOp = builder.create<AffineApplyOp>(opInst->getLoc(),
- singleResMap, remapOperands);
- remapOutputs.push_back(afOp);
- affineApplyOps.push_back(afOp);
- }
- } else {
- // No remapping specified.
- remapOutputs.append(remapOperands.begin(), remapOperands.end());
- }
-
- SmallVector<Value *, 4> newMapOperands;
- newMapOperands.reserve(newMemRefRank);
-
- // Prepend 'extraIndices' in 'newMapOperands'.
- for (auto *extraIndex : extraIndices) {
- assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
- "single result op's expected to generate these indices");
- assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
- "invalid memory op index");
- newMapOperands.push_back(extraIndex);
- }
-
- // Append 'remapOutputs' to 'newMapOperands'.
- newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
-
- // Create new fully composed AffineMap for new op to be created.
- assert(newMapOperands.size() == newMemRefRank);
- auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
- // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
- fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
- newMap = simplifyAffineMap(newMap);
- canonicalizeMapAndOperands(&newMap, &newMapOperands);
- // Remove any affine.apply's that became dead as a result of composition.
- for (auto *value : affineApplyOps)
- if (value->use_empty())
- value->getDefiningOp()->erase();
-
- // Construct the new operation using this memref.
- OperationState state(opInst->getLoc(), opInst->getName());
- state.setOperandListToResizable(opInst->hasResizableOperandsList());
- state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
- // Insert the non-memref operands.
- state.operands.append(opInst->operand_begin(),
- opInst->operand_begin() + memRefOperandPos);
- // Insert the new memref value.
- state.operands.push_back(newMemRef);
-
- // Insert the new memref map operands.
- state.operands.append(newMapOperands.begin(), newMapOperands.end());
-
- // Insert the remaining operands unmodified.
- state.operands.append(opInst->operand_begin() + memRefOperandPos + 1 +
- oldMapNumInputs,
- opInst->operand_end());
-
- // Result types don't change. Both memref's are of the same elemental type.
- state.types.reserve(opInst->getNumResults());
- for (auto *result : opInst->getResults())
- state.types.push_back(result->getType());
-
- // Add attribute for 'newMap', other Attributes do not change.
- auto newMapAttr = builder.getAffineMapAttr(newMap);
- for (auto namedAttr : opInst->getAttrs()) {
- if (namedAttr.first == oldMapAttrPair.first) {
- state.attributes.push_back({namedAttr.first, newMapAttr});
- } else {
- state.attributes.push_back(namedAttr);
- }
- }
-
- // Create the new operation.
- auto *repOp = builder.createOperation(state);
- opInst->replaceAllUsesWith(repOp);
-
- // Collect and erase at the end since one of these op's could be
- // domInstFilter or postDomInstFilter as well!
- opsToErase.push_back(opInst);
+ // We'll first collect and then replace --- since replacement erases the op
+ // that has the use, and that op could be postDomFilter or domFilter itself!
+ opsToReplace.insert(op);
}
- for (auto *opInst : opsToErase)
- opInst->erase();
+ for (auto *op : opsToReplace) {
+ if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
+ indexRemap, extraOperands)))
+ assert(false && "memref replacement guaranteed to succeed here");
+ }
- return true;
+ return success();
}
/// Given an operation, inserts one or more single result affine