Extend InstVisitor and Walker to handle arbitrary CFG functions, expand the
Function::walk functionality into f->walkInsts/Ops which allows visiting all
instructions, not just ops.  Eliminate Function::getBody() and
Function::getReturn() helpers which crash in CFG functions, and were only kept
around as a bridge.

This is step 25/n towards merging instructions and statements.

PiperOrigin-RevId: 227243966
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp
index e8b6688..d21f2f80 100644
--- a/mlir/lib/Analysis/MemRefBoundCheck.cpp
+++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -41,9 +41,7 @@
 struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> {
   explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {}
 
-  PassResult runOnMLFunction(Function *f) override;
-  // Not applicable to CFG functions.
-  PassResult runOnCFGFunction(Function *f) override { return success(); }
+  PassResult runOnFunction(Function *f) override;
 
   void visitOperationInst(OperationInst *opInst);
 
@@ -67,7 +65,7 @@
   // TODO(bondhugula): do this for DMA ops as well.
 }
 
-PassResult MemRefBoundCheck::runOnMLFunction(Function *f) {
+PassResult MemRefBoundCheck::runOnFunction(Function *f) {
   return walk(f), success();
 }
 
diff --git a/mlir/lib/Analysis/MemRefDependenceCheck.cpp b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
index 8391f15..1df935f 100644
--- a/mlir/lib/Analysis/MemRefDependenceCheck.cpp
+++ b/mlir/lib/Analysis/MemRefDependenceCheck.cpp
@@ -44,9 +44,7 @@
   explicit MemRefDependenceCheck()
       : FunctionPass(&MemRefDependenceCheck::passID) {}
 
-  PassResult runOnMLFunction(Function *f) override;
-  // Not applicable to CFG functions.
-  PassResult runOnCFGFunction(Function *f) override { return success(); }
+  PassResult runOnFunction(Function *f) override;
 
   void visitOperationInst(OperationInst *opInst) {
     if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
@@ -168,7 +166,7 @@
 
 // Walks the Function 'f' adding load and store ops to 'loadsAndStores'.
 // Runs pair-wise dependence checks.
-PassResult MemRefDependenceCheck::runOnMLFunction(Function *f) {
+PassResult MemRefDependenceCheck::runOnFunction(Function *f) {
   loadsAndStores.clear();
   walk(f);
   checkDependences(loadsAndStores);
diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp
index 07edb13..a8cad41 100644
--- a/mlir/lib/Analysis/OpStats.cpp
+++ b/mlir/lib/Analysis/OpStats.cpp
@@ -15,9 +15,9 @@
 // limitations under the License.
 // =============================================================================
 
-#include "mlir/IR/Function.h"
 #include "mlir/IR/InstVisitor.h"
 #include "mlir/IR/Instructions.h"
+#include "mlir/IR/Module.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/Pass.h"
 #include "llvm/ADT/DenseMap.h"
@@ -33,11 +33,7 @@
   // Prints the resultant operation stats post iterating over the module.
   PassResult runOnModule(Module *m) override;
 
-  // Process CFG function considering the instructions in basic blocks.
-  PassResult runOnCFGFunction(Function *function) override;
-
-  // Process ML functions and operation statments in ML functions.
-  PassResult runOnMLFunction(Function *function) override;
+  PassResult runOnFunction(Function *function) override;
   void visitOperationInst(OperationInst *inst);
 
   // Print summary of op stats.
@@ -55,17 +51,9 @@
 char PrintOpStatsPass::passID = 0;
 
 PassResult PrintOpStatsPass::runOnModule(Module *m) {
-  auto result = FunctionPass::runOnModule(m);
-  if (!result)
-    printSummary();
-  return result;
-}
-
-PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) {
-  for (const auto &bb : *function)
-    for (const auto &inst : bb)
-      if (auto *op = dyn_cast<OperationInst>(&inst))
-        ++opCount[op->getName().getStringRef()];
+  for (auto &fn : *m)
+    (void)runOnFunction(&fn);
+  printSummary();
   return success();
 }
 
@@ -73,7 +61,7 @@
   ++opCount[inst->getName().getStringRef()];
 }
 
-PassResult PrintOpStatsPass::runOnMLFunction(Function *function) {
+PassResult PrintOpStatsPass::runOnFunction(Function *function) {
   walk(function);
   return success();
 }
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4bc2c94..098439b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -113,17 +113,12 @@
   }
 
   // Visit functions.
-  void visitFunction(const Function *fn);
-  void visitExtFunction(const Function *fn);
-  void visitCFGFunction(const Function *fn);
-  void visitMLFunction(const Function *fn);
   void visitInstruction(const Instruction *inst);
   void visitForInst(const ForInst *forInst);
   void visitIfInst(const IfInst *ifInst);
   void visitOperationInst(const OperationInst *opInst);
   void visitType(Type type);
   void visitAttribute(Attribute attr);
-  void visitOperation(const OperationInst *op);
 
   DenseMap<AffineMap, int> affineMapIds;
   std::vector<AffineMap> affineMapsById;
@@ -161,7 +156,21 @@
   }
 }
 
