| //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python |
| // binding classes wrapping a generic operation API. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "OpGenHelpers.h" |
| |
| #include "mlir/TableGen/GenInfo.h" |
| #include "mlir/TableGen/Operator.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Record.h" |
| |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| using llvm::formatv; |
| using llvm::Record; |
| using llvm::RecordKeeper; |
| |
| /// File header and includes. |
| /// {0} is the dialect namespace. |
| constexpr const char *fileHeader = R"Py( |
| # Autogenerated by mlir-tblgen; don't manually edit. |
| |
| from ._ods_common import _cext as _ods_cext |
| from ._ods_common import ( |
| equally_sized_accessor as _ods_equally_sized_accessor, |
| get_default_loc_context as _ods_get_default_loc_context, |
| get_op_result_or_op_results as _get_op_result_or_op_results, |
| get_op_results_or_values as _get_op_results_or_values, |
| segmented_accessor as _ods_segmented_accessor, |
| ) |
| _ods_ir = _ods_cext.ir |
| |
| import builtins |
| from typing import Sequence as _Sequence, Union as _Union |
| |
| )Py"; |
| |
| /// Template for dialect class: |
| /// {0} is the dialect namespace. |
| constexpr const char *dialectClassTemplate = R"Py( |
| @_ods_cext.register_dialect |
| class _Dialect(_ods_ir.Dialect): |
| DIALECT_NAMESPACE = "{0}" |
| )Py"; |
| |
| constexpr const char *dialectExtensionTemplate = R"Py( |
| from ._{0}_ops_gen import _Dialect |
| )Py"; |
| |
| /// Template for operation class: |
| /// {0} is the Python class name; |
| /// {1} is the operation name. |
| constexpr const char *opClassTemplate = R"Py( |
| @_ods_cext.register_operation(_Dialect) |
| class {0}(_ods_ir.OpView): |
| OPERATION_NAME = "{1}" |
| )Py"; |
| |
| /// Template for class level declarations of operand and result |
| /// segment specs. |
| /// {0} is either "OPERAND" or "RESULT" |
| /// {1} is the segment spec |
| /// Each segment spec is either None (default) or an array of integers |
| /// where: |
| /// 1 = single element (expect non sequence operand/result) |
| /// 0 = optional element (expect a value or std::nullopt) |
| /// -1 = operand/result is a sequence corresponding to a variadic |
| constexpr const char *opClassSizedSegmentsTemplate = R"Py( |
| _ODS_{0}_SEGMENTS = {1} |
| )Py"; |
| |
| /// Template for class level declarations of the _ODS_REGIONS spec: |
| /// {0} is the minimum number of regions |
| /// {1} is the Python bool literal for hasNoVariadicRegions |
| constexpr const char *opClassRegionSpecTemplate = R"Py( |
| _ODS_REGIONS = ({0}, {1}) |
| )Py"; |
| |
| /// Template for single-element accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the position in the element list. |
| constexpr const char *opSingleTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return self.operation.{1}s[{2}] |
| )Py"; |
| |
| /// Template for single-element accessor after a variable-length group: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| /// This works for both a single variadic group (non-negative length) and an |
| /// single optional element (zero length if the element is absent). |
| constexpr const char *opSingleAfterVariableTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
| return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] |
| )Py"; |
| |
| /// Template for an optional element accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| /// This works if we have only one variable-length group (and it's the optional |
| /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is |
| /// smaller than the total number of groups. |
| constexpr const char *opOneOptionalTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] |
| )Py"; |
| |
| /// Template for the variadic group accessor in the single variadic group case: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| constexpr const char *opOneVariadicTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
| return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] |
| )Py"; |
| |
| /// First part of the template for equally-sized variadic group accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of non-variadic groups; |
| /// {3} is the total number of variadic groups; |
| /// {4} is the number of non-variadic groups preceding the current group; |
| /// {5} is the number of variadic groups preceding the current group. |
| constexpr const char *opVariadicEqualPrefixTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py"; |
| |
| /// Second part of the template for equally-sized case, accessing a single |
| /// element: |
| /// {0} is either 'operand' or 'result'. |
| constexpr const char *opVariadicEqualSimpleTemplate = R"Py( |
| return self.operation.{0}s[start] |
| )Py"; |
| |
| /// Second part of the template for equally-sized case, accessing a variadic |
| /// group: |
| /// {0} is either 'operand' or 'result'. |
| constexpr const char *opVariadicEqualVariadicTemplate = R"Py( |
| return self.operation.{0}s[start:start + elements_per_group] |
| )Py"; |
| |
| /// Template for an attribute-sized group accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the position of the group in the group list; |
| /// {3} is a return suffix (expected [0] for single-element, empty for |
| /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). |
| constexpr const char *opVariadicSegmentTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| {1}_range = _ods_segmented_accessor( |
| self.operation.{1}s, |
| self.operation.attributes["{1}SegmentSizes"], {2}) |
| return {1}_range{3} |
| )Py"; |
| |
| /// Template for a suffix when accessing an optional element in the |
| /// attribute-sized case: |
| /// {0} is either 'operand' or 'result'; |
| constexpr const char *opVariadicSegmentOptionalTrailingTemplate = |
| R"Py([0] if len({0}_range) > 0 else None)Py"; |
| |
| /// Template for an operation attribute getter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *attributeGetterTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return self.operation.attributes["{1}"] |
| )Py"; |
| |
| /// Template for an optional operation attribute getter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *optionalAttributeGetterTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| if "{1}" not in self.operation.attributes: |
| return None |
| return self.operation.attributes["{1}"] |
| )Py"; |
| |
| /// Template for a getter of a unit operation attribute, returns True of the |
| /// unit attribute is present, False otherwise (unit attributes have meaning |
| /// by mere presence): |
| /// {0} is the name of the attribute sanitized for Python, |
| /// {1} is the original name of the attribute. |
| constexpr const char *unitAttributeGetterTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return "{1}" in self.operation.attributes |
| )Py"; |
| |
| /// Template for an operation attribute setter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *attributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if value is None: |
| raise ValueError("'None' not allowed as value for mandatory attributes") |
| self.operation.attributes["{1}"] = value |
| )Py"; |
| |
| /// Template for a setter of an optional operation attribute, setting to None |
| /// removes the attribute: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *optionalAttributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if value is not None: |
| self.operation.attributes["{1}"] = value |
| elif "{1}" in self.operation.attributes: |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| /// Template for a setter of a unit operation attribute, setting to None or |
| /// False removes the attribute: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *unitAttributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if bool(value): |
| self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() |
| elif "{1}" in self.operation.attributes: |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| /// Template for a deleter of an optional or a unit operation attribute, removes |
| /// the attribute from the operation: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *attributeDeleterTemplate = R"Py( |
| @{0}.deleter |
| def {0}(self): |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| constexpr const char *regionAccessorTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return self.regions[{1}] |
| )Py"; |
| |
| constexpr const char *valueBuilderTemplate = R"Py( |
| def {0}({2}) -> {4}: |
| return {1}({3}){5} |
| )Py"; |
| |
| constexpr const char *valueBuilderVariadicTemplate = R"Py( |
| def {0}({2}) -> {4}: |
| return _get_op_result_or_op_results({1}({3})) |
| )Py"; |
| |
| static llvm::cl::OptionCategory |
| clOpPythonBindingCat("Options for -gen-python-op-bindings"); |
| |
| static llvm::cl::opt<std::string> |
| clDialectName("bind-dialect", |
| llvm::cl::desc("The dialect to run the generator for"), |
| llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); |
| |
| static llvm::cl::opt<std::string> clDialectExtensionName( |
| "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"), |
| llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); |
| |
| using AttributeClasses = DenseMap<StringRef, StringRef>; |
| |
| /// Checks whether `str` would shadow a generated variable or attribute |
| /// part of the OpView API. |
| static bool isODSReserved(StringRef str) { |
| static llvm::StringSet<> reserved( |
| {"attributes", "create", "context", "ip", "operands", "print", "get_asm", |
| "loc", "verify", "regions", "results", "self", "operation", |
| "DIALECT_NAMESPACE", "OPERATION_NAME"}); |
| return str.starts_with("_ods_") || str.ends_with("_ods") || |
| reserved.contains(str); |
| } |
| |
| /// Modifies the `name` in a way that it becomes suitable for Python bindings |
| /// (does not change the `name` if it already is suitable) and returns the |
| /// modified version. |
| static std::string sanitizeName(StringRef name) { |
| std::string processedStr = name.str(); |
| std::replace_if( |
| processedStr.begin(), processedStr.end(), |
| [](char c) { return !llvm::isAlnum(c); }, '_'); |
| |
| if (llvm::isDigit(*processedStr.begin())) |
| return "_" + processedStr; |
| |
| if (isPythonReserved(processedStr) || isODSReserved(processedStr)) |
| return processedStr + "_"; |
| return processedStr; |
| } |
| |
| static std::string attrSizedTraitForKind(const char *kind) { |
| return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", |
| StringRef(kind).take_front().upper(), |
| StringRef(kind).drop_front()); |
| } |
| |
| /// Emits accessors to "elements" of an Op definition. Currently, the supported |
| /// elements are operands and results, indicated by `kind`, which must be either |
| /// `operand` or `result` and is used verbatim in the emitted code. |
| static void emitElementAccessors( |
| const Operator &op, raw_ostream &os, const char *kind, |
| unsigned numVariadicGroups, unsigned numElements, |
| llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| getElement) { |
| assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"}, |
| kind) && |
| "unsupported kind"); |
| |
| // Traits indicating how to process variadic elements. |
| std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", |
| StringRef(kind).take_front().upper(), |
| StringRef(kind).drop_front()); |
| std::string attrSizedTrait = attrSizedTraitForKind(kind); |
| |
| // If there is only one variable-length element group, its size can be |
| // inferred from the total number of elements. If there are none, the |
| // generation is straightforward. |
| if (numVariadicGroups <= 1) { |
| bool seenVariableLength = false; |
| for (unsigned i = 0; i < numElements; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.isVariableLength()) |
| seenVariableLength = true; |
| if (element.name.empty()) |
| continue; |
| if (element.isVariableLength()) { |
| os << formatv(element.isOptional() ? opOneOptionalTemplate |
| : opOneVariadicTemplate, |
| sanitizeName(element.name), kind, numElements, i); |
| } else if (seenVariableLength) { |
| os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name), |
| kind, numElements, i); |
| } else { |
| os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i); |
| } |
| } |
| return; |
| } |
| |
| // Handle the operations where variadic groups have the same size. |
| if (op.getTrait(sameSizeTrait)) { |
| // Count the number of simple elements |
| unsigned numSimpleLength = 0; |
| for (unsigned i = 0; i < numElements; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (!element.isVariableLength()) { |
| ++numSimpleLength; |
| } |
| } |
| |
| // Generate the accessors |
| int numPrecedingSimple = 0; |
| int numPrecedingVariadic = 0; |
| for (unsigned i = 0; i < numElements; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (!element.name.empty()) { |
| os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name), |
| kind, numSimpleLength, numVariadicGroups, |
| numPrecedingSimple, numPrecedingVariadic); |
| os << formatv(element.isVariableLength() |
| ? opVariadicEqualVariadicTemplate |
| : opVariadicEqualSimpleTemplate, |
| kind); |
| } |
| if (element.isVariableLength()) |
| ++numPrecedingVariadic; |
| else |
| ++numPrecedingSimple; |
| } |
| return; |
| } |
| |
| // Handle the operations where the size of groups (variadic or not) is |
| // provided as an attribute. For non-variadic elements, make sure to return |
| // an element rather than a singleton container. |
| if (op.getTrait(attrSizedTrait)) { |
| for (unsigned i = 0; i < numElements; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.name.empty()) |
| continue; |
| std::string trailing; |
| if (!element.isVariableLength()) |
| trailing = "[0]"; |
| else if (element.isOptional()) |
| trailing = std::string( |
| formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); |
| os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind, |
| i, trailing); |
| } |
| return; |
| } |
| |
| llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); |
| } |
| |
| /// Free function helpers accessing Operator components. |
| static int getNumOperands(const Operator &op) { return op.getNumOperands(); } |
| static const NamedTypeConstraint &getOperand(const Operator &op, int i) { |
| return op.getOperand(i); |
| } |
| static int getNumResults(const Operator &op) { return op.getNumResults(); } |
| static const NamedTypeConstraint &getResult(const Operator &op, int i) { |
| return op.getResult(i); |
| } |
| |
| /// Emits accessors to Op operands. |
| static void emitOperandAccessors(const Operator &op, raw_ostream &os) { |
| emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(), |
| getNumOperands(op), getOperand); |
| } |
| |
| /// Emits accessors Op results. |
| static void emitResultAccessors(const Operator &op, raw_ostream &os) { |
| emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(), |
| getNumResults(op), getResult); |
| } |
| |
| /// Emits accessors to Op attributes. |
| static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { |
| for (const auto &namedAttr : op.getAttributes()) { |
| // Skip "derived" attributes because they are just C++ functions that we |
| // don't currently expose. |
| if (namedAttr.attr.isDerivedAttr()) |
| continue; |
| |
| if (namedAttr.name.empty()) |
| continue; |
| |
| std::string sanitizedName = sanitizeName(namedAttr.name); |
| |
| // Unit attributes are handled specially. |
| if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") { |
| os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name); |
| os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name); |
| os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); |
| continue; |
| } |
| |
| if (namedAttr.attr.isOptional()) { |
| os << formatv(optionalAttributeGetterTemplate, sanitizedName, |
| namedAttr.name); |
| os << formatv(optionalAttributeSetterTemplate, sanitizedName, |
| namedAttr.name); |
| os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); |
| } else { |
| os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name); |
| os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name); |
| // Non-optional attributes cannot be deleted. |
| } |
| } |
| } |
| |
| /// Template for the default auto-generated builder. |
| /// {0} is a comma-separated list of builder arguments, including the trailing |
| /// `loc` and `ip`; |
| /// {1} is the code populating `operands`, `results` and `attributes`, |
| /// `successors` fields. |
| constexpr const char *initTemplate = R"Py( |
| def __init__(self, {0}): |
| operands = [] |
| results = [] |
| attributes = {{} |
| regions = None |
| {1} |
| super().__init__({2}) |
| )Py"; |
| |
| /// Template for appending a single element to the operand/result list. |
| /// {0} is the field name. |
| constexpr const char *singleOperandAppendTemplate = "operands.append({0})"; |
| constexpr const char *singleResultAppendTemplate = "results.append({0})"; |
| |
| /// Template for appending an optional element to the operand/result list. |
| /// {0} is the field name. |
| constexpr const char *optionalAppendOperandTemplate = |
| "if {0} is not None: operands.append({0})"; |
| constexpr const char *optionalAppendAttrSizedOperandsTemplate = |
| "operands.append({0})"; |
| constexpr const char *optionalAppendResultTemplate = |
| "if {0} is not None: results.append({0})"; |
| |
| /// Template for appending a list of elements to the operand/result list. |
| /// {0} is the field name. |
| constexpr const char *multiOperandAppendTemplate = |
| "operands.extend(_get_op_results_or_values({0}))"; |
| constexpr const char *multiOperandAppendPackTemplate = |
| "operands.append(_get_op_results_or_values({0}))"; |
| constexpr const char *multiResultAppendTemplate = "results.extend({0})"; |
| |
| /// Template for attribute builder from raw input in the operation builder. |
| /// {0} is the builder argument name; |
| /// {1} is the attribute builder from raw; |
| /// {2} is the attribute builder from raw. |
| /// Use the value the user passed in if either it is already an Attribute or |
| /// there is no method registered to make it an Attribute. |
| constexpr const char *initAttributeWithBuilderTemplate = |
| R"Py(attributes["{1}"] = ({0} if ( |
| isinstance({0}, _ods_ir.Attribute) or |
| not _ods_ir.AttrBuilder.contains('{2}')) else |
| _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py"; |
| |
| /// Template for attribute builder from raw input for optional attribute in the |
| /// operation builder. |
| /// {0} is the builder argument name; |
| /// {1} is the attribute builder from raw; |
| /// {2} is the attribute builder from raw. |
| /// Use the value the user passed in if either it is already an Attribute or |
| /// there is no method registered to make it an Attribute. |
| constexpr const char *initOptionalAttributeWithBuilderTemplate = |
| R"Py(if {0} is not None: attributes["{1}"] = ({0} if ( |
| isinstance({0}, _ods_ir.Attribute) or |
| not _ods_ir.AttrBuilder.contains('{2}')) else |
| _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py"; |
| |
| constexpr const char *initUnitAttributeTemplate = |
| R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( |
| _ods_get_default_loc_context(loc)))Py"; |
| |
| /// Template to initialize the successors list in the builder if there are any |
| /// successors. |
| /// {0} is the value to initialize the successors list to. |
| constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; |
| |
| /// Template to append or extend the list of successors in the builder. |
| /// {0} is the list method ('append' or 'extend'); |
| /// {1} is the value to add. |
| constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; |
| |
| /// Returns true if the SameArgumentAndResultTypes trait can be used to infer |
| /// result types of the given operation. |
| static bool hasSameArgumentAndResultTypes(const Operator &op) { |
| return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && |
| op.getNumVariableLengthResults() == 0; |
| } |
| |
| /// Returns true if the FirstAttrDerivedResultType trait can be used to infer |
| /// result types of the given operation. |
| static bool hasFirstAttrDerivedResultTypes(const Operator &op) { |
| return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && |
| op.getNumVariableLengthResults() == 0; |
| } |
| |
| /// Returns true if the InferTypeOpInterface can be used to infer result types |
| /// of the given operation. |
| static bool hasInferTypeInterface(const Operator &op) { |
| return op.getTrait("::mlir::InferTypeOpInterface::Trait") && |
| op.getNumRegions() == 0; |
| } |
| |
| /// Returns true if there is a trait or interface that can be used to infer |
| /// result types of the given operation. |
| static bool canInferType(const Operator &op) { |
| return hasSameArgumentAndResultTypes(op) || |
| hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); |
| } |
| |
| /// Populates `builderArgs` with result names if the builder is expected to |
| /// accept them as arguments. |
| static void |
| populateBuilderArgsResults(const Operator &op, |
| SmallVectorImpl<std::string> &builderArgs) { |
| if (canInferType(op)) |
| return; |
| |
| for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| std::string name = op.getResultName(i).str(); |
| if (name.empty()) { |
| if (op.getNumResults() == 1) { |
| // Special case for one result, make the default name be 'result' |
| // to properly match the built-in result accessor. |
| name = "result"; |
| } else { |
| name = formatv("_gen_res_{0}", i); |
| } |
| } |
| name = sanitizeName(name); |
| builderArgs.push_back(name); |
| } |
| } |
| |
| /// Populates `builderArgs` with the Python-compatible names of builder function |
| /// arguments using intermixed attributes and operands in the same order as they |
| /// appear in the `arguments` field of the op definition. Additionally, |
| /// `operandNames` is populated with names of operands in their order of |
| /// appearance. |
| static void populateBuilderArgs(const Operator &op, |
| SmallVectorImpl<std::string> &builderArgs, |
| SmallVectorImpl<std::string> &operandNames) { |
| for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| std::string name = op.getArgName(i).str(); |
| if (name.empty()) |
| name = formatv("_gen_arg_{0}", i); |
| name = sanitizeName(name); |
| builderArgs.push_back(name); |
| if (!isa<NamedAttribute *>(op.getArg(i))) |
| operandNames.push_back(name); |
| } |
| } |
| |
| /// Populates `builderArgs` with the Python-compatible names of builder function |
| /// successor arguments. Additionally, `successorArgNames` is also populated. |
| static void |
| populateBuilderArgsSuccessors(const Operator &op, |
| SmallVectorImpl<std::string> &builderArgs, |
| SmallVectorImpl<std::string> &successorArgNames) { |
| |
| for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { |
| NamedSuccessor successor = op.getSuccessor(i); |
| std::string name = std::string(successor.name); |
| if (name.empty()) |
| name = formatv("_gen_successor_{0}", i); |
| name = sanitizeName(name); |
| builderArgs.push_back(name); |
| successorArgNames.push_back(name); |
| } |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up operation attributes. `argNames` is expected to contain |
| /// the names of builder arguments that correspond to op arguments, i.e. to the |
| /// operands and attributes in the same order as they appear in the `arguments` |
| /// field. |
| static void |
| populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames, |
| SmallVectorImpl<std::string> &builderLines) { |
| builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)"); |
| for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| Argument arg = op.getArg(i); |
| auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg); |
| if (!attribute) |
| continue; |
| |
| // Unit attributes are handled specially. |
| if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") { |
| builderLines.push_back( |
| formatv(initUnitAttributeTemplate, attribute->name, argNames[i])); |
| continue; |
| } |
| |
| builderLines.push_back(formatv( |
| attribute->attr.isOptional() || attribute->attr.hasDefaultValue() |
| ? initOptionalAttributeWithBuilderTemplate |
| : initAttributeWithBuilderTemplate, |
| argNames[i], attribute->name, attribute->attr.getAttrDefName())); |
| } |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up successors. successorArgNames is expected to correspond |
| /// to the Python argument name for each successor on the op. |
| static void |
| populateBuilderLinesSuccessors(const Operator &op, |
| ArrayRef<std::string> successorArgNames, |
| SmallVectorImpl<std::string> &builderLines) { |
| if (successorArgNames.empty()) { |
| builderLines.push_back(formatv(initSuccessorsTemplate, "None")); |
| return; |
| } |
| |
| builderLines.push_back(formatv(initSuccessorsTemplate, "[]")); |
| for (int i = 0, e = successorArgNames.size(); i < e; ++i) { |
| auto &argName = successorArgNames[i]; |
| const NamedSuccessor &successor = op.getSuccessor(i); |
| builderLines.push_back(formatv(addSuccessorTemplate, |
| successor.isVariadic() ? "extend" : "append", |
| argName)); |
| } |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up op operands. |
| static void |
| populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names, |
| SmallVectorImpl<std::string> &builderLines) { |
| bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; |
| |
| // For each element, find or generate a name. |
| for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
| const NamedTypeConstraint &element = op.getOperand(i); |
| std::string name = names[i]; |
| |
| // Choose the formatting string based on the element kind. |
| StringRef formatString; |
| if (!element.isVariableLength()) { |
| formatString = singleOperandAppendTemplate; |
| } else if (element.isOptional()) { |
| if (sizedSegments) { |
| formatString = optionalAppendAttrSizedOperandsTemplate; |
| } else { |
| formatString = optionalAppendOperandTemplate; |
| } |
| } else { |
| assert(element.isVariadic() && "unhandled element group type"); |
| // If emitting with sizedSegments, then we add the actual list-typed |
| // element. Otherwise, we extend the actual operands. |
| if (sizedSegments) { |
| formatString = multiOperandAppendPackTemplate; |
| } else { |
| formatString = multiOperandAppendTemplate; |
| } |
| } |
| |
| builderLines.push_back(formatv(formatString.data(), name)); |
| } |
| } |
| |
| /// Python code template for deriving the operation result types from its |
| /// attribute: |
| /// - {0} is the name of the attribute from which to derive the types. |
| constexpr const char *deriveTypeFromAttrTemplate = |
| R"Py(_ods_result_type_source_attr = attributes["{0}"] |
| _ods_derived_result_type = ( |
| _ods_ir.TypeAttr(_ods_result_type_source_attr).value |
| if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else |
| _ods_result_type_source_attr.type))Py"; |
| |
| /// Python code template appending {0} type {1} times to the results list. |
| constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; |
| |
| /// Appends the given multiline string as individual strings into |
| /// `builderLines`. |
| static void appendLineByLine(StringRef string, |
| SmallVectorImpl<std::string> &builderLines) { |
| |
| std::pair<StringRef, StringRef> split = std::make_pair(string, string); |
| do { |
| split = split.second.split('\n'); |
| builderLines.push_back(split.first.str()); |
| } while (!split.second.empty()); |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up op results. |
| static void |
| populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names, |
| SmallVectorImpl<std::string> &builderLines) { |
| bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; |
| |
| if (hasSameArgumentAndResultTypes(op)) { |
| builderLines.push_back(formatv(appendSameResultsTemplate, |
| "operands[0].type", op.getNumResults())); |
| return; |
| } |
| |
| if (hasFirstAttrDerivedResultTypes(op)) { |
| const NamedAttribute &firstAttr = op.getAttribute(0); |
| assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " |
| "from which the type is derived"); |
| appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), |
| builderLines); |
| builderLines.push_back(formatv(appendSameResultsTemplate, |
| "_ods_derived_result_type", |
| op.getNumResults())); |
| return; |
| } |
| |
| if (hasInferTypeInterface(op)) |
| return; |
| |
| // For each element, find or generate a name. |
| for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| const NamedTypeConstraint &element = op.getResult(i); |
| std::string name = names[i]; |
| |
| // Choose the formatting string based on the element kind. |
| StringRef formatString; |
| if (!element.isVariableLength()) { |
| formatString = singleResultAppendTemplate; |
| } else if (element.isOptional()) { |
| formatString = optionalAppendResultTemplate; |
| } else { |
| assert(element.isVariadic() && "unhandled element group type"); |
| // If emitting with sizedSegments, then we add the actual list-typed |
| // element. Otherwise, we extend the actual operands. |
| if (sizedSegments) { |
| formatString = singleResultAppendTemplate; |
| } else { |
| formatString = multiResultAppendTemplate; |
| } |
| } |
| |
| builderLines.push_back(formatv(formatString.data(), name)); |
| } |
| } |
| |
| /// If the operation has variadic regions, adds a builder argument to specify |
| /// the number of those regions and builder lines to forward it to the generic |
| /// constructor. |
| static void populateBuilderRegions(const Operator &op, |
| SmallVectorImpl<std::string> &builderArgs, |
| SmallVectorImpl<std::string> &builderLines) { |
| if (op.hasNoVariadicRegions()) |
| return; |
| |
| // This is currently enforced when Operator is constructed. |
| assert(op.getNumVariadicRegions() == 1 && |
| op.getRegion(op.getNumRegions() - 1).isVariadic() && |
| "expected the last region to be varidic"); |
| |
| const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); |
| std::string name = |
| ("num_" + region.name.take_front().lower() + region.name.drop_front()) |
| .str(); |
| builderArgs.push_back(name); |
| builderLines.push_back( |
| formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); |
| } |
| |
| /// Emits a default builder constructing an operation from the list of its |
| /// result types, followed by a list of its operands. Returns vector |
| /// of fully built functionArgs for downstream users (to save having to |
| /// rebuild anew). |
| static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op, |
| raw_ostream &os) { |
| SmallVector<std::string> builderArgs; |
| SmallVector<std::string> builderLines; |
| SmallVector<std::string> operandArgNames; |
| SmallVector<std::string> successorArgNames; |
| builderArgs.reserve(op.getNumOperands() + op.getNumResults() + |
| op.getNumNativeAttributes() + op.getNumSuccessors()); |
| populateBuilderArgsResults(op, builderArgs); |
| size_t numResultArgs = builderArgs.size(); |
| populateBuilderArgs(op, builderArgs, operandArgNames); |
| size_t numOperandAttrArgs = builderArgs.size() - numResultArgs; |
| populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); |
| |
| populateBuilderLinesOperand(op, operandArgNames, builderLines); |
| populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs), |
| builderLines); |
| populateBuilderLinesResult( |
| op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines); |
| populateBuilderLinesSuccessors(op, successorArgNames, builderLines); |
| populateBuilderRegions(op, builderArgs, builderLines); |
| |
| // Layout of builderArgs vector elements: |
| // [ result_args operand_attr_args successor_args regions ] |
| |
| // Determine whether the argument corresponding to a given index into the |
| // builderArgs vector is a python keyword argument or not. |
| auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool { |
| // All result, successor, and region arguments are positional arguments. |
| if ((builderArgIndex < numResultArgs) || |
| (builderArgIndex >= (numResultArgs + numOperandAttrArgs))) |
| return false; |
| // Keyword arguments: |
| // - optional named attributes (including unit attributes) |
| // - default-valued named attributes |
| // - optional operands |
| Argument a = op.getArg(builderArgIndex - numResultArgs); |
| if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a)) |
| return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue()); |
| if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a)) |
| return ntype->isOptional(); |
| return false; |
| }; |
| |
| // StringRefs in functionArgs refer to strings allocated by builderArgs. |
| SmallVector<StringRef> functionArgs; |
| |
| // Add positional arguments. |
| for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { |
| if (!isKeywordArgFn(i)) |
| functionArgs.push_back(builderArgs[i]); |
| } |
| |
| // Add a bare '*' to indicate that all following arguments must be keyword |
| // arguments. |
| functionArgs.push_back("*"); |
| |
| // Add a default 'None' value to each keyword arg string, and then add to the |
| // function args list. |
| for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { |
| if (isKeywordArgFn(i)) { |
| builderArgs[i].append("=None"); |
| functionArgs.push_back(builderArgs[i]); |
| } |
| } |
| functionArgs.push_back("loc=None"); |
| functionArgs.push_back("ip=None"); |
| |
| SmallVector<std::string> initArgs; |
| initArgs.push_back("self.OPERATION_NAME"); |
| initArgs.push_back("self._ODS_REGIONS"); |
| initArgs.push_back("self._ODS_OPERAND_SEGMENTS"); |
| initArgs.push_back("self._ODS_RESULT_SEGMENTS"); |
| initArgs.push_back("attributes=attributes"); |
| if (!hasInferTypeInterface(op)) |
| initArgs.push_back("results=results"); |
| initArgs.push_back("operands=operands"); |
| initArgs.push_back("successors=_ods_successors"); |
| initArgs.push_back("regions=regions"); |
| initArgs.push_back("loc=loc"); |
| initArgs.push_back("ip=ip"); |
| |
| os << formatv(initTemplate, llvm::join(functionArgs, ", "), |
| llvm::join(builderLines, "\n "), llvm::join(initArgs, ", ")); |
| return llvm::to_vector<8>( |
| llvm::map_range(functionArgs, [](StringRef s) { return s.str(); })); |
| } |
| |
| static void emitSegmentSpec( |
| const Operator &op, const char *kind, |
| llvm::function_ref<int(const Operator &)> getNumElements, |
| llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| getElement, |
| raw_ostream &os) { |
| std::string segmentSpec("["); |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.isOptional()) { |
| segmentSpec.append("0,"); |
| } else if (element.isVariadic()) { |
| segmentSpec.append("-1,"); |
| } else { |
| segmentSpec.append("1,"); |
| } |
| } |
| segmentSpec.append("]"); |
| |
| os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); |
| } |
| |
| static void emitRegionAttributes(const Operator &op, raw_ostream &os) { |
| // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). |
| // Note that the base OpView class defines this as (0, True). |
| unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); |
| os << formatv(opClassRegionSpecTemplate, minRegionCount, |
| op.hasNoVariadicRegions() ? "True" : "False"); |
| } |
| |
| /// Emits named accessors to regions. |
| static void emitRegionAccessors(const Operator &op, raw_ostream &os) { |
| for (const auto &en : llvm::enumerate(op.getRegions())) { |
| const NamedRegion ®ion = en.value(); |
| if (region.name.empty()) |
| continue; |
| |
| assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && |
| "expected only the last region to be variadic"); |
| os << formatv(regionAccessorTemplate, sanitizeName(region.name), |
| std::to_string(en.index()) + |
| (region.isVariadic() ? ":" : "")); |
| } |
| } |
| |
| /// Emits builder that extracts results from op |
| static void emitValueBuilder(const Operator &op, |
| SmallVector<std::string> functionArgs, |
| raw_ostream &os) { |
| // Params with (possibly) default args. |
| auto valueBuilderParams = |
| llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) { |
| SmallVector<StringRef> argMaybeDefault = |
| llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "=")); |
| auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]); |
| if (argMaybeDefault.size() == 2) |
| return arg + "=" + argMaybeDefault[1].str(); |
| return arg; |
| }); |
| // Actual args passed to op builder (e.g., opParam=op_param). |
| auto opBuilderArgs = llvm::map_range( |
| llvm::make_filter_range(functionArgs, |
| [](const std::string &s) { return s != "*"; }), |
| [](const std::string &arg) { |
| auto lhs = *llvm::split(arg, "=").begin(); |
| return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str(); |
| }); |
| std::string nameWithoutDialect = sanitizeName( |
| op.getOperationName().substr(op.getOperationName().find('.') + 1)); |
| if (nameWithoutDialect == op.getCppClassName()) |
| nameWithoutDialect += "_"; |
| std::string params = llvm::join(valueBuilderParams, ", "); |
| std::string args = llvm::join(opBuilderArgs, ", "); |
| const char *type = |
| (op.getNumResults() > 1 |
| ? "_Sequence[_ods_ir.Value]" |
| : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")); |
| if (op.getNumVariableLengthResults() > 0) { |
| os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect, |
| op.getCppClassName(), params, args, type); |
| } else { |
| const char *results; |
| if (op.getNumResults() == 0) { |
| results = ""; |
| } else if (op.getNumResults() == 1) { |
| results = ".result"; |
| } else { |
| results = ".results"; |
| } |
| os << formatv(valueBuilderTemplate, nameWithoutDialect, |
| op.getCppClassName(), params, args, type, results); |
| } |
| } |
| |
| /// Emits bindings for a specific Op to the given output stream. |
| static void emitOpBindings(const Operator &op, raw_ostream &os) { |
| os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); |
| |
| // Sized segments. |
| if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { |
| emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); |
| } |
| if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { |
| emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); |
| } |
| |
| emitRegionAttributes(op, os); |
| SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os); |
| emitOperandAccessors(op, os); |
| emitAttributeAccessors(op, os); |
| emitResultAccessors(op, os); |
| emitRegionAccessors(op, os); |
| emitValueBuilder(op, functionArgs, os); |
| } |
| |
| /// Emits bindings for the dialect specified in the command line, including file |
| /// headers and utilities. Returns `false` on success to comply with Tablegen |
| /// registration requirements. |
| static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) { |
| if (clDialectName.empty()) |
| llvm::PrintFatalError("dialect name not provided"); |
| |
| os << fileHeader; |
| if (!clDialectExtensionName.empty()) |
| os << formatv(dialectExtensionTemplate, clDialectName.getValue()); |
| else |
| os << formatv(dialectClassTemplate, clDialectName.getValue()); |
| |
| for (const Record *rec : records.getAllDerivedDefinitions("Op")) { |
| Operator op(rec); |
| if (op.getDialectName() == clDialectName.getValue()) |
| emitOpBindings(op, os); |
| } |
| return false; |
| } |
| |
| static GenRegistration |
| genPythonBindings("gen-python-op-bindings", |
| "Generate Python bindings for MLIR Ops", &emitAllOps); |