blob: cfa93c7e0ea6a135250972843c722ce8867ce4a8 [file] [log] [blame]
Nicolas Vasilachef1b97202020-05-15 04:22:211//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
10
11#include "../PassDetail.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
River Riddle23aa5a72022-02-26 22:49:5413#include "mlir/Dialect/Func/IR/FuncOps.h"
gysitb7f2c102021-12-15 12:14:3514#include "mlir/Dialect/Linalg/IR/Linalg.h"
Nicolas Vasilachee0dc3db2020-10-09 14:31:5215#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
Julian Grosse2310702021-02-10 12:53:1116#include "mlir/Dialect/MemRef/IR/MemRef.h"
Nicolas Vasilachef1b97202020-05-15 04:22:2117#include "mlir/Dialect/SCF/SCF.h"
Nicolas Vasilachef1b97202020-05-15 04:22:2118
19using namespace mlir;
20using namespace mlir::linalg;
21
22/// Helper function to extract the operand types that are passed to the
23/// generated CallOp. MemRefTypes have their layout canonicalized since the
24/// information is not used in signature generation.
25/// Note that static size information is not modified.
Nicolas Vasilachef1b97202020-05-15 04:22:2126static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
27 SmallVector<Type, 4> result;
28 result.reserve(op->getNumOperands());
Nicolas Vasilachef1b97202020-05-15 04:22:2129 for (auto type : op->getOperandTypes()) {
30 // The underlying descriptor type (e.g. LLVM) does not have layout
31 // information. Canonicalizing the type at the level of std when going into
32 // a library call avoids needing to introduce DialectCastOp.
33 if (auto memrefType = type.dyn_cast<MemRefType>())
34 result.push_back(eraseStridedLayout(memrefType));
35 else
36 result.push_back(type);
37 }
38 return result;
39}
40
Nicolas Vasilachef1b97202020-05-15 04:22:2141// Get a SymbolRefAttr containing the library function name for the LinalgOp.
42// If the library function does not exist, insert a declaration.
Nicolas Vasilachef1b97202020-05-15 04:22:2143static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
44 PatternRewriter &rewriter) {
45 auto linalgOp = cast<LinalgOp>(op);
46 auto fnName = linalgOp.getLibraryCallName();
47 if (fnName.empty()) {
48 op->emitWarning("No library call defined for: ") << *op;
49 return {};
50 }
51
52 // fnName is a dynamic std::string, unique it via a SymbolRefAttr.
Chris Lattnerfaf1c222021-08-30 16:31:4853 FlatSymbolRefAttr fnNameAttr =
54 SymbolRefAttr::get(rewriter.getContext(), fnName);
Nicolas Vasilachef1b97202020-05-15 04:22:2155 auto module = op->getParentOfType<ModuleOp>();
Chris Lattner41d4aa72021-08-29 21:22:2456 if (module.lookupSymbol(fnNameAttr.getAttr()))
Nicolas Vasilachef1b97202020-05-15 04:22:2157 return fnNameAttr;
Nicolas Vasilachef1b97202020-05-15 04:22:2158
Nicolas Vasilachee0dc3db2020-10-09 14:31:5259 SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
Nicolas Vasilachef1b97202020-05-15 04:22:2160 assert(op->getNumResults() == 0 &&
61 "Library call for linalg operation can be generated only for ops that "
62 "have void return types");
River Riddle1b97cdf2020-12-17 20:24:4563 auto libFnType = rewriter.getFunctionType(inputTypes, {});
Nicolas Vasilachef1b97202020-05-15 04:22:2164
65 OpBuilder::InsertionGuard guard(rewriter);
66 // Insert before module terminator.
67 rewriter.setInsertionPoint(module.getBody(),
68 std::prev(module.getBody()->end()));
69 FuncOp funcOp =
Rahul Joshi74145d52020-07-07 23:15:4470 rewriter.create<FuncOp>(op->getLoc(), fnNameAttr.getValue(), libFnType);
Nicolas Vasilachef1b97202020-05-15 04:22:2171 // Insert a function attribute that will trigger the emission of the
72 // corresponding `_mlir_ciface_xxx` interface so that external libraries see
73 // a normalized ABI. This interface is added during std to llvm conversion.
Christian Sigg1ffc1aa2020-12-12 09:50:4174 funcOp->setAttr("llvm.emit_c_interface", UnitAttr::get(op->getContext()));
Rahul Joshib7382ed2020-11-13 21:04:5375 funcOp.setPrivate();
Nicolas Vasilachef1b97202020-05-15 04:22:2176 return fnNameAttr;
77}
78
Nicolas Vasilachee0dc3db2020-10-09 14:31:5279static SmallVector<Value, 4>
Nicolas Vasilachef1b97202020-05-15 04:22:2180createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
81 ValueRange operands) {
82 SmallVector<Value, 4> res;
83 res.reserve(operands.size());
84 for (auto op : operands) {
85 auto memrefType = op.getType().dyn_cast<MemRefType>();
86 if (!memrefType) {
87 res.push_back(op);
88 continue;
89 }
90 Value cast =
Julian Grosse2310702021-02-10 12:53:1191 b.create<memref::CastOp>(loc, eraseStridedLayout(memrefType), op);
Nicolas Vasilachef1b97202020-05-15 04:22:2192 res.push_back(cast);
93 }
94 return res;
95}
96
Nicolas Vasilachee0dc3db2020-10-09 14:31:5297LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
River Riddle76f3c2f2021-03-23 20:44:1498 LinalgOp op, PatternRewriter &rewriter) const {
Nicolas Vasilachee0dc3db2020-10-09 14:31:5299 auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
Tobias Gysif1844f12021-06-23 09:06:04100 if (!libraryCallName)
Nicolas Vasilachee0dc3db2020-10-09 14:31:52101 return failure();
Nicolas Vasilachef1b97202020-05-15 04:22:21102
Tobias Gysi27ad2132021-04-19 12:23:11103 // TODO: Add support for more complex library call signatures that include
104 // indices or captured values.
River Riddle23aa5a72022-02-26 22:49:54105 rewriter.replaceOpWithNewOp<func::CallOp>(
Nicolas Vasilachee0dc3db2020-10-09 14:31:52106 op, libraryCallName.getValue(), TypeRange(),
107 createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
108 op->getOperands()));
109 return success();
110}
Nicolas Vasilachef1b97202020-05-15 04:22:21111
Alexander Belyaev25bf6a22022-01-31 17:51:39112
Nicolas Vasilachef1b97202020-05-15 04:22:21113/// Populate the given list with patterns that convert from Linalg to Standard.
Nicolas Vasilachee0dc3db2020-10-09 14:31:52114void mlir::linalg::populateLinalgToStandardConversionPatterns(
Chris Lattnerdc4e9132021-03-22 23:58:34115 RewritePatternSet &patterns) {
River Riddle9db53a12020-07-07 08:35:23116 // TODO: ConvOp conversion needs to export a descriptor with relevant
Nicolas Vasilachef1b97202020-05-15 04:22:21117 // attribute values such as kernel striding and dilation.
Alexander Belyaevebc81532022-02-01 17:07:33118 patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
Nicolas Vasilachef1b97202020-05-15 04:22:21119}
120
121namespace {
122struct ConvertLinalgToStandardPass
123 : public ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
124 void runOnOperation() override;
125};
126} // namespace
127
128void ConvertLinalgToStandardPass::runOnOperation() {
129 auto module = getOperation();
130 ConversionTarget target(getContext());
Mogballa54f4ea2021-10-12 23:14:57131 target.addLegalDialect<AffineDialect, arith::ArithmeticDialect,
River Riddle23aa5a72022-02-26 22:49:54132 func::FuncDialect, memref::MemRefDialect,
133 scf::SCFDialect>();
134 target.addLegalOp<ModuleOp, FuncOp, func::ReturnOp>();
Chris Lattnerdc4e9132021-03-22 23:58:34135 RewritePatternSet patterns(&getContext());
Chris Lattner3a506b32021-03-20 23:29:41136 populateLinalgToStandardConversionPatterns(patterns);
River Riddle3fffffa82020-10-27 00:25:01137 if (failed(applyFullConversion(module, target, std::move(patterns))))
Nicolas Vasilachef1b97202020-05-15 04:22:21138 signalPassFailure();
139}
140
141std::unique_ptr<OperationPass<ModuleOp>>
142mlir::createConvertLinalgToStandardPass() {
143 return std::make_unique<ConvertLinalgToStandardPass>();
144}