-void ModuleState::visitOperation(const OperationInst *op) {
+void ModuleState::visitIfInst(const IfInst *ifInst) {
+  recordIntegerSetReference(ifInst->getIntegerSet());
+}
+
+void ModuleState::visitForInst(const ForInst *forInst) {
+  AffineMap lbMap = forInst->getLowerBoundMap();
+  if (!hasShorthandForm(lbMap))
+    recordAffineMapReference(lbMap);
+
+  AffineMap ubMap = forInst->getUpperBoundMap();
+  if (!hasShorthandForm(ubMap))
+    recordAffineMapReference(ubMap);
+}
+
+void ModuleState::visitOperationInst(const OperationInst *op) {
   // Visit all the types used in the operation.
   for (auto *operand : op->getOperands())
     visitType(operand->getType());
@@ -173,50 +182,6 @@
     visitAttribute(elt.second);
 }
 
-void ModuleState::visitExtFunction(const Function *fn) {
-  visitType(fn->getType());
-}
-
-void ModuleState::visitCFGFunction(const Function *fn) {
-  visitType(fn->getType());
-  for (auto &block : *fn) {
-    for (auto &op : block.getInstructions()) {
-      if (auto *opInst = dyn_cast<OperationInst>(&op))
-        visitOperation(opInst);
-      else {
-        llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported");
-      }
-    }
-  }
-}
-
-void ModuleState::visitIfInst(const IfInst *ifInst) {
-  recordIntegerSetReference(ifInst->getIntegerSet());
-  for (auto &childInst : *ifInst->getThen())
-    visitInstruction(&childInst);
-  if (ifInst->hasElse())
-    for (auto &childInst : *ifInst->getElse())
-      visitInstruction(&childInst);
-}
-
-void ModuleState::visitForInst(const ForInst *forInst) {
-  AffineMap lbMap = forInst->getLowerBoundMap();
-  if (!hasShorthandForm(lbMap))
-    recordAffineMapReference(lbMap);
-
-  AffineMap ubMap = forInst->getUpperBoundMap();
-  if (!hasShorthandForm(ubMap))
-    recordAffineMapReference(ubMap);
-
-  for (auto &childInst : *forInst->getBody())
-    visitInstruction(&childInst);
-}
-
-void ModuleState::visitOperationInst(const OperationInst *opInst) {
-  for (auto attr : opInst->getAttrs())
-    visitAttribute(attr.second);
-}
-
 void ModuleState::visitInstruction(const Instruction *inst) {
   switch (inst->getKind()) {
   case Instruction::Kind::If:
@@ -225,33 +190,16 @@
     return visitForInst(cast<ForInst>(inst));
   case Instruction::Kind::OperationInst:
     return visitOperationInst(cast<OperationInst>(inst));
-  default:
-    return;
-  }
-}
-
-void ModuleState::visitMLFunction(const Function *fn) {
-  visitType(fn->getType());
-  for (auto &inst : *fn->getBody()) {
-    ModuleState::visitInstruction(&inst);
-  }
-}
-
-void ModuleState::visitFunction(const Function *fn) {
-  switch (fn->getKind()) {
-  case Function::Kind::ExtFunc:
-    return visitExtFunction(fn);
-  case Function::Kind::CFGFunc:
-    return visitCFGFunction(fn);
-  case Function::Kind::MLFunc:
-    return visitMLFunction(fn);
   }
 }
 
 // Initializes module state, populating affine map and integer set state.
 void ModuleState::initialize(const Module *module) {
   for (auto &fn : *module) {
-    visitFunction(&fn);
+    visitType(fn.getType());
+
+    const_cast<Function &>(fn).walkInsts(
+        [&](Instruction *op) { ModuleState::visitInstruction(op); });
   }
 }
 
@@ -1167,12 +1115,26 @@
   }
 }
 
