Extend InstVisitor and Walker to handle arbitrary CFG functions, expand the
Function::walk functionality into f->walkInsts/Ops which allows visiting all
instructions, not just ops. Eliminate Function::getBody() and
Function::getReturn() helpers which crash in CFG functions, and were only kept
around as a bridge.
This is step 25/n towards merging instructions and statements.
PiperOrigin-RevId: 227243966
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp
index e8b6688..d21f2f80 100644
--- a/mlir/lib/Analysis/MemRefBoundCheck.cpp
+++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -41,9 +41,7 @@
struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> {
explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {}
- PassResult runOnMLFunction(Function *f) override;
- // Not applicable to CFG functions.
- PassResult runOnCFGFunction(Function *f) override { return success(); }
+ PassResult runOnFunction(Function *f) override;
void visitOperationInst(OperationInst *opInst);
@@ -67,7 +65,7 @@
// TODO(bondhugula): do this for DMA ops as well.
}
-PassResult MemRefBoundCheck::runOnMLFunction(Function *f) {
+PassResult MemRefBoundCheck::runOnFunction(Function *f) {
return walk(f), success();
}
diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
index 8391f15..1df935f 100644
--- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp
+++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
@@ -44,9 +44,7 @@
explicit MemRefDependenceCheck()
: FunctionPass(&MemRefDependenceCheck::passID) {}
- PassResult runOnMLFunction(Function *f) override;
- // Not applicable to CFG functions.
- PassResult runOnCFGFunction(Function *f) override { return success(); }
+ PassResult runOnFunction(Function *f) override;
void visitOperationInst(OperationInst *opInst) {
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
@@ -168,7 +166,7 @@
// Walks the Function 'f' adding load and store ops to 'loadsAndStores'.
// Runs pair-wise dependence checks.
-PassResult MemRefDependenceCheck::runOnMLFunction(Function *f) {
+PassResult MemRefDependenceCheck::runOnFunction(Function *f) {
loadsAndStores.clear();
walk(f);
checkDependences(loadsAndStores);
diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp
index 07edb13..a8cad41 100644
--- a/mlir/lib/Analysis/OpStats.cpp
+++ b/mlir/lib/Analysis/OpStats.cpp
@@ -15,9 +15,9 @@
// limitations under the License.
// =============================================================================
-#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instructions.h"
+#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Pass.h"
#include "llvm/ADT/DenseMap.h"
@@ -33,11 +33,7 @@
// Prints the resultant operation stats post iterating over the module.
PassResult runOnModule(Module *m) override;
- // Process CFG function considering the instructions in basic blocks.
- PassResult runOnCFGFunction(Function *function) override;
-
- // Process ML functions and operation statments in ML functions.
- PassResult runOnMLFunction(Function *function) override;
+ PassResult runOnFunction(Function *function) override;
void visitOperationInst(OperationInst *inst);
// Print summary of op stats.
@@ -55,17 +51,9 @@
char PrintOpStatsPass::passID = 0;
PassResult PrintOpStatsPass::runOnModule(Module *m) {
- auto result = FunctionPass::runOnModule(m);
- if (!result)
- printSummary();
- return result;
-}
-
-PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) {
- for (const auto &bb : *function)
- for (const auto &inst : bb)
- if (auto *op = dyn_cast<OperationInst>(&inst))
- ++opCount[op->getName().getStringRef()];
+ for (auto &fn : *m)
+ (void)runOnFunction(&fn);
+ printSummary();
return success();
}
@@ -73,7 +61,7 @@
++opCount[inst->getName().getStringRef()];
}
-PassResult PrintOpStatsPass::runOnMLFunction(Function *function) {
+PassResult PrintOpStatsPass::runOnFunction(Function *function) {
walk(function);
return success();
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4bc2c94..098439b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -113,17 +113,12 @@
}
// Visit functions.
- void visitFunction(const Function *fn);
- void visitExtFunction(const Function *fn);
- void visitCFGFunction(const Function *fn);
- void visitMLFunction(const Function *fn);
void visitInstruction(const Instruction *inst);
void visitForInst(const ForInst *forInst);
void visitIfInst(const IfInst *ifInst);
void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
- void visitOperation(const OperationInst *op);
DenseMap<AffineMap, int> affineMapIds;
std::vector<AffineMap> affineMapsById;
@@ -161,7 +156,21 @@
}
}
-void ModuleState::visitOperation(const OperationInst *op) {
+void ModuleState::visitIfInst(const IfInst *ifInst) {
+ recordIntegerSetReference(ifInst->getIntegerSet());
+}
+
+void ModuleState::visitForInst(const ForInst *forInst) {
+ AffineMap lbMap = forInst->getLowerBoundMap();
+ if (!hasShorthandForm(lbMap))
+ recordAffineMapReference(lbMap);
+
+ AffineMap ubMap = forInst->getUpperBoundMap();
+ if (!hasShorthandForm(ubMap))
+ recordAffineMapReference(ubMap);
+}
+
+void ModuleState::visitOperationInst(const OperationInst *op) {
// Visit all the types used in the operation.
for (auto *operand : op->getOperands())
visitType(operand->getType());
@@ -173,50 +182,6 @@
visitAttribute(elt.second);
}
-void ModuleState::visitExtFunction(const Function *fn) {
- visitType(fn->getType());
-}
-
-void ModuleState::visitCFGFunction(const Function *fn) {
- visitType(fn->getType());
- for (auto &block : *fn) {
- for (auto &op : block.getInstructions()) {
- if (auto *opInst = dyn_cast<OperationInst>(&op))
- visitOperation(opInst);
- else {
- llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported");
- }
- }
- }
-}
-
-void ModuleState::visitIfInst(const IfInst *ifInst) {
- recordIntegerSetReference(ifInst->getIntegerSet());
- for (auto &childInst : *ifInst->getThen())
- visitInstruction(&childInst);
- if (ifInst->hasElse())
- for (auto &childInst : *ifInst->getElse())
- visitInstruction(&childInst);
-}
-
-void ModuleState::visitForInst(const ForInst *forInst) {
- AffineMap lbMap = forInst->getLowerBoundMap();
- if (!hasShorthandForm(lbMap))
- recordAffineMapReference(lbMap);
-
- AffineMap ubMap = forInst->getUpperBoundMap();
- if (!hasShorthandForm(ubMap))
- recordAffineMapReference(ubMap);
-
- for (auto &childInst : *forInst->getBody())
- visitInstruction(&childInst);
-}
-
-void ModuleState::visitOperationInst(const OperationInst *opInst) {
- for (auto attr : opInst->getAttrs())
- visitAttribute(attr.second);
-}
-
void ModuleState::visitInstruction(const Instruction *inst) {
switch (inst->getKind()) {
case Instruction::Kind::If:
@@ -225,33 +190,16 @@
return visitForInst(cast<ForInst>(inst));
case Instruction::Kind::OperationInst:
return visitOperationInst(cast<OperationInst>(inst));
- default:
- return;
- }
-}
-
-void ModuleState::visitMLFunction(const Function *fn) {
- visitType(fn->getType());
- for (auto &inst : *fn->getBody()) {
- ModuleState::visitInstruction(&inst);
- }
-}
-
-void ModuleState::visitFunction(const Function *fn) {
- switch (fn->getKind()) {
- case Function::Kind::ExtFunc:
- return visitExtFunction(fn);
- case Function::Kind::CFGFunc:
- return visitCFGFunction(fn);
- case Function::Kind::MLFunc:
- return visitMLFunction(fn);
}
}
// Initializes module state, populating affine map and integer set state.
void ModuleState::initialize(const Module *module) {
for (auto &fn : *module) {
- visitFunction(&fn);
+ visitType(fn.getType());
+
+ const_cast<Function &>(fn).walkInsts(
+ [&](Instruction *op) { ModuleState::visitInstruction(op); });
}
}
@@ -1167,12 +1115,26 @@
}
}
+/// Return true if the introducer for the specified block should be printed.
+static bool shouldPrintBlockArguments(const Block *block) {
+ // Never print the entry block of the function - it is included in the
+ // argument list.
+ if (block == &block->getFunction()->front())
+ return false;
+
+ // If this is the first block in a nested region, and if there are no
+ // arguments, then we can omit it.
+ if (block == &block->getParent()->front() && block->getNumArguments() == 0)
+ return false;
+
+ // Otherwise print it.
+ return true;
+}
+
void FunctionPrinter::print(const Block *block) {
// Print the block label and argument list, unless this is the first block of
// the function, or the first block of an IfInst/ForInst with no arguments.
- if (block != &block->getFunction()->front() &&
- (block != &block->getParent()->front() ||
- block->getNumArguments() != 0)) {
+ if (shouldPrintBlockArguments(block)) {
os.indent(currentIndent);
printBlockName(block);
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index bacb504..b7346e9 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -161,15 +161,20 @@
// Function implementation.
//===----------------------------------------------------------------------===//
-const OperationInst *Function::getReturn() const {
- return cast<OperationInst>(&getBody()->back());
+void Function::walkInsts(std::function<void(Instruction *)> callback) {
+ struct Walker : public InstWalker<Walker> {
+ std::function<void(Instruction *)> const &callback;
+ Walker(std::function<void(Instruction *)> const &callback)
+ : callback(callback) {}
+
+ void visitInstruction(Instruction *inst) { callback(inst); }
+ };
+
+ Walker v(callback);
+ v.walk(this);
}
-OperationInst *Function::getReturn() {
- return cast<OperationInst>(&getBody()->back());
-}
-
-void Function::walk(std::function<void(OperationInst *)> callback) {
+void Function::walkOps(std::function<void(OperationInst *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
@@ -182,7 +187,20 @@
v.walk(this);
}
-void Function::walkPostOrder(std::function<void(OperationInst *)> callback) {
+void Function::walkInstsPostOrder(std::function<void(Instruction *)> callback) {
+ struct Walker : public InstWalker<Walker> {
+ std::function<void(Instruction *)> const &callback;
+ Walker(std::function<void(Instruction *)> const &callback)
+ : callback(callback) {}
+
+ void visitOperationInst(Instruction *inst) { callback(inst); }
+ };
+
+ Walker v(callback);
+ v.walkPostOrder(this);
+}
+
+void Function::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp
index a9124b0..0ecd248 100644
--- a/mlir/lib/Transforms/ConvertToCFG.cpp
+++ b/mlir/lib/Transforms/ConvertToCFG.cpp
@@ -485,8 +485,10 @@
}
// Convert instructions in order.
- for (auto &inst : *mlFunc->getBody()) {
- visit(&inst);
+ for (auto &block : *mlFunc) {
+ for (auto &inst : block) {
+ visit(&inst);
+ }
}
return cfgFunc;
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index bc7f31f..b5e7653 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -62,9 +62,7 @@
}
}
- // Not applicable to CFG functions.
- PassResult runOnCFGFunction(Function *f) override { return success(); }
- PassResult runOnMLFunction(Function *f) override;
+ PassResult runOnFunction(Function *f) override;
void runOnForInst(ForInst *forInst);
void visitOperationInst(OperationInst *opInst);
@@ -425,10 +423,12 @@
<< " KiB of DMA buffers in fast memory space\n";);
}
-PassResult DmaGeneration::runOnMLFunction(Function *f) {
- for (auto &inst : *f->getBody()) {
- if (auto *forInst = dyn_cast<ForInst>(&inst)) {
- runOnForInst(forInst);
+PassResult DmaGeneration::runOnFunction(Function *f) {
+ for (auto &block : *f) {
+ for (auto &inst : block) {
+ if (auto *forInst = dyn_cast<ForInst>(&inst)) {
+ runOnForInst(forInst);
+ }
}
}
// This function never leaves the IR in an invalid state.
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 97dea75..1854cd9 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -348,7 +348,12 @@
bool MemRefDependenceGraph::init(Function *f) {
unsigned id = 0;
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
- for (auto &inst : *f->getBody()) {
+
+ // TODO: support multi-block functions.
+ if (f->getBlocks().size() != 1)
+ return false;
+
+ for (auto &inst : f->front()) {
if (auto *forInst = dyn_cast<ForInst>(&inst)) {
// Create graph node 'id' to represent top-level 'forInst' and record
// all loads and store accesses it contains.
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 8f3be8a..fa39b7d 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -224,15 +224,17 @@
do {
band.push_back(currInst);
} while (currInst->getBody()->getInstructions().size() == 1 &&
- (currInst = dyn_cast<ForInst>(&*currInst->getBody()->begin())));
+ (currInst = dyn_cast<ForInst>(&currInst->getBody()->front())));
bands->push_back(band);
};
- for (auto &inst : *f->getBody()) {
- auto *forInst = dyn_cast<ForInst>(&inst);
- if (!forInst)
- continue;
- getMaximalPerfectLoopNest(forInst);
+ for (auto &block : *f) {
+ for (auto &inst : block) {
+ auto *forInst = dyn_cast<ForInst>(&inst);
+ if (!forInst)
+ continue;
+ getMaximalPerfectLoopNest(forInst);
+ }
}
}
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index f59659c..1297560 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -74,7 +74,7 @@
: FunctionPass(&LoopUnrollAndJam::passID),
unrollJamFactor(unrollJamFactor) {}
- PassResult runOnMLFunction(Function *f) override;
+ PassResult runOnFunction(Function *f) override;
bool runOnForInst(ForInst *forInst);
static char passID;
@@ -88,15 +88,15 @@
unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
}
-PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) {
+PassResult LoopUnrollAndJam::runOnFunction(Function *f) {
// Currently, just the outermost loop from the first loop nest is
// unroll-and-jammed by this pass. However, runOnForInst can be called on any
// for Inst.
- auto *forInst = dyn_cast<ForInst>(f->getBody()->begin());
- if (!forInst)
- return success();
+ auto &entryBlock = f->front();
+ if (!entryBlock.empty())
+ if (auto *forInst = dyn_cast<ForInst>(&entryBlock.front()))
+ runOnForInst(forInst);
- runOnForInst(forInst);
return success();
}
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index bcb2abf..5d55800 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -236,7 +236,6 @@
makeFuncWiseState(Function *f) const override {
auto state = llvm::make_unique<LowerVectorTransfersState>();
auto builder = FuncBuilder(f);
- builder.setInsertionPointToStart(f->getBody());
state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0);
return state;
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 6064d1f..c37b997 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -280,7 +280,7 @@
};
GreedyPatternRewriteDriver driver(std::move(patterns));
- fn->walk([&](OperationInst *inst) { driver.addToWorklist(inst); });
+ fn->walkOps([&](OperationInst *inst) { driver.addToWorklist(inst); });
FuncBuilder mlBuilder(fn);
MLFuncRewriter rewriter(driver, mlBuilder);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 9303937..4168dda 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -101,7 +101,7 @@
if (!forInst->use_empty()) {
if (forInst->hasConstantLowerBound()) {
auto *mlFunc = forInst->getFunction();
- FuncBuilder topBuilder(&mlFunc->getBody()->front());
+ FuncBuilder topBuilder(mlFunc);
auto constOp = topBuilder.create<ConstantIndexOp>(
forInst->getLoc(), forInst->getConstantLowerBound());
forInst->replaceAllUsesWith(constOp);
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index bbb703c..58bb390 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -651,7 +651,7 @@
struct Vectorize : public FunctionPass {
Vectorize() : FunctionPass(&Vectorize::passID) {}
- PassResult runOnMLFunction(Function *f) override;
+ PassResult runOnFunction(Function *f) override;
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
MLFunctionMatcherContext MLContext;
@@ -1267,7 +1267,7 @@
/// Applies vectorization to the current Function by searching over a bunch of
/// predetermined patterns.
-PassResult Vectorize::runOnMLFunction(Function *f) {
+PassResult Vectorize::runOnFunction(Function *f) {
for (auto pat : makePatterns()) {
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n******************************************");