blob: 8e2d6114e48eb0e8798e548d67b8ff9f4dfbe68a [file] [log] [blame]
//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
// generate the corresponding Python binding classes.
//
//===----------------------------------------------------------------------===//
#include "OpGenHelpers.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/EnumInfo.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
/// File header and includes.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from enum import IntEnum, auto, IntFlag
from ._ods_common import _cext as _ods_cext
from ..ir import register_attribute_builder
_ods_ir = _ods_cext.ir
)Py";
/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
static std::string makePythonEnumCaseName(StringRef name) {
if (isPythonReserved(name.str()))
return (name + "_").str();
return name.str();
}
/// Emits the Python class for the given enum.
static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(),
enumInfo.isBitEnum() ? "IntFlag" : "IntEnum");
if (!enumInfo.getSummary().empty())
os << formatv(" \"\"\"{0}\"\"\"\n", enumInfo.getSummary());
os << "\n";
for (const EnumCase &enumCase : enumInfo.getAllCases()) {
os << formatv(" {0} = {1}\n",
makePythonEnumCaseName(enumCase.getSymbol()),
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
: "auto()");
}
os << "\n";
if (enumInfo.isBitEnum()) {
os << formatv(" def __iter__(self):\n"
" return iter([case for case in type(self) if "
"(self & case) is case])\n");
os << formatv(" def __len__(self):\n"
" return bin(self).count(\"1\")\n");
os << "\n";
}
os << formatv(" def __str__(self):\n");
if (enumInfo.isBitEnum())
os << formatv(" if len(self) > 1:\n"
" return \"{0}\".join(map(str, self))\n",
enumInfo.getDef().getValueAsString("separator"));
for (const EnumCase &enumCase : enumInfo.getAllCases()) {
os << formatv(" if self is {0}.{1}:\n", enumInfo.getEnumClassName(),
makePythonEnumCaseName(enumCase.getSymbol()));
os << formatv(" return \"{0}\"\n", enumCase.getStr());
}
os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
enumInfo.getEnumClassName());
os << "\n";
}
/// Emits an attribute builder for the given enum attribute to support automatic
/// conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr();
if (!enumAttrInfo)
return false;
int64_t bitwidth = enumInfo.getBitwidth();
os << formatv("@register_attribute_builder(\"{0}\")\n",
enumAttrInfo->getAttrDefName());
os << formatv("def _{0}(x, context):\n",
enumAttrInfo->getAttrDefName().lower());
os << formatv(" return "
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
"context=context), int(x))\n\n",
bitwidth);
return false;
}
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
StringRef formatString,
raw_ostream &os) {
os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
os << formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
return false;
}
/// Emits Python bindings for all enums in the record keeper. Returns
/// `false` on success, `true` on failure.
static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
os << fileHeader;
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumInfo")) {
EnumInfo enumInfo(*it);
emitEnumClass(enumInfo, os);
emitAttributeBuilder(enumInfo, os);
}
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
if (!attr.getMnemonic()) {
llvm::errs() << "enum case " << attr
<< " needs mnemonic for python enum bindings generation";
return true;
}
StringRef mnemonic = attr.getMnemonic().value();
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
StringRef dialect = attr.getDialect().getName();
if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder(
attr.getName(),
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
} else if (assemblyFormat == "$value") {
emitDialectEnumAttributeBuilder(
attr.getName(),
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
} else {
llvm::errs()
<< "unsupported assembly format for python enum bindings generation";
return true;
}
}
return false;
}
// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
genPythonEnumBindings("gen-python-enum-bindings",
"Generate Python bindings for enum attributes",
&emitPythonEnums);