+/// Return true if the introducer for the specified  block should be printed.
+static bool shouldPrintBlockArguments(const Block *block) {
+  // Never print the entry block of the function - it is included in the
+  // argument list.
+  if (block == &block->getFunction()->front())
+    return false;
+
+  // If this is the first block in a nested region, and if there are no
+  // arguments, then we can omit it.
+  if (block == &block->getParent()->front() && block->getNumArguments() == 0)
+    return false;
+
+  // Otherwise print it.
+  return true;
+}
+
 void FunctionPrinter::print(const Block *block) {
   // Print the block label and argument list, unless this is the first block of
   // the function, or the first block of an IfInst/ForInst with no arguments.
-  if (block != &block->getFunction()->front() &&
-      (block != &block->getParent()->front() ||
-       block->getNumArguments() != 0)) {
+  if (shouldPrintBlockArguments(block)) {
     os.indent(currentIndent);
     printBlockName(block);
 
diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp
index bacb504..b7346e9 100644
--- a/mlir/lib/IR/Function.cpp
+++ b/mlir/lib/IR/Function.cpp
@@ -161,15 +161,20 @@
 // Function implementation.
 //===----------------------------------------------------------------------===//
 
-const OperationInst *Function::getReturn() const {
-  return cast<OperationInst>(&getBody()->back());
+void Function::walkInsts(std::function<void(Instruction *)> callback) {
+  struct Walker : public InstWalker<Walker> {
+    std::function<void(Instruction *)> const &callback;
+    Walker(std::function<void(Instruction *)> const &callback)
+        : callback(callback) {}
+
+    void visitInstruction(Instruction *inst) { callback(inst); }
+  };
+
+  Walker v(callback);
+  v.walk(this);
 }
 
-OperationInst *Function::getReturn() {
-  return cast<OperationInst>(&getBody()->back());
-}
-
-void Function::walk(std::function<void(OperationInst *)> callback) {
+void Function::walkOps(std::function<void(OperationInst *)> callback) {
   struct Walker : public InstWalker<Walker> {
     std::function<void(OperationInst *)> const &callback;
     Walker(std::function<void(OperationInst *)> const &callback)
@@ -182,7 +187,20 @@
   v.walk(this);
 }
 
-void Function::walkPostOrder(std::function<void(OperationInst *)> callback) {
+void Function::walkInstsPostOrder(std::function<void(Instruction *)> callback) {
+  struct Walker : public InstWalker<Walker> {
+    std::function<void(Instruction *)> const &callback;
+    Walker(std::function<void(Instruction *)> const &callback)
+        : callback(callback) {}
+
+    void visitOperationInst(Instruction *inst) { callback(inst); }
+  };
+
+  Walker v(callback);
+  v.walkPostOrder(this);
+}
+
+void Function::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
   struct Walker : public InstWalker<Walker> {
     std::function<void(OperationInst *)> const &callback;
     Walker(std::function<void(OperationInst *)> const &callback)
diff --git a/mlir/lib/Transforms/ConvertToCFG.cpp b/mlir/lib/Transforms/ConvertToCFG.cpp
index a9124b0..0ecd248 100644
--- a/mlir/lib/Transforms/ConvertToCFG.cpp
+++ b/mlir/lib/Transforms/ConvertToCFG.cpp
@@ -485,8 +485,10 @@
   }
 
   // Convert instructions in order.
-  for (auto &inst : *mlFunc->getBody()) {
-    visit(&inst);
+  for (auto &block : *mlFunc) {
+    for (auto &inst : block) {
+      visit(&inst);
+    }
   }
 
   return cfgFunc;
diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp
index bc7f31f..b5e7653 100644
--- a/mlir/lib/Transforms/DmaGeneration.cpp
+++ b/mlir/lib/Transforms/DmaGeneration.cpp
@@ -62,9 +62,7 @@
     }
   }
 
-  // Not applicable to CFG functions.
-  PassResult runOnCFGFunction(Function *f) override { return success(); }
-  PassResult runOnMLFunction(Function *f) override;
+  PassResult runOnFunction(Function *f) override;
   void runOnForInst(ForInst *forInst);
 
   void visitOperationInst(OperationInst *opInst);
@@ -425,10 +423,12 @@
                           << " KiB of DMA buffers in fast memory space\n";);
 }
 
-PassResult DmaGeneration::runOnMLFunction(Function *f) {
-  for (auto &inst : *f->getBody()) {
-    if (auto *forInst = dyn_cast<ForInst>(&inst)) {
-      runOnForInst(forInst);
+PassResult DmaGeneration::runOnFunction(Function *f) {
+  for (auto &block : *f) {
+    for (auto &inst : block) {
+      if (auto *forInst = dyn_cast<ForInst>(&inst)) {
+        runOnForInst(forInst);
+      }
     }
   }
   // This function never leaves the IR in an invalid state.
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 97dea75..1854cd9 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -348,7 +348,12 @@
 bool MemRefDependenceGraph::init(Function *f) {
   unsigned id = 0;
   DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
-  for (auto &inst : *f->getBody()) {
+
+  // TODO: support multi-block functions.
+  if (f->getBlocks().size() != 1)
+    return false;
+
+  for (auto &inst : f->front()) {
     if (auto *forInst = dyn_cast<ForInst>(&inst)) {
       // Create graph node 'id' to represent top-level 'forInst' and record
       // all loads and store accesses it contains.
diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp
index 8f3be8a..fa39b7d 100644
--- a/mlir/lib/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Transforms/LoopTiling.cpp
@@ -224,15 +224,17 @@
     do {
       band.push_back(currInst);
     } while (currInst->getBody()->getInstructions().size() == 1 &&
-             (currInst = dyn_cast<ForInst>(&*currInst->getBody()->begin())));
+             (currInst = dyn_cast<ForInst>(&currInst->getBody()->front())));
     bands->push_back(band);
   };
 
-  for (auto &inst : *f->getBody()) {
-    auto *forInst = dyn_cast<ForInst>(&inst);
-    if (!forInst)
-      continue;
-    getMaximalPerfectLoopNest(forInst);
+  for (auto &block : *f) {
+    for (auto &inst : block) {
+      auto *forInst = dyn_cast<ForInst>(&inst);
+      if (!forInst)
+        continue;
+      getMaximalPerfectLoopNest(forInst);
+    }
   }
 }
 
diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
index f59659c..1297560 100644
--- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp
@@ -74,7 +74,7 @@
       : FunctionPass(&LoopUnrollAndJam::passID),
         unrollJamFactor(unrollJamFactor) {}
 
-  PassResult runOnMLFunction(Function *f) override;
+  PassResult runOnFunction(Function *f) override;
   bool runOnForInst(ForInst *forInst);
 
   static char passID;
@@ -88,15 +88,15 @@
       unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
 }
 
-PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) {
+PassResult LoopUnrollAndJam::runOnFunction(Function *f) {
   // Currently, just the outermost loop from the first loop nest is
   // unroll-and-jammed by this pass. However, runOnForInst can be called on any
   // for Inst.
-  auto *forInst = dyn_cast<ForInst>(f->getBody()->begin());
-  if (!forInst)
-    return success();
+  auto &entryBlock = f->front();
+  if (!entryBlock.empty())
+    if (auto *forInst = dyn_cast<ForInst>(&entryBlock.front()))
+      runOnForInst(forInst);
 
-  runOnForInst(forInst);
   return success();
 }
 
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index bcb2abf..5d55800 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -236,7 +236,6 @@
   makeFuncWiseState(Function *f) const override {
     auto state = llvm::make_unique<LowerVectorTransfersState>();
     auto builder = FuncBuilder(f);
-    builder.setInsertionPointToStart(f->getBody());
     state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0);
     return state;
   }
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 6064d1f..c37b997 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -280,7 +280,7 @@
   };
 
   GreedyPatternRewriteDriver driver(std::move(patterns));
-  fn->walk([&](OperationInst *inst) { driver.addToWorklist(inst); });
+  fn->walkOps([&](OperationInst *inst) { driver.addToWorklist(inst); });
 
   FuncBuilder mlBuilder(fn);
   MLFuncRewriter rewriter(driver, mlBuilder);
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 9303937..4168dda 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -101,7 +101,7 @@
   if (!forInst->use_empty()) {
     if (forInst->hasConstantLowerBound()) {
       auto *mlFunc = forInst->getFunction();
-      FuncBuilder topBuilder(&mlFunc->getBody()->front());
+      FuncBuilder topBuilder(mlFunc);
       auto constOp = topBuilder.create<ConstantIndexOp>(
           forInst->getLoc(), forInst->getConstantLowerBound());
       forInst->replaceAllUsesWith(constOp);
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index bbb703c..58bb390 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -651,7 +651,7 @@
 struct Vectorize : public FunctionPass {
   Vectorize() : FunctionPass(&Vectorize::passID) {}
 
-  PassResult runOnMLFunction(Function *f) override;
+  PassResult runOnFunction(Function *f) override;
 
   // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
   MLFunctionMatcherContext MLContext;
@@ -1267,7 +1267,7 @@
 
 /// Applies vectorization to the current Function by searching over a bunch of
 /// predetermined patterns.
-PassResult Vectorize::runOnMLFunction(Function *f) {
+PassResult Vectorize::runOnFunction(Function *f) {
   for (auto pat : makePatterns()) {
     LLVM_DEBUG(dbgs() << "\n******************************************");
     LLVM_DEBUG(dbgs() << "\n******************************************");