| //===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to |
| // generate the corresponding Python binding classes. |
| // |
| //===----------------------------------------------------------------------===// |
| #include "OpGenHelpers.h" |
| |
| #include "mlir/TableGen/AttrOrTypeDef.h" |
| #include "mlir/TableGen/Attribute.h" |
| #include "mlir/TableGen/Dialect.h" |
| #include "mlir/TableGen/EnumInfo.h" |
| #include "mlir/TableGen/GenInfo.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/TableGen/Record.h" |
| |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| using llvm::formatv; |
| using llvm::Record; |
| using llvm::RecordKeeper; |
| |
| /// File header and includes. |
| constexpr const char *fileHeader = R"Py( |
| # Autogenerated by mlir-tblgen; don't manually edit. |
| |
| from enum import IntEnum, auto, IntFlag |
| from ._ods_common import _cext as _ods_cext |
| from ..ir import register_attribute_builder |
| _ods_ir = _ods_cext.ir |
| |
| )Py"; |
| |
| /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. |
| static std::string makePythonEnumCaseName(StringRef name) { |
| if (isPythonReserved(name.str())) |
| return (name + "_").str(); |
| return name.str(); |
| } |
| |
| /// Emits the Python class for the given enum. |
| static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) { |
| os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(), |
| enumInfo.isBitEnum() ? "IntFlag" : "IntEnum"); |
| if (!enumInfo.getSummary().empty()) |
| os << formatv(" \"\"\"{0}\"\"\"\n", enumInfo.getSummary()); |
| os << "\n"; |
| |
| for (const EnumCase &enumCase : enumInfo.getAllCases()) { |
| os << formatv(" {0} = {1}\n", |
| makePythonEnumCaseName(enumCase.getSymbol()), |
| enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) |
| : "auto()"); |
| } |
| |
| os << "\n"; |
| |
| if (enumInfo.isBitEnum()) { |
| os << formatv(" def __iter__(self):\n" |
| " return iter([case for case in type(self) if " |
| "(self & case) is case])\n"); |
| os << formatv(" def __len__(self):\n" |
| " return bin(self).count(\"1\")\n"); |
| os << "\n"; |
| } |
| |
| os << formatv(" def __str__(self):\n"); |
| if (enumInfo.isBitEnum()) |
| os << formatv(" if len(self) > 1:\n" |
| " return \"{0}\".join(map(str, self))\n", |
| enumInfo.getDef().getValueAsString("separator")); |
| for (const EnumCase &enumCase : enumInfo.getAllCases()) { |
| os << formatv(" if self is {0}.{1}:\n", enumInfo.getEnumClassName(), |
| makePythonEnumCaseName(enumCase.getSymbol())); |
| os << formatv(" return \"{0}\"\n", enumCase.getStr()); |
| } |
| os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n", |
| enumInfo.getEnumClassName()); |
| os << "\n"; |
| } |
| |
| /// Emits an attribute builder for the given enum attribute to support automatic |
| /// conversion between enum values and attributes in Python. Returns |
| /// `false` on success, `true` on failure. |
| static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) { |
| std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr(); |
| if (!enumAttrInfo) |
| return false; |
| |
| int64_t bitwidth = enumInfo.getBitwidth(); |
| os << formatv("@register_attribute_builder(\"{0}\")\n", |
| enumAttrInfo->getAttrDefName()); |
| os << formatv("def _{0}(x, context):\n", |
| enumAttrInfo->getAttrDefName().lower()); |
| os << formatv(" return " |
| "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " |
| "context=context), int(x))\n\n", |
| bitwidth); |
| return false; |
| } |
| |
| /// Emits an attribute builder for the given dialect enum attribute to support |
| /// automatic conversion between enum values and attributes in Python. Returns |
| /// `false` on success, `true` on failure. |
| static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, |
| StringRef formatString, |
| raw_ostream &os) { |
| os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); |
| os << formatv("def _{0}(x, context):\n", attrDefName.lower()); |
| os << formatv(" return " |
| "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", |
| formatString); |
| return false; |
| } |
| |
| /// Emits Python bindings for all enums in the record keeper. Returns |
| /// `false` on success, `true` on failure. |
| static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) { |
| os << fileHeader; |
| for (const Record *it : |
| records.getAllDerivedDefinitionsIfDefined("EnumInfo")) { |
| EnumInfo enumInfo(*it); |
| emitEnumClass(enumInfo, os); |
| emitAttributeBuilder(enumInfo, os); |
| } |
| for (const Record *it : |
| records.getAllDerivedDefinitionsIfDefined("EnumAttr")) { |
| AttrOrTypeDef attr(&*it); |
| if (!attr.getMnemonic()) { |
| llvm::errs() << "enum case " << attr |
| << " needs mnemonic for python enum bindings generation"; |
| return true; |
| } |
| StringRef mnemonic = attr.getMnemonic().value(); |
| std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat(); |
| StringRef dialect = attr.getDialect().getName(); |
| if (assemblyFormat == "`<` $value `>`") { |
| emitDialectEnumAttributeBuilder( |
| attr.getName(), |
| formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os); |
| } else if (assemblyFormat == "$value") { |
| emitDialectEnumAttributeBuilder( |
| attr.getName(), |
| formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os); |
| } else { |
| llvm::errs() |
| << "unsupported assembly format for python enum bindings generation"; |
| return true; |
| } |
| } |
| |
| return false; |
| } |
| |
| // Registers the enum utility generator to mlir-tblgen. |
| static mlir::GenRegistration |
| genPythonEnumBindings("gen-python-enum-bindings", |
| "Generate Python bindings for enum attributes", |
| &emitPythonEnums); |