NFC: Refactor Function to be value typed.
Move the data members out of Function and into a new impl storage class 'FunctionStorage'. This allows for Function to become value typed, which will greatly simplify the transition of Function to FuncOp(given that FuncOp is also value typed).
PiperOrigin-RevId: 255983022
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 8a2002c..394b3ef 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -40,7 +40,7 @@
void Canonicalizer::runOnFunction() {
OwningRewritePatternList patterns;
- auto &func = getFunction();
+ auto func = getFunction();
// TODO: Instead of adding all known patterns from the whole system lazily add
// and cache the canonicalization patterns for ops we see in practice when
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index be60ada..84f00b9 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -849,7 +849,7 @@
/// error, success otherwise. If 'signatureConversion' is provided, the
/// arguments of the entry block are updated accordingly.
LogicalResult
- convertFunction(Function *f,
+ convertFunction(Function f,
TypeConverter::SignatureConversion *signatureConversion);
/// Converts the given region starting from the entry block and following the
@@ -957,22 +957,22 @@
}
LogicalResult FunctionConverter::convertFunction(
- Function *f, TypeConverter::SignatureConversion *signatureConversion) {
+ Function f, TypeConverter::SignatureConversion *signatureConversion) {
// If this is an external function, there is nothing else to do.
- if (f->isExternal())
+ if (f.isExternal())
return success();
- DialectConversionRewriter rewriter(f->getBody(), typeConverter);
+ DialectConversionRewriter rewriter(f.getBody(), typeConverter);
// Update the signature of the entry block.
if (signatureConversion) {
rewriter.argConverter.convertSignature(
- &f->getBody().front(), *signatureConversion, rewriter.mapping);
+ &f.getBody().front(), *signatureConversion, rewriter.mapping);
}
// Rewrite the function body.
if (failed(
- convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) {
+ convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) {
// Reset any of the generated rewrites.
rewriter.discardRewrites();
return failure();
@@ -1124,24 +1124,6 @@
// applyConversionPatterns
//===----------------------------------------------------------------------===//
-namespace {
-/// This class represents a function to be converted. It allows for converting
-/// the body of functions and the signature in two phases.
-struct ConvertedFunction {
- ConvertedFunction(Function *fn, FunctionType newType,
- ArrayRef<NamedAttributeList> newFunctionArgAttrs)
- : fn(fn), newType(newType),
- newFunctionArgAttrs(newFunctionArgAttrs.begin(),
- newFunctionArgAttrs.end()) {}
-
- /// The function to convert.
- Function *fn;
- /// The new type and argument attributes for the function.
- FunctionType newType;
- SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
-};
-} // end anonymous namespace
-
/// Convert the given module with the provided conversion patterns and type
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
@@ -1149,37 +1131,33 @@
mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
- std::vector<Function *> allFunctions;
- allFunctions.reserve(module.getFunctions().size());
- for (auto &func : module)
- allFunctions.push_back(&func);
+ SmallVector<Function, 32> allFunctions(module.getFunctions());
return applyConversionPatterns(allFunctions, target, converter,
std::move(patterns));
}
/// Convert the given functions with the provided conversion patterns.
LogicalResult mlir::applyConversionPatterns(
- ArrayRef<Function *> fns, ConversionTarget &target,
+ MutableArrayRef<Function> fns, ConversionTarget &target,
TypeConverter &converter, OwningRewritePatternList &&patterns) {
if (fns.empty())
return success();
// Build the function converter.
- FunctionConverter funcConverter(fns.front()->getContext(), target, patterns,
- &converter);
+ auto *ctx = fns.front().getContext();
+ FunctionConverter funcConverter(ctx, target, patterns, &converter);
// Try to convert each of the functions within the module.
- auto *ctx = fns.front()->getContext();
- for (auto *func : fns) {
+ for (auto func : fns) {
// Convert the function type using the type converter.
auto conversion =
- converter.convertSignature(func->getType(), func->getAllArgAttrs());
+ converter.convertSignature(func.getType(), func.getAllArgAttrs());
if (!conversion)
return failure();
// Update the function signature.
- func->setType(conversion->getConvertedType(ctx));
- func->setAllArgAttrs(conversion->getConvertedArgAttrs());
+ func.setType(conversion->getConvertedType(ctx));
+ func.setAllArgAttrs(conversion->getConvertedArgAttrs());
// Convert the body of this function.
if (failed(funcConverter.convertFunction(func, &*conversion)))
@@ -1193,9 +1171,9 @@
/// convert as many of the operations within 'fn' as possible given the set of
/// patterns.
LogicalResult
-mlir::applyConversionPatterns(Function &fn, ConversionTarget &target,
+mlir::applyConversionPatterns(Function fn, ConversionTarget &target,
OwningRewritePatternList &&patterns) {
// Convert the body of this function.
FunctionConverter converter(fn.getContext(), target, patterns);
- return converter.convertFunction(&fn, /*signatureConversion=*/nullptr);
+ return converter.convertFunction(fn, /*signatureConversion=*/nullptr);
}
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index 5a926ce..a3aa092 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -214,7 +214,7 @@
static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
emitRemarkForBlock(Block &block) {
auto *op = block.getContainingOp();
- return op ? op->emitRemark() : block.getFunction()->emitRemark();
+ return op ? op->emitRemark() : block.getFunction().emitRemark();
}
/// Creates a buffer in the faster memory space for the specified region;
@@ -246,8 +246,8 @@
OpBuilder &b = region.isWrite() ? epilogue : prologue;
// Builder to create constants at the top level.
- auto *func = block->getFunction();
- OpBuilder top(func->getBody());
+ auto func = block->getFunction();
+ OpBuilder top(func.getBody());
auto loc = region.loc;
auto *memref = region.memref;
@@ -751,14 +751,14 @@
if (auto *op = block->getContainingOp())
op->emitError(str);
else
- block->getFunction()->emitError(str);
+ block->getFunction().emitError(str);
}
return totalDmaBuffersSizeInBytes;
}
void DmaGeneration::runOnFunction() {
- Function &f = getFunction();
+ Function f = getFunction();
OpBuilder topBuilder(f.getBody());
zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 8d2e75b..77b944f 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -257,7 +257,7 @@
// Initializes the dependence graph based on operations in 'f'.
// Returns true on success, false otherwise.
- bool init(Function &f);
+ bool init(Function f);
// Returns the graph node for 'id'.
Node *getNode(unsigned id) {
@@ -637,7 +637,7 @@
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
-bool MemRefDependenceGraph::init(Function &f) {
+bool MemRefDependenceGraph::init(Function f) {
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
// TODO: support multi-block functions.
@@ -859,7 +859,7 @@
// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
// Builder to create constants at the top level.
- OpBuilder top(forInst->getFunction()->getBody());
+ OpBuilder top(forInst->getFunction().getBody());
// Create new memref type based on slice bounds.
auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
@@ -1750,9 +1750,9 @@
};
// Search for siblings which load the same memref function argument.
- auto *fn = dstNode->op->getFunction();
- for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) {
- for (auto *user : fn->getArgument(i)->getUsers()) {
+ auto fn = dstNode->op->getFunction();
+ for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
+ for (auto *user : fn.getArgument(i)->getUsers()) {
if (auto loadOp = dyn_cast<LoadOp>(user)) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index c1be6e8..2744e5c 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -261,7 +261,7 @@
// Identify valid and profitable bands of loops to tile. This is currently just
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
-static void getTileableBands(Function &f,
+static void getTileableBands(Function f,
std::vector<SmallVector<AffineForOp, 6>> *bands) {
// Get maximal perfect nest of 'affine.for' insts starting from root
// (inclusive).
diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp
index 0595392..6f13f62 100644
--- a/mlir/lib/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Transforms/LoopUnroll.cpp
@@ -92,8 +92,8 @@
// Store innermost loops as we walk.
std::vector<AffineForOp> loops;
- void walkPostOrder(Function *f) {
- for (auto &b : *f)
+ void walkPostOrder(Function f) {
+ for (auto &b : f)
walkPostOrder(b.begin(), b.end());
}
@@ -142,10 +142,10 @@
? clUnrollNumRepetitions
: 1;
// If the call back is provided, we will recurse until no loops are found.
- Function &func = getFunction();
+ Function func = getFunction();
for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
InnermostLoopGatherer ilg;
- ilg.walkPostOrder(&func);
+ ilg.walkPostOrder(func);
auto &loops = ilg.loops;
if (loops.empty())
break;
diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp
index 77a23b1..df30e27 100644
--- a/mlir/lib/Transforms/LowerAffine.cpp
+++ b/mlir/lib/Transforms/LowerAffine.cpp
@@ -726,7 +726,7 @@
} // end namespace
-LogicalResult mlir::lowerAffineConstructs(Function &function) {
+LogicalResult mlir::lowerAffineConstructs(Function function) {
OwningRewritePatternList patterns;
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index cd92198..f59f100 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -636,7 +636,7 @@
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
- LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs()));
+ LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs()));
// slice are topologically sorted, we can just erase them in reverse
// order. Reverse iterator does not just work simply with an operator*
@@ -667,7 +667,7 @@
/// because we currently disallow vectorization of defs that come from another
/// scope.
/// TODO(ntv): please document return value.
-static bool materialize(Function *f, const SetVector<Operation *> &terminators,
+static bool materialize(Function f, const SetVector<Operation *> &terminators,
MaterializationState *state) {
DenseSet<Operation *> seen;
DominanceInfo domInfo(f);
@@ -721,7 +721,7 @@
return true;
}
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
- LLVM_DEBUG(f->print(dbgs()));
+ LLVM_DEBUG(f.print(dbgs()));
}
return false;
}
@@ -731,13 +731,13 @@
NestedPatternContext mlContext;
// TODO(ntv): Check to see if this supports arbitrary top-level code.
- Function *f = &getFunction();
- if (f->getBlocks().size() != 1)
+ Function f = getFunction();
+ if (f.getBlocks().size() != 1)
return;
using matcher::Op;
LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
- LLVM_DEBUG(f->print(dbgs()));
+ LLVM_DEBUG(f.print(dbgs()));
MaterializationState state(hwVectorSize);
// Get the hardware vector type.
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index c5676af..1208e2f 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -212,7 +212,7 @@
void MemRefDataFlowOpt::runOnFunction() {
// Only supports single block functions at the moment.
- Function &f = getFunction();
+ Function f = getFunction();
if (f.getBlocks().size() != 1) {
markAllAnalysesPreserved();
return;
diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp
index f97f549..c7c3621 100644
--- a/mlir/lib/Transforms/StripDebugInfo.cpp
+++ b/mlir/lib/Transforms/StripDebugInfo.cpp
@@ -29,7 +29,7 @@
} // end anonymous namespace
void StripDebugInfo::runOnFunction() {
- Function &func = getFunction();
+ Function func = getFunction();
auto unknownLoc = UnknownLoc::get(&getContext());
// Strip the debug info from the function and its operations.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 47ca378..e185f70 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -44,7 +44,7 @@
/// applies the locally optimal patterns in a roughly "bottom up" way.
class GreedyPatternRewriteDriver : public PatternRewriter {
public:
- explicit GreedyPatternRewriteDriver(Function &fn,
+ explicit GreedyPatternRewriteDriver(Function fn,
OwningRewritePatternList &&patterns)
: PatternRewriter(fn.getBody()), matcher(std::move(patterns)) {
worklist.reserve(64);
@@ -213,7 +213,7 @@
/// patterns in a greedy work-list driven manner. Return true if no more
/// patterns can be matched in the result function.
///
-bool mlir::applyPatternsGreedily(Function &fn,
+bool mlir::applyPatternsGreedily(Function fn,
OwningRewritePatternList &&patterns) {
GreedyPatternRewriteDriver driver(fn, std::move(patterns));
bool converged = driver.simplifyFunction(maxPatternMatchIterations);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 728123f..4ddf93c 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -125,7 +125,7 @@
Operation *op = forOp.getOperation();
if (!iv->use_empty()) {
if (forOp.hasConstantLowerBound()) {
- OpBuilder topBuilder(op->getFunction()->getBody());
+ OpBuilder topBuilder(op->getFunction().getBody());
auto constOp = topBuilder.create<ConstantIndexOp>(
forOp.getLoc(), forOp.getConstantLowerBound());
iv->replaceAllUsesWith(constOp);
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index 39a05d8..3fca26b 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -1194,7 +1194,7 @@
/// Applies vectorization to the current Function by searching over a bunch of
/// predetermined patterns.
void Vectorize::runOnFunction() {
- Function &f = getFunction();
+ Function f = getFunction();
if (!fastestVaryingPattern.empty() &&
fastestVaryingPattern.size() != vectorSizes.size()) {
f.emitRemark("Fastest varying pattern specified with different size than "
@@ -1220,7 +1220,7 @@
unsigned patternDepth = pat.getDepth();
SmallVector<NestedMatch, 8> matches;
- pat.match(&f, &matches);
+ pat.match(f, &matches);
// Iterate over all the top-level matches and vectorize eagerly.
// This automatically prunes intersecting matches.
for (auto m : matches) {
diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp
index 1f2ab69..3c1a1b3 100644
--- a/mlir/lib/Transforms/ViewFunctionGraph.cpp
+++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp
@@ -53,13 +53,13 @@
} // end namespace llvm
-void mlir::viewGraph(Function &function, const llvm::Twine &name,
+void mlir::viewGraph(Function function, const llvm::Twine &name,
bool shortNames, const llvm::Twine &title,
llvm::GraphProgram::Name program) {
llvm::ViewGraph(&function, name, shortNames, title, program);
}
-llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function,
+llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function function,
bool shortNames, const llvm::Twine &title) {
return llvm::WriteGraph(os, &function, shortNames, title);
}