| |
| |
| //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir-c/Interfaces.h" |
| |
| #include "mlir/CAPI/IR.h" |
| #include "mlir/CAPI/Interfaces.h" |
| #include "mlir/CAPI/Support.h" |
| #include "mlir/CAPI/Wrap.h" |
| #include "mlir/IR/ValueRange.h" |
| #include "mlir/Interfaces/InferTypeOpInterface.h" |
| #include "llvm/ADT/ScopeExit.h" |
| #include <optional> |
| |
| using namespace mlir; |
| |
| namespace { |
| |
| std::optional<RegisteredOperationName> |
| getRegisteredOperationName(MlirContext context, MlirStringRef opName) { |
| StringRef name(opName.data, opName.length); |
| std::optional<RegisteredOperationName> info = |
| RegisteredOperationName::lookup(name, unwrap(context)); |
| return info; |
| } |
| |
| std::optional<Location> maybeGetLocation(MlirLocation location) { |
| std::optional<Location> maybeLocation; |
| if (!mlirLocationIsNull(location)) |
| maybeLocation = unwrap(location); |
| return maybeLocation; |
| } |
| |
| SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) { |
| SmallVector<Value> unwrappedOperands; |
| (void)unwrapList(nOperands, operands, unwrappedOperands); |
| return unwrappedOperands; |
| } |
| |
| DictionaryAttr unwrapAttributes(MlirAttribute attributes) { |
| DictionaryAttr attributeDict; |
| if (!mlirAttributeIsNull(attributes)) |
| attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes)); |
| return attributeDict; |
| } |
| |
| SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions, |
| MlirRegion *regions) { |
| // Create a vector of unique pointers to regions and make sure they are not |
| // deleted when exiting the scope. This is a hack caused by C++ API expecting |
| // an list of unique pointers to regions (without ownership transfer |
| // semantics) and C API making ownership transfer explicit. |
| SmallVector<std::unique_ptr<Region>> unwrappedRegions; |
| unwrappedRegions.reserve(nRegions); |
| for (intptr_t i = 0; i < nRegions; ++i) |
| unwrappedRegions.emplace_back(unwrap(*(regions + i))); |
| auto cleaner = llvm::make_scope_exit([&]() { |
| for (auto ®ion : unwrappedRegions) |
| region.release(); |
| }); |
| return unwrappedRegions; |
| } |
| |
| } // namespace |
| |
| bool mlirOperationImplementsInterface(MlirOperation operation, |
| MlirTypeID interfaceTypeID) { |
| std::optional<RegisteredOperationName> info = |
| unwrap(operation)->getRegisteredInfo(); |
| return info && info->hasInterface(unwrap(interfaceTypeID)); |
| } |
| |
| bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, |
| MlirContext context, |
| MlirTypeID interfaceTypeID) { |
| std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup( |
| StringRef(operationName.data, operationName.length), unwrap(context)); |
| return info && info->hasInterface(unwrap(interfaceTypeID)); |
| } |
| |
| MlirTypeID mlirInferTypeOpInterfaceTypeID() { |
| return wrap(InferTypeOpInterface::getInterfaceID()); |
| } |
| |
| MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( |
| MlirStringRef opName, MlirContext context, MlirLocation location, |
| intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, |
| void *properties, intptr_t nRegions, MlirRegion *regions, |
| MlirTypesCallback callback, void *userData) { |
| StringRef name(opName.data, opName.length); |
| std::optional<RegisteredOperationName> info = |
| getRegisteredOperationName(context, opName); |
| if (!info) |
| return mlirLogicalResultFailure(); |
| |
| std::optional<Location> maybeLocation = maybeGetLocation(location); |
| SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands); |
| DictionaryAttr attributeDict = unwrapAttributes(attributes); |
| SmallVector<std::unique_ptr<Region>> unwrappedRegions = |
| unwrapRegions(nRegions, regions); |
| |
| SmallVector<Type> inferredTypes; |
| if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes( |
| unwrap(context), maybeLocation, unwrappedOperands, attributeDict, |
| properties, unwrappedRegions, inferredTypes))) |
| return mlirLogicalResultFailure(); |
| |
| SmallVector<MlirType> wrappedInferredTypes; |
| wrappedInferredTypes.reserve(inferredTypes.size()); |
| for (Type t : inferredTypes) |
| wrappedInferredTypes.push_back(wrap(t)); |
| callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); |
| return mlirLogicalResultSuccess(); |
| } |
| |
| MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() { |
| return wrap(InferShapedTypeOpInterface::getInterfaceID()); |
| } |
| |
| MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes( |
| MlirStringRef opName, MlirContext context, MlirLocation location, |
| intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, |
| void *properties, intptr_t nRegions, MlirRegion *regions, |
| MlirShapedTypeComponentsCallback callback, void *userData) { |
| std::optional<RegisteredOperationName> info = |
| getRegisteredOperationName(context, opName); |
| if (!info) |
| return mlirLogicalResultFailure(); |
| |
| std::optional<Location> maybeLocation = maybeGetLocation(location); |
| SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands); |
| DictionaryAttr attributeDict = unwrapAttributes(attributes); |
| SmallVector<std::unique_ptr<Region>> unwrappedRegions = |
| unwrapRegions(nRegions, regions); |
| |
| SmallVector<ShapedTypeComponents> inferredTypeComponents; |
| if (failed(info->getInterface<InferShapedTypeOpInterface>() |
| ->inferReturnTypeComponents( |
| unwrap(context), maybeLocation, |
| mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)), |
| attributeDict, properties, unwrappedRegions, |
| inferredTypeComponents))) |
| return mlirLogicalResultFailure(); |
| |
| bool hasRank; |
| intptr_t rank; |
| const int64_t *shapeData; |
| for (const ShapedTypeComponents &t : inferredTypeComponents) { |
| if (t.hasRank()) { |
| hasRank = true; |
| rank = t.getDims().size(); |
| shapeData = t.getDims().data(); |
| } else { |
| hasRank = false; |
| rank = 0; |
| shapeData = nullptr; |
| } |
| callback(hasRank, rank, shapeData, wrap(t.getElementType()), |
| wrap(t.getAttribute()), userData); |
| } |
| return mlirLogicalResultSuccess(); |
| } |