| //===- Pass.cpp - MLIR pass registration generator ------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // PassGen uses the description of passes to generate base classes for passes |
| // and command line registration. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/TableGen/GenInfo.h" |
| #include "mlir/TableGen/Pass.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Record.h" |
| |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| using llvm::formatv; |
| using llvm::RecordKeeper; |
| |
| static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls"); |
| static llvm::cl::opt<std::string> |
| groupName("name", llvm::cl::desc("The name of this group of passes"), |
| llvm::cl::cat(passGenCat)); |
| |
| /// Extract the list of passes from the TableGen records. |
| static std::vector<Pass> getPasses(const RecordKeeper &records) { |
| std::vector<Pass> passes; |
| |
| for (const auto *def : records.getAllDerivedDefinitions("PassBase")) |
| passes.emplace_back(def); |
| |
| return passes; |
| } |
| |
| const char *const passHeader = R"( |
| //===----------------------------------------------------------------------===// |
| // {0} |
| //===----------------------------------------------------------------------===// |
| )"; |
| |
| //===----------------------------------------------------------------------===// |
| // GEN: Pass registration generation |
| //===----------------------------------------------------------------------===// |
| |
| /// The code snippet used to generate a pass registration. |
| /// |
| /// {0}: The def name of the pass record. |
| /// {1}: The pass constructor call. |
| const char *const passRegistrationCode = R"( |
| //===----------------------------------------------------------------------===// |
| // {0} Registration |
| //===----------------------------------------------------------------------===// |
| |
| inline void register{0}() {{ |
| ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ |
| return {1}; |
| }); |
| } |
| |
| // Old registration code, kept for temporary backwards compatibility. |
| inline void register{0}Pass() {{ |
| ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ |
| return {1}; |
| }); |
| } |
| )"; |
| |
| /// The code snippet used to generate a function to register all passes in a |
| /// group. |
| /// |
| /// {0}: The name of the pass group. |
| const char *const passGroupRegistrationCode = R"( |
| //===----------------------------------------------------------------------===// |
| // {0} Registration |
| //===----------------------------------------------------------------------===// |
| |
| inline void register{0}Passes() {{ |
| )"; |
| |
| /// Emits the definition of the struct to be used to control the pass options. |
| static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) { |
| StringRef passName = pass.getDef()->getName(); |
| ArrayRef<PassOption> options = pass.getOptions(); |
| |
| // Emit the struct only if the pass has at least one option. |
| if (options.empty()) |
| return; |
| |
| os << formatv("struct {0}Options {{\n", passName); |
| |
| for (const PassOption &opt : options) { |
| std::string type = opt.getType().str(); |
| |
| if (opt.isListOption()) |
| type = "::llvm::SmallVector<" + type + ">"; |
| |
| os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName()); |
| |
| if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) |
| os << " = " << defaultVal; |
| |
| os << ";\n"; |
| } |
| |
| os << "};\n"; |
| } |
| |
| static std::string getPassDeclVarName(const Pass &pass) { |
| return "GEN_PASS_DECL_" + pass.getDef()->getName().upper(); |
| } |
| |
| /// Emit the code to be included in the public header of the pass. |
| static void emitPassDecls(const Pass &pass, raw_ostream &os) { |
| StringRef passName = pass.getDef()->getName(); |
| std::string enableVarName = getPassDeclVarName(pass); |
| |
| os << "#ifdef " << enableVarName << "\n"; |
| emitPassOptionsStruct(pass, os); |
| |
| if (StringRef constructor = pass.getConstructor(); constructor.empty()) { |
| // Default constructor declaration. |
| os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n"; |
| |
| // Declaration of the constructor with options. |
| if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) |
| os << formatv("std::unique_ptr<::mlir::Pass> create{0}(" |
| "{0}Options options);\n", |
| passName); |
| } |
| |
| os << "#undef " << enableVarName << "\n"; |
| os << "#endif // " << enableVarName << "\n"; |
| } |
| |
| /// Emit the code for registering each of the given passes with the global |
| /// PassRegistry. |
| static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) { |
| os << "#ifdef GEN_PASS_REGISTRATION\n"; |
| |
| for (const Pass &pass : passes) { |
| std::string constructorCall; |
| if (StringRef constructor = pass.getConstructor(); !constructor.empty()) |
| constructorCall = constructor.str(); |
| else |
| constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); |
| |
| os << formatv(passRegistrationCode, pass.getDef()->getName(), |
| constructorCall); |
| } |
| |
| os << formatv(passGroupRegistrationCode, groupName); |
| |
| for (const Pass &pass : passes) |
| os << " register" << pass.getDef()->getName() << "();\n"; |
| |
| os << "}\n"; |
| os << "#undef GEN_PASS_REGISTRATION\n"; |
| os << "#endif // GEN_PASS_REGISTRATION\n"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GEN: Pass base class generation |
| //===----------------------------------------------------------------------===// |
| |
| /// The code snippet used to generate the start of a pass base class. |
| /// |
| /// {0}: The def name of the pass record. |
| /// {1}: The base class for the pass. |
| /// {2): The command line argument for the pass. |
| /// {3}: The summary for the pass. |
| /// {4}: The dependent dialects registration. |
| const char *const baseClassBegin = R"( |
| template <typename DerivedT> |
| class {0}Base : public {1} { |
| public: |
| using Base = {0}Base; |
| |
| {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} |
| {0}Base(const {0}Base &other) : {1}(other) {{} |
| {0}Base& operator=(const {0}Base &) = delete; |
| {0}Base({0}Base &&) = delete; |
| {0}Base& operator=({0}Base &&) = delete; |
| ~{0}Base() = default; |
| |
| /// Returns the command-line argument attached to this pass. |
| static constexpr ::llvm::StringLiteral getArgumentName() { |
| return ::llvm::StringLiteral("{2}"); |
| } |
| ::llvm::StringRef getArgument() const override { return "{2}"; } |
| |
| ::llvm::StringRef getDescription() const override { return "{3}"; } |
| |
| /// Returns the derived pass name. |
| static constexpr ::llvm::StringLiteral getPassName() { |
| return ::llvm::StringLiteral("{0}"); |
| } |
| ::llvm::StringRef getName() const override { return "{0}"; } |
| |
| /// Support isa/dyn_cast functionality for the derived pass class. |
| static bool classof(const ::mlir::Pass *pass) {{ |
| return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
| } |
| |
| /// A clone method to create a copy of this pass. |
| std::unique_ptr<::mlir::Pass> clonePass() const override {{ |
| return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
| } |
| |
| /// Return the dialect that must be loaded in the context before this pass. |
| void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
| {4} |
| } |
| |
| /// Explicitly declare the TypeID for this class. We declare an explicit private |
| /// instantiation because Pass classes should only be visible by the current |
| /// library. |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) |
| |
| )"; |
| |
| /// Registration for a single dependent dialect, to be inserted for each |
| /// dependent dialect in the `getDependentDialects` above. |
| const char *const dialectRegistrationTemplate = "registry.insert<{0}>();"; |
| |
| const char *const friendDefaultConstructorDeclTemplate = R"( |
| namespace impl {{ |
| std::unique_ptr<::mlir::Pass> create{0}(); |
| } // namespace impl |
| )"; |
| |
| const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"( |
| namespace impl {{ |
| std::unique_ptr<::mlir::Pass> create{0}({0}Options options); |
| } // namespace impl |
| )"; |
| |
| const char *const friendDefaultConstructorDefTemplate = R"( |
| friend std::unique_ptr<::mlir::Pass> create{0}() {{ |
| return std::make_unique<DerivedT>(); |
| } |
| )"; |
| |
| const char *const friendDefaultConstructorWithOptionsDefTemplate = R"( |
| friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{ |
| return std::make_unique<DerivedT>(std::move(options)); |
| } |
| )"; |
| |
| const char *const defaultConstructorDefTemplate = R"( |
| std::unique_ptr<::mlir::Pass> create{0}() {{ |
| return impl::create{0}(); |
| } |
| )"; |
| |
| const char *const defaultConstructorWithOptionsDefTemplate = R"( |
| std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{ |
| return impl::create{0}(std::move(options)); |
| } |
| )"; |
| |
| /// Emit the declarations for each of the pass options. |
| static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { |
| for (const PassOption &opt : pass.getOptions()) { |
| os.indent(2) << "::mlir::Pass::" |
| << (opt.isListOption() ? "ListOption" : "Option"); |
| |
| os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))", |
| opt.getType(), opt.getCppVariableName(), opt.getArgument(), |
| opt.getDescription()); |
| if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) |
| os << ", ::llvm::cl::init(" << defaultVal << ")"; |
| if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags()) |
| os << ", " << *additionalFlags; |
| os << "};\n"; |
| } |
| } |
| |
| /// Emit the declarations for each of the pass statistics. |
| static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { |
| for (const PassStatistic &stat : pass.getStatistics()) { |
| os << formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", |
| stat.getCppVariableName(), stat.getName(), |
| stat.getDescription()); |
| } |
| } |
| |
| /// Emit the code to be used in the implementation of the pass. |
| static void emitPassDefs(const Pass &pass, raw_ostream &os) { |
| StringRef passName = pass.getDef()->getName(); |
| std::string enableVarName = "GEN_PASS_DEF_" + passName.upper(); |
| bool emitDefaultConstructors = pass.getConstructor().empty(); |
| bool emitDefaultConstructorWithOptions = !pass.getOptions().empty(); |
| |
| os << "#ifdef " << enableVarName << "\n"; |
| |
| if (emitDefaultConstructors) { |
| os << formatv(friendDefaultConstructorDeclTemplate, passName); |
| |
| if (emitDefaultConstructorWithOptions) |
| os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName); |
| } |
| |
| std::string dependentDialectRegistrations; |
| { |
| llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
| llvm::interleave( |
| pass.getDependentDialects(), dialectsOs, |
| [&](StringRef dependentDialect) { |
| dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); |
| }, |
| "\n "); |
| } |
| |
| os << "namespace impl {\n"; |
| os << formatv(baseClassBegin, passName, pass.getBaseClass(), |
| pass.getArgument(), pass.getSummary(), |
| dependentDialectRegistrations); |
| |
| if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) { |
| os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n", |
| passName); |
| |
| for (const PassOption &opt : pass.getOptions()) |
| os.indent(4) << formatv("{0} = std::move(options.{0});\n", |
| opt.getCppVariableName()); |
| |
| os.indent(2) << "}\n"; |
| } |
| |
| // Protected content |
| os << "protected:\n"; |
| emitPassOptionDecls(pass, os); |
| emitPassStatisticDecls(pass, os); |
| |
| // Private content |
| os << "private:\n"; |
| |
| if (emitDefaultConstructors) { |
| os << formatv(friendDefaultConstructorDefTemplate, passName); |
| |
| if (!pass.getOptions().empty()) |
| os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName); |
| } |
| |
| os << "};\n"; |
| os << "} // namespace impl\n"; |
| |
| if (emitDefaultConstructors) { |
| os << formatv(defaultConstructorDefTemplate, passName); |
| |
| if (emitDefaultConstructorWithOptions) |
| os << formatv(defaultConstructorWithOptionsDefTemplate, passName); |
| } |
| |
| os << "#undef " << enableVarName << "\n"; |
| os << "#endif // " << enableVarName << "\n"; |
| } |
| |
| static void emitPass(const Pass &pass, raw_ostream &os) { |
| StringRef passName = pass.getDef()->getName(); |
| os << formatv(passHeader, passName); |
| |
| emitPassDecls(pass, os); |
| emitPassDefs(pass, os); |
| } |
| |
| // TODO: Drop old pass declarations. |
| // The old pass base class is being kept until all the passes have switched to |
| // the new decls/defs design. |
| const char *const oldPassDeclBegin = R"( |
| template <typename DerivedT> |
| class {0}Base : public {1} { |
| public: |
| using Base = {0}Base; |
| |
| {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} |
| {0}Base(const {0}Base &other) : {1}(other) {{} |
| {0}Base& operator=(const {0}Base &) = delete; |
| {0}Base({0}Base &&) = delete; |
| {0}Base& operator=({0}Base &&) = delete; |
| ~{0}Base() = default; |
| |
| /// Returns the command-line argument attached to this pass. |
| static constexpr ::llvm::StringLiteral getArgumentName() { |
| return ::llvm::StringLiteral("{2}"); |
| } |
| ::llvm::StringRef getArgument() const override { return "{2}"; } |
| |
| ::llvm::StringRef getDescription() const override { return "{3}"; } |
| |
| /// Returns the derived pass name. |
| static constexpr ::llvm::StringLiteral getPassName() { |
| return ::llvm::StringLiteral("{0}"); |
| } |
| ::llvm::StringRef getName() const override { return "{0}"; } |
| |
| /// Support isa/dyn_cast functionality for the derived pass class. |
| static bool classof(const ::mlir::Pass *pass) {{ |
| return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); |
| } |
| |
| /// A clone method to create a copy of this pass. |
| std::unique_ptr<::mlir::Pass> clonePass() const override {{ |
| return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); |
| } |
| |
| /// Register the dialects that must be loaded in the context before this pass. |
| void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
| {4} |
| } |
| |
| /// Explicitly declare the TypeID for this class. We declare an explicit private |
| /// instantiation because Pass classes should only be visible by the current |
| /// library. |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) |
| |
| protected: |
| )"; |
| |
| // TODO: Drop old pass declarations. |
| /// Emit a backward-compatible declaration of the pass base class. |
| static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { |
| StringRef defName = pass.getDef()->getName(); |
| std::string dependentDialectRegistrations; |
| { |
| llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
| llvm::interleave( |
| pass.getDependentDialects(), dialectsOs, |
| [&](StringRef dependentDialect) { |
| dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); |
| }, |
| "\n "); |
| } |
| os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(), |
| pass.getArgument(), pass.getSummary(), |
| dependentDialectRegistrations); |
| emitPassOptionDecls(pass, os); |
| emitPassStatisticDecls(pass, os); |
| os << "};\n"; |
| } |
| |
| static void emitPasses(const RecordKeeper &records, raw_ostream &os) { |
| std::vector<Pass> passes = getPasses(records); |
| os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; |
| |
| os << "\n"; |
| os << "#ifdef GEN_PASS_DECL\n"; |
| os << "// Generate declarations for all passes.\n"; |
| for (const Pass &pass : passes) |
| os << "#define " << getPassDeclVarName(pass) << "\n"; |
| os << "#undef GEN_PASS_DECL\n"; |
| os << "#endif // GEN_PASS_DECL\n"; |
| |
| for (const Pass &pass : passes) |
| emitPass(pass, os); |
| |
| emitRegistrations(passes, os); |
| |
| // TODO: Drop old pass declarations. |
| // Emit the old code until all the passes have switched to the new design. |
| os << "// Deprecated. Please use the new per-pass macros.\n"; |
| os << "#ifdef GEN_PASS_CLASSES\n"; |
| for (const Pass &pass : passes) |
| emitOldPassDecl(pass, os); |
| os << "#undef GEN_PASS_CLASSES\n"; |
| os << "#endif // GEN_PASS_CLASSES\n"; |
| } |
| |
| static mlir::GenRegistration |
| genPassDecls("gen-pass-decls", "Generate pass declarations", |
| [](const RecordKeeper &records, raw_ostream &os) { |
| emitPasses(records, os); |
| return false; |
| }); |