[mlir][Pass] Update the PassGen to generate base classes instead of utilities
Summary:
This is much cleaner, and fits the same structure as many other tablegen backends. This was not done originally as the CRTP in the pass classes made it overly verbose/complex.
Differential Revision: https://reviews.llvm.org/D77367
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 919c957..8309099 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -7,16 +7,13 @@
//===----------------------------------------------------------------------===//
//
// This transformation pass performs a simple common sub-expression elimination
-// algorithm on operations within a function.
+// algorithm on operations within a region.
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/Analysis/Dominance.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Support/Functional.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMapInfo.h"
@@ -25,6 +22,7 @@
#include "llvm/Support/Allocator.h"
#include "llvm/Support/RecyclingAllocator.h"
#include <deque>
+
using namespace mlir;
namespace {
@@ -73,14 +71,7 @@
namespace {
/// Simple common sub-expression elimination.
-struct CSE : public PassWrapper<CSE, OperationPass<>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_CSE
-#include "mlir/Transforms/Passes.h.inc"
-
- CSE() = default;
- CSE(const CSE &) {}
-
+struct CSE : public CSEBase<CSE> {
/// Shared implementation of operation elimination and scoped map definitions.
using AllocatorTy = llvm::RecyclingAllocator<
llvm::BumpPtrAllocator,
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 3f3d302..c46a8b9 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
@@ -19,11 +20,7 @@
namespace {
/// Canonicalize operations in nested regions.
-struct Canonicalizer : public PassWrapper<Canonicalizer, OperationPass<>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_Canonicalizer
-#include "mlir/Transforms/Passes.h.inc"
-
+struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
void runOnOperation() override {
OwningRewritePatternList patterns;
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 60382ea..582f720 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -13,10 +13,10 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffects.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SCCIterator.h"
@@ -589,11 +589,7 @@
//===----------------------------------------------------------------------===//
namespace {
-struct InlinerPass : public PassWrapper<InlinerPass, OperationPass<>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_Inliner
-#include "mlir/Transforms/Passes.h.inc"
-
+struct InlinerPass : public InlinerBase<InlinerPass> {
void runOnOperation() override {
CallGraph &cg = getAnalysis<CallGraph>();
auto *context = &getContext();
diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp
index e9858bc..0b1d929 100644
--- a/mlir/lib/Transforms/LocationSnapshot.cpp
+++ b/mlir/lib/Transforms/LocationSnapshot.cpp
@@ -7,9 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/LocationSnapshot.h"
+#include "PassDetail.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Builders.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/ToolOutputFile.h"
@@ -124,13 +124,8 @@
namespace {
struct LocationSnapshotPass
- : public PassWrapper<LocationSnapshotPass, OperationPass<>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_LocationSnapshot
-#include "mlir/Transforms/Passes.h.inc"
-
+ : public LocationSnapshotBase<LocationSnapshotPass> {
LocationSnapshotPass() = default;
- LocationSnapshotPass(const LocationSnapshotPass &) {}
LocationSnapshotPass(OpPrintingFlags flags, StringRef fileName, StringRef tag)
: flags(flags) {
this->fileName = fileName.str();
diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp
index 57d8e2a..d47b377 100644
--- a/mlir/lib/Transforms/LoopCoalescing.cpp
+++ b/mlir/lib/Transforms/LoopCoalescing.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -19,12 +19,7 @@
using namespace mlir;
namespace {
-struct LoopCoalescingPass
- : public PassWrapper<LoopCoalescingPass, FunctionPass> {
-/// Include the generated pass utilities.
-#define GEN_PASS_LoopCoalescing
-#include "mlir/Transforms/Passes.h.inc"
-
+struct LoopCoalescingPass : public LoopCoalescingBase<LoopCoalescingPass> {
void runOnFunction() override {
FuncOp func = getFunction();
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index f802ba5..47ee502 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/LoopAnalysis.h"
@@ -18,7 +19,6 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopFusionUtils.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
@@ -77,11 +77,7 @@
// TODO(andydavis) Extend this pass to check for fusion preventing dependences,
// and add support for more general loop fusion algorithms.
-struct LoopFusion : public PassWrapper<LoopFusion, FunctionPass> {
-/// Include the generated pass utilities.
-#define GEN_PASS_AffineLoopFusion
-#include "mlir/Transforms/Passes.h.inc"
-
+struct LoopFusion : public AffineLoopFusionBase<LoopFusion> {
LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0,
bool maximalFusion = false)
: localBufSizeThreshold(localBufSizeThreshold),
diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
index e7e48ac..dacd688 100644
--- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp
@@ -10,13 +10,13 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffects.h"
-#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -28,11 +28,7 @@
namespace {
/// Loop invariant code motion (LICM) pass.
struct LoopInvariantCodeMotion
- : public PassWrapper<LoopInvariantCodeMotion, OperationPass<>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_LoopInvariantCodeMotion
-#include "mlir/Transforms/Passes.h.inc"
-
+ : public LoopInvariantCodeMotionBase<LoopInvariantCodeMotion> {
void runOnOperation() override;
};
} // end anonymous namespace
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index 5b03de9..a0a4175 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -13,12 +13,12 @@
// SSA scalars live out of 'affine.for'/'affine.if' statements is available.
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include <algorithm>
@@ -60,11 +60,7 @@
// currently only eliminates the stores only if no other loads/uses (other
// than dealloc) remain.
//
-struct MemRefDataFlowOpt : public PassWrapper<MemRefDataFlowOpt, FunctionPass> {
-/// Include the generated pass utilities.
-#define GEN_PASS_MemRefDataFlowOpt
-#include "mlir/Transforms/Passes.h.inc"
-
+struct MemRefDataFlowOpt : public MemRefDataFlowOptBase<MemRefDataFlowOpt> {
void runOnFunction() override;
void forwardStoreToLoad(AffineLoadOp loadOp);
diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp
index 667a0b4..7d64910 100644
--- a/mlir/lib/Transforms/OpStats.cpp
+++ b/mlir/lib/Transforms/OpStats.cpp
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Format.h"
@@ -18,12 +18,7 @@
using namespace mlir;
namespace {
-struct PrintOpStatsPass
- : public PassWrapper<PrintOpStatsPass, OperationPass<ModuleOp>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_PrintOpStats
-#include "mlir/Transforms/Passes.h.inc"
-
+struct PrintOpStatsPass : public PrintOpStatsBase<PrintOpStatsPass> {
explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
// Prints the resultant operation statistics post iterating over the module.
diff --git a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp
index 4380fe3..16ecb84 100644
--- a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp
+++ b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp
@@ -6,9 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
@@ -21,13 +20,7 @@
namespace {
struct ParallelLoopCollapsing
- : public PassWrapper<ParallelLoopCollapsing, OperationPass<>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_ParallelLoopCollapsing
-#include "mlir/Transforms/Passes.h.inc"
-
- ParallelLoopCollapsing() = default;
- ParallelLoopCollapsing(const ParallelLoopCollapsing &) {}
+ : public ParallelLoopCollapsingBase<ParallelLoopCollapsing> {
void runOnOperation() override {
Operation *module = getOperation();
@@ -45,7 +38,6 @@
});
}
};
-
} // namespace
std::unique_ptr<Pass> mlir::createParallelLoopCollapsingPass() {
diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h
new file mode 100644
index 0000000..c6f7e22
--- /dev/null
+++ b/mlir/lib/Transforms/PassDetail.h
@@ -0,0 +1,21 @@
+//===- PassDetail.h - Transforms Pass class details -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TRANSFORMS_PASSDETAIL_H_
+#define TRANSFORMS_PASSDETAIL_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+#define GEN_PASS_CLASSES
+#include "mlir/Transforms/Passes.h.inc"
+
+} // end namespace mlir
+
+#endif // TRANSFORMS_PASSDETAIL_H_
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 8eeea89..01aa25a 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/AffineAnalysis.h"
@@ -17,7 +18,6 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/Builders.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
@@ -29,11 +29,7 @@
namespace {
struct PipelineDataTransfer
- : public PassWrapper<PipelineDataTransfer, FunctionPass> {
-/// Include the generated pass utilities.
-#define GEN_PASS_AffinePipelineDataTransfer
-#include "mlir/Transforms/Passes.h.inc"
-
+ : public AffinePipelineDataTransferBase<PipelineDataTransfer> {
void runOnFunction() override;
void runOnAffineForOp(AffineForOp forOp);
diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp
index e5ba144..15ce1c2 100644
--- a/mlir/lib/Transforms/StripDebugInfo.cpp
+++ b/mlir/lib/Transforms/StripDebugInfo.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "PassDetail.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
@@ -14,11 +15,7 @@
using namespace mlir;
namespace {
-struct StripDebugInfo : public PassWrapper<StripDebugInfo, OperationPass<>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_StripDebugInfo
-#include "mlir/Transforms/Passes.h.inc"
-
+struct StripDebugInfo : public StripDebugInfoBase<StripDebugInfo> {
void runOnOperation() override;
};
} // end anonymous namespace
diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp
index 251a956..581857a6 100644
--- a/mlir/lib/Transforms/SymbolDCE.cpp
+++ b/mlir/lib/Transforms/SymbolDCE.cpp
@@ -11,17 +11,13 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Pass/Pass.h"
+#include "PassDetail.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace {
-struct SymbolDCE : public PassWrapper<SymbolDCE, OperationPass<>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_SymbolDCE
-#include "mlir/Transforms/Passes.h.inc"
-
+struct SymbolDCE : public SymbolDCEBase<SymbolDCE> {
void runOnOperation() override;
/// Compute the liveness of the symbols within the given symbol table.
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 8ac61fc..41e33e8 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -7,10 +7,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/ViewOpGraph.h"
+#include "PassDetail.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
-#include "mlir/Pass/Pass.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/CommandLine.h"
@@ -100,11 +100,7 @@
// PrintOpPass is simple pass to write graph per function.
// Note: this is a module pass only to avoid interleaving on the same ostream
// due to multi-threading over functions.
-struct PrintOpPass : public PassWrapper<PrintOpPass, OperationPass<ModuleOp>> {
-/// Include the generated pass utilities.
-#define GEN_PASS_PrintOpGraph
-#include "mlir/Transforms/Passes.h.inc"
-
+struct PrintOpPass : public PrintOpBase<PrintOpPass> {
explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false,
const Twine &title = "")
: os(os), title(title.str()), short_names(short_names) {}
diff --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp
index 4f31a79..0c67f30 100644
--- a/mlir/lib/Transforms/ViewRegionGraph.cpp
+++ b/mlir/lib/Transforms/ViewRegionGraph.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/ViewRegionGraph.h"
+#include "PassDetail.h"
#include "mlir/IR/RegionGraphTraits.h"
-#include "mlir/Pass/Pass.h"
using namespace mlir;
@@ -60,11 +60,7 @@
void mlir::Region::viewGraph() { viewGraph("region"); }
namespace {
-struct PrintCFGPass : public PassWrapper<PrintCFGPass, FunctionPass> {
-/// Include the generated pass utilities.
-#define GEN_PASS_PrintCFG
-#include "mlir/Transforms/Passes.h.inc"
-
+struct PrintCFGPass : public PrintCFGBase<PrintCFGPass> {
PrintCFGPass(raw_ostream &os = llvm::errs(), bool shortNames = false,
const Twine &title = "")
: os(os), shortNames(shortNames), title(title.str()) {}