Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 1 | //===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" |
| 10 | |
| 11 | #include "../PassDetail.h" |
| 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
River Riddle | 23aa5a7 | 2022-02-26 22:49:54 | [diff] [blame^] | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
gysit | b7f2c10 | 2021-12-15 12:14:35 | [diff] [blame] | 14 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
Nicolas Vasilache | e0dc3db | 2020-10-09 14:31:52 | [diff] [blame] | 15 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
Julian Gross | e231070 | 2021-02-10 12:53:11 | [diff] [blame] | 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 17 | #include "mlir/Dialect/SCF/SCF.h" |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 18 | |
| 19 | using namespace mlir; |
| 20 | using namespace mlir::linalg; |
| 21 | |
| 22 | /// Helper function to extract the operand types that are passed to the |
| 23 | /// generated CallOp. MemRefTypes have their layout canonicalized since the |
| 24 | /// information is not used in signature generation. |
| 25 | /// Note that static size information is not modified. |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 26 | static SmallVector<Type, 4> extractOperandTypes(Operation *op) { |
| 27 | SmallVector<Type, 4> result; |
| 28 | result.reserve(op->getNumOperands()); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 29 | for (auto type : op->getOperandTypes()) { |
| 30 | // The underlying descriptor type (e.g. LLVM) does not have layout |
| 31 | // information. Canonicalizing the type at the level of std when going into |
| 32 | // a library call avoids needing to introduce DialectCastOp. |
| 33 | if (auto memrefType = type.dyn_cast<MemRefType>()) |
| 34 | result.push_back(eraseStridedLayout(memrefType)); |
| 35 | else |
| 36 | result.push_back(type); |
| 37 | } |
| 38 | return result; |
| 39 | } |
| 40 | |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 41 | // Get a SymbolRefAttr containing the library function name for the LinalgOp. |
| 42 | // If the library function does not exist, insert a declaration. |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 43 | static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, |
| 44 | PatternRewriter &rewriter) { |
| 45 | auto linalgOp = cast<LinalgOp>(op); |
| 46 | auto fnName = linalgOp.getLibraryCallName(); |
| 47 | if (fnName.empty()) { |
| 48 | op->emitWarning("No library call defined for: ") << *op; |
| 49 | return {}; |
| 50 | } |
| 51 | |
| 52 | // fnName is a dynamic std::string, unique it via a SymbolRefAttr. |
Chris Lattner | faf1c22 | 2021-08-30 16:31:48 | [diff] [blame] | 53 | FlatSymbolRefAttr fnNameAttr = |
| 54 | SymbolRefAttr::get(rewriter.getContext(), fnName); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 55 | auto module = op->getParentOfType<ModuleOp>(); |
Chris Lattner | 41d4aa7 | 2021-08-29 21:22:24 | [diff] [blame] | 56 | if (module.lookupSymbol(fnNameAttr.getAttr())) |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 57 | return fnNameAttr; |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 58 | |
Nicolas Vasilache | e0dc3db | 2020-10-09 14:31:52 | [diff] [blame] | 59 | SmallVector<Type, 4> inputTypes(extractOperandTypes(op)); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 60 | assert(op->getNumResults() == 0 && |
| 61 | "Library call for linalg operation can be generated only for ops that " |
| 62 | "have void return types"); |
River Riddle | 1b97cdf | 2020-12-17 20:24:45 | [diff] [blame] | 63 | auto libFnType = rewriter.getFunctionType(inputTypes, {}); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 64 | |
| 65 | OpBuilder::InsertionGuard guard(rewriter); |
| 66 | // Insert before module terminator. |
| 67 | rewriter.setInsertionPoint(module.getBody(), |
| 68 | std::prev(module.getBody()->end())); |
| 69 | FuncOp funcOp = |
Rahul Joshi | 74145d5 | 2020-07-07 23:15:44 | [diff] [blame] | 70 | rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 71 | // Insert a function attribute that will trigger the emission of the |
| 72 | // corresponding `_mlir_ciface_xxx` interface so that external libraries see |
| 73 | // a normalized ABI. This interface is added during std to llvm conversion. |
Christian Sigg | 1ffc1aa | 2020-12-12 09:50:41 | [diff] [blame] | 74 | funcOp->setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext())); |
Rahul Joshi | b7382ed | 2020-11-13 21:04:53 | [diff] [blame] | 75 | funcOp.setPrivate(); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 76 | return fnNameAttr; |
| 77 | } |
| 78 | |
Nicolas Vasilache | e0dc3db | 2020-10-09 14:31:52 | [diff] [blame] | 79 | static SmallVector<Value, 4> |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 80 | createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, |
| 81 | ValueRange operands) { |
| 82 | SmallVector<Value, 4> res; |
| 83 | res.reserve(operands.size()); |
| 84 | for (auto op : operands) { |
| 85 | auto memrefType = op.getType().dyn_cast<MemRefType>(); |
| 86 | if (!memrefType) { |
| 87 | res.push_back(op); |
| 88 | continue; |
| 89 | } |
| 90 | Value cast = |
Julian Gross | e231070 | 2021-02-10 12:53:11 | [diff] [blame] | 91 | b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 92 | res.push_back(cast); |
| 93 | } |
| 94 | return res; |
| 95 | } |
| 96 | |
Nicolas Vasilache | e0dc3db | 2020-10-09 14:31:52 | [diff] [blame] | 97 | LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( |
River Riddle | 76f3c2f | 2021-03-23 20:44:14 | [diff] [blame] | 98 | LinalgOp op, PatternRewriter &rewriter) const { |
Nicolas Vasilache | e0dc3db | 2020-10-09 14:31:52 | [diff] [blame] | 99 | auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); |
Tobias Gysi | f1844f1 | 2021-06-23 09:06:04 | [diff] [blame] | 100 | if (!libraryCallName) |
Nicolas Vasilache | e0dc3db | 2020-10-09 14:31:52 | [diff] [blame] | 101 | return failure(); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 102 | |
Tobias Gysi | 27ad213 | 2021-04-19 12:23:11 | [diff] [blame] | 103 | // TODO: Add support for more complex library call signatures that include |
| 104 | // indices or captured values. |
River Riddle | 23aa5a7 | 2022-02-26 22:49:54 | [diff] [blame^] | 105 | rewriter.replaceOpWithNewOp<func::CallOp>( |
Nicolas Vasilache | e0dc3db | 2020-10-09 14:31:52 | [diff] [blame] | 106 | op, libraryCallName.getValue(), TypeRange(), |
| 107 | createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), |
| 108 | op->getOperands())); |
| 109 | return success(); |
| 110 | } |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 111 | |
Alexander Belyaev | 25bf6a2 | 2022-01-31 17:51:39 | [diff] [blame] | 112 | |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 113 | /// Populate the given list with patterns that convert from Linalg to Standard. |
Nicolas Vasilache | e0dc3db | 2020-10-09 14:31:52 | [diff] [blame] | 114 | void mlir::linalg::populateLinalgToStandardConversionPatterns( |
Chris Lattner | dc4e913 | 2021-03-22 23:58:34 | [diff] [blame] | 115 | RewritePatternSet &patterns) { |
River Riddle | 9db53a1 | 2020-07-07 08:35:23 | [diff] [blame] | 116 | // TODO: ConvOp conversion needs to export a descriptor with relevant |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 117 | // attribute values such as kernel striding and dilation. |
Alexander Belyaev | ebc8153 | 2022-02-01 17:07:33 | [diff] [blame] | 118 | patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext()); |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 119 | } |
| 120 | |
| 121 | namespace { |
| 122 | struct ConvertLinalgToStandardPass |
| 123 | : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> { |
| 124 | void runOnOperation() override; |
| 125 | }; |
| 126 | } // namespace |
| 127 | |
| 128 | void ConvertLinalgToStandardPass::runOnOperation() { |
| 129 | auto module = getOperation(); |
| 130 | ConversionTarget target(getContext()); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 131 | target.addLegalDialect<AffineDialect, arith::ArithmeticDialect, |
River Riddle | 23aa5a7 | 2022-02-26 22:49:54 | [diff] [blame^] | 132 | func::FuncDialect, memref::MemRefDialect, |
| 133 | scf::SCFDialect>(); |
| 134 | target.addLegalOp<ModuleOp, FuncOp, func::ReturnOp>(); |
Chris Lattner | dc4e913 | 2021-03-22 23:58:34 | [diff] [blame] | 135 | RewritePatternSet patterns(&getContext()); |
Chris Lattner | 3a506b3 | 2021-03-20 23:29:41 | [diff] [blame] | 136 | populateLinalgToStandardConversionPatterns(patterns); |
River Riddle | 3fffffa8 | 2020-10-27 00:25:01 | [diff] [blame] | 137 | if (failed(applyFullConversion(module, target, std::move(patterns)))) |
Nicolas Vasilache | f1b9720 | 2020-05-15 04:22:21 | [diff] [blame] | 138 | signalPassFailure(); |
| 139 | } |
| 140 | |
| 141 | std::unique_ptr<OperationPass<ModuleOp>> |
| 142 | mlir::createConvertLinalgToStandardPass() { |
| 143 | return std::make_unique<ConvertLinalgToStandardPass>(); |
| 144 | } |