blob: bdeba081ca4c760dba724ae77a44fbe232c0bf4b [file] [log] [blame]
Alexander Belyaev71c10802020-06-16 11:49:541//===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
Frederik Gossen37133142020-06-03 16:14:422//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10
11#include "../PassDetail.h"
Mogballa54f4ea2021-10-12 23:14:5712#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
River Riddle23aa5a72022-02-26 22:49:5413#include "mlir/Dialect/Func/IR/FuncOps.h"
Frederik Gossen37133142020-06-03 16:14:4214#include "mlir/Dialect/SCF/SCF.h"
15#include "mlir/Dialect/Shape/IR/Shape.h"
Sean Silva444822d2020-12-11 22:20:0316#include "mlir/Dialect/Tensor/IR/Tensor.h"
Frederik Gossena70f2eb2020-09-07 13:58:0117#include "mlir/IR/BlockAndValueMapping.h"
Tres Poppf30f3472021-02-01 08:49:5418#include "mlir/IR/ImplicitLocOpBuilder.h"
Frederik Gossen37133142020-06-03 16:14:4219#include "mlir/Transforms/DialectConversion.h"
Tres Poppf30f3472021-02-01 08:49:5420#include "llvm/ADT/STLExtras.h"
Frederik Gossen37133142020-06-03 16:14:4221
Frederik Gossen24edbdf2020-06-08 08:58:0622using namespace mlir;
Alexander Belyaev80be54c2020-06-08 15:48:0123using namespace mlir::shape;
Frederik Gossena70f2eb2020-09-07 13:58:0124using namespace mlir::scf;
Frederik Gossen24edbdf2020-06-08 08:58:0625
Frederik Gossen37133142020-06-03 16:14:4226/// Conversion patterns.
Frederik Gossen4baf18d2020-07-24 08:53:5427namespace {
Frederik Gossen9df6afb2020-07-13 08:28:1328class AnyOpConversion : public OpConversionPattern<AnyOp> {
29public:
30 using OpConversionPattern<AnyOp>::OpConversionPattern;
31
32 LogicalResult
River Riddleb54c7242021-09-24 17:50:5833 matchAndRewrite(AnyOp op, OpAdaptor adaptor,
Frederik Gossen4baf18d2020-07-24 08:53:5434 ConversionPatternRewriter &rewriter) const override;
Frederik Gossen9df6afb2020-07-13 08:28:1335};
Frederik Gossen4baf18d2020-07-24 08:53:5436} // namespace
Frederik Gossen9df6afb2020-07-13 08:28:1337
Frederik Gossen4baf18d2020-07-24 08:53:5438LogicalResult
River Riddleb54c7242021-09-24 17:50:5839AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
Frederik Gossen4baf18d2020-07-24 08:53:5440 ConversionPatternRewriter &rewriter) const {
Frederik Gossen4baf18d2020-07-24 08:53:5441 // Replace `any` with its first operand.
42 // Any operand would be a valid substitution.
Jacques Pienaarcfb72fd32021-10-25 01:36:3343 rewriter.replaceOp(op, {adaptor.getInputs().front()});
Frederik Gossen4baf18d2020-07-24 08:53:5444 return success();
45}
46
47namespace {
Alexander Belyaev80be54c2020-06-08 15:48:0148template <typename SrcOpTy, typename DstOpTy>
49class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
50public:
51 using OpConversionPattern<SrcOpTy>::OpConversionPattern;
52
53 LogicalResult
River Riddleb54c7242021-09-24 17:50:5854 matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
Alexander Belyaev80be54c2020-06-08 15:48:0155 ConversionPatternRewriter &rewriter) const override {
Frederik Gossen6673c6c2020-07-29 13:53:4156 // For now, only error-free types are supported by this lowering.
57 if (op.getType().template isa<SizeType>())
58 return failure();
59
Jacques Pienaarcfb72fd32021-10-25 01:36:3360 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
61 adaptor.getRhs());
Alexander Belyaev80be54c2020-06-08 15:48:0162 return success();
63 }
64};
Frederik Gossen4baf18d2020-07-24 08:53:5465} // namespace
Alexander Belyaev80be54c2020-06-08 15:48:0166
Frederik Gossen4baf18d2020-07-24 08:53:5467namespace {
Frederik Gossena70f2eb2020-09-07 13:58:0168struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
69 using OpConversionPattern<BroadcastOp>::OpConversionPattern;
Stephan Herhut5d9f33a2020-07-28 17:08:4070
71 LogicalResult
River Riddleb54c7242021-09-24 17:50:5872 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
Frederik Gossen4baf18d2020-07-24 08:53:5473 ConversionPatternRewriter &rewriter) const override;
Frederik Gossenac3e5c42020-06-19 15:09:3674};
Tres Poppf30f3472021-02-01 08:49:5475
76// Get the resulting extent in a given dimension. This is computed with any
77// number of extent tensors and shifted offsets into them.
78Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
79 ValueRange rankDiffs, Value outputDimension) {
Mogballa54f4ea2021-10-12 23:14:5780 Value one = lb.create<arith::ConstantIndexOp>(1);
Tres Poppf30f3472021-02-01 08:49:5481 Value broadcastedDim = one;
82 for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
83 Value shape = std::get<0>(tup);
84 Value rankDiff = std::get<1>(tup);
Mogballa54f4ea2021-10-12 23:14:5785 Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
86 outputDimension, rankDiff);
Tres Poppf30f3472021-02-01 08:49:5487 Type indexTy = lb.getIndexType();
88 broadcastedDim =
89 lb.create<IfOp>(
90 TypeRange{indexTy}, outOfBounds,
91 [&](OpBuilder &b, Location loc) {
92 b.create<scf::YieldOp>(loc, broadcastedDim);
93 },
94 [&](OpBuilder &b, Location loc) {
95 // The broadcasting logic is:
96 // - if one extent (here we arbitrarily choose the
97 // extent from the greater-rank operand) is equal to 1,
98 // then take the extent from the other operand
99 // - otherwise, take the extent as-is.
100 // Note that this logic remains correct in the presence
101 // of dimensions of zero extent.
Mogballa54f4ea2021-10-12 23:14:57102 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
103 loc, indexTy, outputDimension, rankDiff);
Tres Poppf30f3472021-02-01 08:49:54104 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
105 loc, shape, ValueRange{lesserRankOperandDimension});
106
Mogballa54f4ea2021-10-12 23:14:57107 Value dimIsOne =
108 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
109 lesserRankOperandExtent, one);
River Riddledec8af72022-01-31 20:44:35110 Value dim = b.create<arith::SelectOp>(
111 loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
Tres Poppf30f3472021-02-01 08:49:54112 b.create<scf::YieldOp>(loc, dim);
113 })
114 .getResult(0);
115 }
116 return broadcastedDim;
117}
Frederik Gossen4baf18d2020-07-24 08:53:54118} // namespace
Frederik Gossenac3e5c42020-06-19 15:09:36119
Frederik Gossena70f2eb2020-09-07 13:58:01120LogicalResult BroadcastOpConverter::matchAndRewrite(
River Riddleb54c7242021-09-24 17:50:58121 BroadcastOp op, OpAdaptor adaptor,
Frederik Gossen4baf18d2020-07-24 08:53:54122 ConversionPatternRewriter &rewriter) const {
Frederik Gossena70f2eb2020-09-07 13:58:01123 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
124 // on shapes.
Frederik Gossen6673c6c2020-07-29 13:53:41125 if (op.getType().isa<ShapeType>())
126 return failure();
Frederik Gossen4baf18d2020-07-24 08:53:54127
Frederik Gossen6673c6c2020-07-29 13:53:41128 auto loc = op.getLoc();
Tres Poppf30f3472021-02-01 08:49:54129 ImplicitLocOpBuilder lb(loc, rewriter);
Frederik Gossen4baf18d2020-07-24 08:53:54130
Mogballa54f4ea2021-10-12 23:14:57131 Value zero = lb.create<arith::ConstantIndexOp>(0);
Tres Poppf30f3472021-02-01 08:49:54132 Type indexTy = lb.getIndexType();
Frederik Gossena70f2eb2020-09-07 13:58:01133
Tres Poppf30f3472021-02-01 08:49:54134 // Save all the ranks for bounds checking. Because this is a tensor
135 // representing the shape extents, the rank is the extent of the only
136 // dimension in the tensor.
137 SmallVector<Value> ranks, rankDiffs;
Jacques Pienaarcfb72fd32021-10-25 01:36:33138 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
Matthias Springerc0a63182021-07-01 00:58:48139 return lb.create<tensor::DimOp>(v, zero);
Tres Poppf30f3472021-02-01 08:49:54140 }));
141
142 // Find the maximum rank
143 Value maxRank = ranks.front();
144 for (Value v : llvm::drop_begin(ranks, 1)) {
Mogballa54f4ea2021-10-12 23:14:57145 Value rankIsGreater =
146 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
River Riddledec8af72022-01-31 20:44:35147 maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
Tres Poppf30f3472021-02-01 08:49:54148 }
149
150 // Calculate the difference of ranks and the maximum rank for later offsets.
151 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
Mogballa54f4ea2021-10-12 23:14:57152 return lb.create<arith::SubIOp>(indexTy, maxRank, v);
Tres Poppf30f3472021-02-01 08:49:54153 }));
154
Frederik Gosseneb56fa92021-04-29 08:07:20155 Value replacement = lb.create<tensor::GenerateOp>(
156 getExtentTensorType(lb.getContext()), ValueRange{maxRank},
157 [&](OpBuilder &b, Location loc, ValueRange args) {
Jacques Pienaarcfb72fd32021-10-25 01:36:33158 Value broadcastedDim =
159 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
160 rankDiffs, args[0]);
Tres Poppf30f3472021-02-01 08:49:54161
Frederik Gosseneb56fa92021-04-29 08:07:20162 b.create<tensor::YieldOp>(loc, broadcastedDim);
163 });
164 if (replacement.getType() != op.getType())
165 replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
166 rewriter.replaceOp(op, replacement);
Frederik Gossen4baf18d2020-07-24 08:53:54167 return success();
168}
169
170namespace {
Frederik Gossendfcc0982020-07-28 15:39:49171class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
172public:
173 using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
174
175 LogicalResult
River Riddleb54c7242021-09-24 17:50:58176 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
Frederik Gossendfcc0982020-07-28 15:39:49177 ConversionPatternRewriter &rewriter) const override;
178};
179} // namespace
180
181LogicalResult ConstShapeOpConverter::matchAndRewrite(
River Riddleb54c7242021-09-24 17:50:58182 ConstShapeOp op, OpAdaptor adaptor,
Frederik Gossendfcc0982020-07-28 15:39:49183 ConversionPatternRewriter &rewriter) const {
184
185 // For now, this lowering supports only extent tensors, not `shape.shape`
186 // types.
187 if (op.getType().isa<ShapeType>())
188 return failure();
189
190 auto loc = op.getLoc();
191 SmallVector<Value, 4> extentOperands;
Jacques Pienaarcfb72fd32021-10-25 01:36:33192 for (auto extent : op.getShape()) {
Frederik Gossendfcc0982020-07-28 15:39:49193 extentOperands.push_back(
Mogballa54f4ea2021-10-12 23:14:57194 rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
Frederik Gossendfcc0982020-07-28 15:39:49195 }
Alexander Belyaevf77e9f82021-12-16 13:42:27196 Type resultTy =
197 RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
Sean Silva84a6da62020-09-11 05:04:58198 Value tensor =
Alexander Belyaevf77e9f82021-12-16 13:42:27199 rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
Sean Silva129d6e52020-12-16 00:47:19200 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
Frederik Gossendfcc0982020-07-28 15:39:49201 return success();
202}
203
204namespace {
Frederik Gossena70f2eb2020-09-07 13:58:01205class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
Stephan Herhut5d9f33a2020-07-28 17:08:40206public:
Frederik Gossena70f2eb2020-09-07 13:58:01207 using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
Stephan Herhut5d9f33a2020-07-28 17:08:40208
209 LogicalResult
River Riddleb54c7242021-09-24 17:50:58210 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01211 ConversionPatternRewriter &rewriter) const override;
Stephan Herhut5d9f33a2020-07-28 17:08:40212};
213} // namespace
214
Frederik Gossena70f2eb2020-09-07 13:58:01215LogicalResult ConstSizeOpConversion::matchAndRewrite(
River Riddleb54c7242021-09-24 17:50:58216 ConstSizeOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01217 ConversionPatternRewriter &rewriter) const {
Mogballa54f4ea2021-10-12 23:14:57218 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
Jacques Pienaarcfb72fd32021-10-25 01:36:33219 op, op.getValue().getSExtValue());
Frederik Gossena70f2eb2020-09-07 13:58:01220 return success();
221}
222
Stephan Herhut5d9f33a2020-07-28 17:08:40223namespace {
Tres Popp511484f2020-10-29 16:13:26224struct IsBroadcastableOpConverter
225 : public OpConversionPattern<IsBroadcastableOp> {
226 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
227
228 LogicalResult
River Riddleb54c7242021-09-24 17:50:58229 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
Tres Popp511484f2020-10-29 16:13:26230 ConversionPatternRewriter &rewriter) const override;
231};
232} // namespace
233
234LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
River Riddleb54c7242021-09-24 17:50:58235 IsBroadcastableOp op, OpAdaptor adaptor,
Tres Popp511484f2020-10-29 16:13:26236 ConversionPatternRewriter &rewriter) const {
237 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
238 // on shapes.
Jacques Pienaarcfb72fd32021-10-25 01:36:33239 if (!llvm::all_of(op.getShapes(),
Tres Popp3842d4b2021-02-10 09:24:32240 [](Value v) { return !v.getType().isa<ShapeType>(); }))
Tres Popp511484f2020-10-29 16:13:26241 return failure();
242
243 auto loc = op.getLoc();
Tres Popp3842d4b2021-02-10 09:24:32244 ImplicitLocOpBuilder lb(loc, rewriter);
Mogballa54f4ea2021-10-12 23:14:57245 Value zero = lb.create<arith::ConstantIndexOp>(0);
246 Value one = lb.create<arith::ConstantIndexOp>(1);
Tres Popp3842d4b2021-02-10 09:24:32247 Type indexTy = lb.getIndexType();
Tres Popp511484f2020-10-29 16:13:26248
Tres Popp3842d4b2021-02-10 09:24:32249 // Save all the ranks for bounds checking. Because this is a tensor
250 // representing the shape extents, the rank is the extent of the only
251 // dimension in the tensor.
252 SmallVector<Value> ranks, rankDiffs;
Jacques Pienaarcfb72fd32021-10-25 01:36:33253 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
Matthias Springerc0a63182021-07-01 00:58:48254 return lb.create<tensor::DimOp>(v, zero);
Tres Popp3842d4b2021-02-10 09:24:32255 }));
256
257 // Find the maximum rank
258 Value maxRank = ranks.front();
259 for (Value v : llvm::drop_begin(ranks, 1)) {
Mogballa54f4ea2021-10-12 23:14:57260 Value rankIsGreater =
261 lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank);
River Riddledec8af72022-01-31 20:44:35262 maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank);
Tres Popp3842d4b2021-02-10 09:24:32263 }
264
265 // Calculate the difference of ranks and the maximum rank for later offsets.
266 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
Mogballa54f4ea2021-10-12 23:14:57267 return lb.create<arith::SubIOp>(indexTy, maxRank, v);
Tres Popp3842d4b2021-02-10 09:24:32268 }));
269
Tres Popp511484f2020-10-29 16:13:26270 Type i1Ty = rewriter.getI1Type();
Tres Popp3842d4b2021-02-10 09:24:32271 Value trueVal =
Mogballa54f4ea2021-10-12 23:14:57272 rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
Tres Popp511484f2020-10-29 16:13:26273
Tres Popp3842d4b2021-02-10 09:24:32274 auto reduceResult = lb.create<ForOp>(
275 loc, zero, maxRank, one, ValueRange{trueVal},
Tres Popp511484f2020-10-29 16:13:26276 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
Tres Popp3842d4b2021-02-10 09:24:32277 // Find a non-1 dim, if it exists. Note that the first part of this
278 // could reuse the Broadcast lowering entirely, but we redo the work
279 // here to make optimizations easier between the two loops.
280 Value broadcastedDim = getBroadcastedDim(
Jacques Pienaarcfb72fd32021-10-25 01:36:33281 ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
Tres Popp3842d4b2021-02-10 09:24:32282
283 Value broadcastable = iterArgs[0];
Jacques Pienaarcfb72fd32021-10-25 01:36:33284 for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
Tres Popp3842d4b2021-02-10 09:24:32285 Value shape, rankDiff;
286 std::tie(shape, rankDiff) = tup;
Mogballa54f4ea2021-10-12 23:14:57287 Value outOfBounds = b.create<arith::CmpIOp>(
288 loc, arith::CmpIPredicate::ult, iv, rankDiff);
Tres Popp3842d4b2021-02-10 09:24:32289 broadcastable =
290 b.create<IfOp>(
291 loc, TypeRange{i1Ty}, outOfBounds,
292 [&](OpBuilder &b, Location loc) {
293 // Non existent dimensions are always broadcastable
294 b.create<scf::YieldOp>(loc, broadcastable);
295 },
296 [&](OpBuilder &b, Location loc) {
297 // Every value needs to be either 1, or the same non-1
298 // value to be broadcastable in this dim.
299 Value operandDimension =
Mogballa54f4ea2021-10-12 23:14:57300 b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
Tres Popp3842d4b2021-02-10 09:24:32301 Value dimensionExtent = b.create<tensor::ExtractOp>(
302 loc, shape, ValueRange{operandDimension});
303
Mogballa54f4ea2021-10-12 23:14:57304 Value equalOne = b.create<arith::CmpIOp>(
305 loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306 Value equalBroadcasted = b.create<arith::CmpIOp>(
307 loc, arith::CmpIPredicate::eq, dimensionExtent,
308 broadcastedDim);
309 Value result = b.create<arith::AndIOp>(
Tres Popp3842d4b2021-02-10 09:24:32310 loc, broadcastable,
Mogballa54f4ea2021-10-12 23:14:57311 b.create<arith::OrIOp>(loc, equalOne,
312 equalBroadcasted));
Tres Popp3842d4b2021-02-10 09:24:32313 b.create<scf::YieldOp>(loc, result);
314 })
315 .getResult(0);
316 }
317
318 b.create<scf::YieldOp>(loc, broadcastable);
Tres Popp511484f2020-10-29 16:13:26319 });
320
Jacques Pienaarc0342a22021-12-20 16:03:43321 rewriter.replaceOp(op, reduceResult.getResults().front());
Tres Popp511484f2020-10-29 16:13:26322 return success();
323}
324
325namespace {
Frederik Gossen8577a092020-06-30 08:33:49326class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
327 using OpConversionPattern<GetExtentOp>::OpConversionPattern;
328
329 LogicalResult
River Riddleb54c7242021-09-24 17:50:58330 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
Frederik Gossen4baf18d2020-07-24 08:53:54331 ConversionPatternRewriter &rewriter) const override;
332};
333} // namespace
Frederik Gossen8577a092020-06-30 08:33:49334
Frederik Gossen4baf18d2020-07-24 08:53:54335LogicalResult GetExtentOpConverter::matchAndRewrite(
River Riddleb54c7242021-09-24 17:50:58336 GetExtentOp op, OpAdaptor adaptor,
Frederik Gossen4baf18d2020-07-24 08:53:54337 ConversionPatternRewriter &rewriter) const {
Frederik Gossen6673c6c2020-07-29 13:53:41338 // For now, only error-free types are supported by this lowering.
339 if (op.getType().isa<SizeType>())
340 return failure();
341
342 // Derive shape extent directly from shape origin if possible. This
343 // circumvents the necessity to materialize the shape in memory.
Jacques Pienaarcfb72fd32021-10-25 01:36:33344 if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
345 if (shapeOfOp.getArg().getType().isa<ShapedType>()) {
346 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
347 adaptor.getDim());
Frederik Gossen6673c6c2020-07-29 13:53:41348 return success();
349 }
Frederik Gossen8577a092020-06-30 08:33:49350 }
Frederik Gossen8577a092020-06-30 08:33:49351
Jacques Pienaarcfb72fd32021-10-25 01:36:33352 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
353 adaptor.getShape(),
354 ValueRange{adaptor.getDim()});
Frederik Gossen4baf18d2020-07-24 08:53:54355 return success();
356}
357
358namespace {
Frederik Gossen24debf52020-06-25 08:42:40359class RankOpConverter : public OpConversionPattern<shape::RankOp> {
360public:
361 using OpConversionPattern<shape::RankOp>::OpConversionPattern;
362
363 LogicalResult
River Riddleb54c7242021-09-24 17:50:58364 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
Frederik Gossen4baf18d2020-07-24 08:53:54365 ConversionPatternRewriter &rewriter) const override;
Frederik Gossen24debf52020-06-25 08:42:40366};
Frederik Gossen4baf18d2020-07-24 08:53:54367} // namespace
Frederik Gossen24debf52020-06-25 08:42:40368
Frederik Gossen4baf18d2020-07-24 08:53:54369LogicalResult
River Riddleb54c7242021-09-24 17:50:58370RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
Frederik Gossen4baf18d2020-07-24 08:53:54371 ConversionPatternRewriter &rewriter) const {
Frederik Gossena97940d2020-07-30 11:40:16372 // For now, this lowering supports only error-free types.
373 if (op.getType().isa<SizeType>())
374 return failure();
375
Jacques Pienaarcfb72fd32021-10-25 01:36:33376 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
Frederik Gossen4baf18d2020-07-24 08:53:54377 return success();
378}
379
380namespace {
Frederik Gossena70f2eb2020-09-07 13:58:01381/// Converts `shape.reduce` to `scf.for`.
382struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
383public:
384 using OpConversionPattern::OpConversionPattern;
385
386 LogicalResult
River Riddleb54c7242021-09-24 17:50:58387 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01388 ConversionPatternRewriter &rewriter) const final;
389};
390} // namespace
391
392LogicalResult
River Riddleb54c7242021-09-24 17:50:58393ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01394 ConversionPatternRewriter &rewriter) const {
395 // For now, this lowering is only defined on `tensor<?xindex>` operands.
Jacques Pienaarcfb72fd32021-10-25 01:36:33396 if (op.getShape().getType().isa<ShapeType>())
Frederik Gossena70f2eb2020-09-07 13:58:01397 return failure();
398
399 auto loc = op.getLoc();
Frederik Gossena70f2eb2020-09-07 13:58:01400
Mogballa54f4ea2021-10-12 23:14:57401 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
402 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Frederik Gossena70f2eb2020-09-07 13:58:01403 Type indexTy = rewriter.getIndexType();
Julian Grosse2310702021-02-10 12:53:11404 Value rank =
Jacques Pienaarcfb72fd32021-10-25 01:36:33405 rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
Frederik Gossena70f2eb2020-09-07 13:58:01406
407 auto loop = rewriter.create<scf::ForOp>(
Jacques Pienaarcfb72fd32021-10-25 01:36:33408 loc, zero, rank, one, op.getInitVals(),
Frederik Gossena70f2eb2020-09-07 13:58:01409 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
Jacques Pienaarcfb72fd32021-10-25 01:36:33410 Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
Frederik Gossena70f2eb2020-09-07 13:58:01411
412 SmallVector<Value, 2> mappedValues{iv, extent};
413 mappedValues.append(args.begin(), args.end());
414
415 BlockAndValueMapping mapping;
416 Block *reduceBody = op.getBody();
417 mapping.map(reduceBody->getArguments(), mappedValues);
418 for (auto &nested : reduceBody->without_terminator())
419 b.clone(nested, mapping);
420
421 SmallVector<Value, 2> mappedResults;
422 for (auto result : reduceBody->getTerminator()->getOperands())
423 mappedResults.push_back(mapping.lookup(result));
424 b.create<scf::YieldOp>(loc, mappedResults);
425 });
426
427 rewriter.replaceOp(op, loop.getResults());
428 return success();
429}
430
431namespace {
432/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
433/// only defined on `tensor<?xindex>` operands. The test for equality first
434/// compares their size and, if equal, checks every extent for equality.
435///
436/// Example:
437///
438/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
439///
440/// becomes
441///
Mogballcb3aa492021-10-14 16:55:33442/// %c0 = arith.constant 0 : index
Frederik Gossena70f2eb2020-09-07 13:58:01443/// %0 = dim %arg0, %c0 : tensor<?xindex>
444/// %1 = dim %arg1, %c0 : tensor<?xindex>
Mogballa54f4ea2021-10-12 23:14:57445/// %2 = arith.cmpi "eq", %0, %1 : index
Frederik Gossena70f2eb2020-09-07 13:58:01446/// %result = scf.if %2 -> (i1) {
Mogballa54f4ea2021-10-12 23:14:57447/// %c1 = arith.constant 1 : index
448/// %true = arith.constant true
Frederik Gossena70f2eb2020-09-07 13:58:01449/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
Sean Silva444822d2020-12-11 22:20:03450/// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
451/// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
Mogballa54f4ea2021-10-12 23:14:57452/// %7 = arith.cmpi "eq", %5, %6 : index
453/// %8 = arith.andi %arg3, %7 : i1
Frederik Gossena70f2eb2020-09-07 13:58:01454/// scf.yield %8 : i1
455/// }
456/// scf.yield %4 : i1
457/// } else {
Mogballa54f4ea2021-10-12 23:14:57458/// %false = arith.constant false
Frederik Gossena70f2eb2020-09-07 13:58:01459/// scf.yield %false : i1
460/// }
461///
462struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
463 using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
464
465 LogicalResult
River Riddleb54c7242021-09-24 17:50:58466 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01467 ConversionPatternRewriter &rewriter) const override;
468};
469} // namespace
470
471LogicalResult
River Riddleb54c7242021-09-24 17:50:58472ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01473 ConversionPatternRewriter &rewriter) const {
Jacques Pienaarcfb72fd32021-10-25 01:36:33474 if (!llvm::all_of(op.getShapes(),
Benjamin Kramer24acade2021-03-01 19:34:17475 [](Value v) { return !v.getType().isa<ShapeType>(); }))
Frederik Gossena70f2eb2020-09-07 13:58:01476 return failure();
Benjamin Kramer24acade2021-03-01 19:34:17477
478 Type i1Ty = rewriter.getI1Type();
Jacques Pienaarcfb72fd32021-10-25 01:36:33479 if (op.getShapes().size() <= 1) {
Mogballa54f4ea2021-10-12 23:14:57480 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
481 rewriter.getBoolAttr(true));
Benjamin Kramer24acade2021-03-01 19:34:17482 return success();
Frederik Gossena70f2eb2020-09-07 13:58:01483 }
484
Frederik Gossena70f2eb2020-09-07 13:58:01485 auto loc = op.getLoc();
486 Type indexTy = rewriter.getIndexType();
Mogballa54f4ea2021-10-12 23:14:57487 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Jacques Pienaarcfb72fd32021-10-25 01:36:33488 Value firstShape = adaptor.getShapes().front();
Julian Grosse2310702021-02-10 12:53:11489 Value firstRank =
Matthias Springerc0a63182021-07-01 00:58:48490 rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
Benjamin Kramer24acade2021-03-01 19:34:17491 Value result = nullptr;
492 // Generate a linear sequence of compares, all with firstShape as lhs.
Jacques Pienaarcfb72fd32021-10-25 01:36:33493 for (Value shape : adaptor.getShapes().drop_front(1)) {
Matthias Springerc0a63182021-07-01 00:58:48494 Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
Mogballa54f4ea2021-10-12 23:14:57495 Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
496 firstRank, rank);
Benjamin Kramer24acade2021-03-01 19:34:17497 auto same = rewriter.create<IfOp>(
498 loc, i1Ty, eqRank,
499 [&](OpBuilder &b, Location loc) {
Mogballa54f4ea2021-10-12 23:14:57500 Value one = b.create<arith::ConstantIndexOp>(loc, 1);
501 Value init =
502 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
Benjamin Kramer24acade2021-03-01 19:34:17503 auto loop = b.create<scf::ForOp>(
504 loc, zero, firstRank, one, ValueRange{init},
505 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
506 Value conj = args[0];
507 Value lhsExtent =
508 b.create<tensor::ExtractOp>(loc, firstShape, iv);
509 Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
Mogballa54f4ea2021-10-12 23:14:57510 Value eqExtent = b.create<arith::CmpIOp>(
511 loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
512 Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
Benjamin Kramer24acade2021-03-01 19:34:17513 b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
514 });
515 b.create<scf::YieldOp>(loc, loop.getResults());
516 },
517 [&](OpBuilder &b, Location loc) {
Mogballa54f4ea2021-10-12 23:14:57518 Value result =
519 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
Benjamin Kramer24acade2021-03-01 19:34:17520 b.create<scf::YieldOp>(loc, result);
521 });
522 result = !result ? same.getResult(0)
Mogballa54f4ea2021-10-12 23:14:57523 : rewriter.create<arith::AndIOp>(loc, result,
524 same.getResult(0));
Benjamin Kramer24acade2021-03-01 19:34:17525 }
526 rewriter.replaceOp(op, result);
Frederik Gossena70f2eb2020-09-07 13:58:01527 return success();
528}
529
530namespace {
531class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
532public:
533 using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
534
535 LogicalResult
River Riddleb54c7242021-09-24 17:50:58536 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01537 ConversionPatternRewriter &rewriter) const override;
538};
539} // namespace
540
541LogicalResult ShapeOfOpConversion::matchAndRewrite(
River Riddleb54c7242021-09-24 17:50:58542 ShapeOfOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01543 ConversionPatternRewriter &rewriter) const {
544
545 // For now, only error-free types are supported by this lowering.
546 if (op.getType().isa<ShapeType>())
547 return failure();
548
Sean Silvabe7352c2021-01-15 02:28:48549 // For ranked tensor arguments, lower to `tensor.from_elements`.
Frederik Gossen5106a8b2020-09-09 07:53:13550 auto loc = op.getLoc();
Jacques Pienaarcfb72fd32021-10-25 01:36:33551 Value tensor = adaptor.getArg();
Frederik Gossena70f2eb2020-09-07 13:58:01552 Type tensorTy = tensor.getType();
553 if (tensorTy.isa<RankedTensorType>()) {
554
555 // Build values for individual extents.
556 SmallVector<Value, 8> extentValues;
557 RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
558 int64_t rank = rankedTensorTy.getRank();
Frederik Gossena70f2eb2020-09-07 13:58:01559 for (int64_t i = 0; i < rank; i++) {
560 if (rankedTensorTy.isDynamicDim(i)) {
Matthias Springerc0a63182021-07-01 00:58:48561 Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
Frederik Gossena70f2eb2020-09-07 13:58:01562 extentValues.push_back(extent);
563 } else {
Mogballa54f4ea2021-10-12 23:14:57564 Value extent = rewriter.create<arith::ConstantIndexOp>(
565 loc, rankedTensorTy.getDimSize(i));
Frederik Gossena70f2eb2020-09-07 13:58:01566 extentValues.push_back(extent);
567 }
568 }
569
570 // Materialize extent tensor.
Sean Silvabe7352c2021-01-15 02:28:48571 Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
Alexander Belyaevf77e9f82021-12-16 13:42:27572 loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
573 extentValues);
Sean Silva129d6e52020-12-16 00:47:19574 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
575 staticExtentTensor);
Frederik Gossena70f2eb2020-09-07 13:58:01576 return success();
577 }
578
Sean Silvabe7352c2021-01-15 02:28:48579 // Lower to `tensor.generate` otherwise.
Frederik Gossen5106a8b2020-09-09 07:53:13580 auto *ctx = rewriter.getContext();
Alexander Belyaev15f8f3e2021-12-14 08:35:14581 Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
Sean Silvabe7352c2021-01-15 02:28:48582 rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
Frederik Gossen5106a8b2020-09-09 07:53:13583 op, getExtentTensorType(ctx), ValueRange{rank},
584 [&](OpBuilder &b, Location loc, ValueRange args) {
585 Value dim = args.front();
Matthias Springerc0a63182021-07-01 00:58:48586 Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
Sean Silvabe7352c2021-01-15 02:28:48587 b.create<tensor::YieldOp>(loc, extent);
Frederik Gossena70f2eb2020-09-07 13:58:01588 });
589
Frederik Gossena70f2eb2020-09-07 13:58:01590 return success();
591}
592
593namespace {
Benjamin Kramer42c195f2021-03-08 14:23:28594class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
595public:
596 using OpConversionPattern<SplitAtOp>::OpConversionPattern;
597
598 LogicalResult
River Riddleb54c7242021-09-24 17:50:58599 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
Benjamin Kramer42c195f2021-03-08 14:23:28600 ConversionPatternRewriter &rewriter) const override;
601};
602} // namespace
603
604LogicalResult SplitAtOpConversion::matchAndRewrite(
River Riddleb54c7242021-09-24 17:50:58605 SplitAtOp op, OpAdaptor adaptor,
Benjamin Kramer42c195f2021-03-08 14:23:28606 ConversionPatternRewriter &rewriter) const {
607 // Error conditions are not implemented, only lower if all operands and
608 // results are extent tensors.
Jacques Pienaarcfb72fd32021-10-25 01:36:33609 if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
Benjamin Kramer42c195f2021-03-08 14:23:28610 [](Value v) { return v.getType().isa<ShapeType>(); }))
611 return failure();
612
Benjamin Kramer42c195f2021-03-08 14:23:28613 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Mogballa54f4ea2021-10-12 23:14:57614 Value zero = b.create<arith::ConstantIndexOp>(0);
Jacques Pienaarcfb72fd32021-10-25 01:36:33615 Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
Benjamin Kramer42c195f2021-03-08 14:23:28616
617 // index < 0 ? index + rank : index
Jacques Pienaarcfb72fd32021-10-25 01:36:33618 Value originalIndex = adaptor.getIndex();
Mogballa54f4ea2021-10-12 23:14:57619 Value add = b.create<arith::AddIOp>(originalIndex, rank);
Benjamin Kramer42c195f2021-03-08 14:23:28620 Value indexIsNegative =
Mogballa54f4ea2021-10-12 23:14:57621 b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
River Riddledec8af72022-01-31 20:44:35622 Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
Benjamin Kramer42c195f2021-03-08 14:23:28623
Mogballa54f4ea2021-10-12 23:14:57624 Value one = b.create<arith::ConstantIndexOp>(1);
Matthias Springer060208b2021-06-22 07:49:08625 Value head =
Jacques Pienaarcfb72fd32021-10-25 01:36:33626 b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
Mogballa54f4ea2021-10-12 23:14:57627 Value tailSize = b.create<arith::SubIOp>(rank, index);
Jacques Pienaarcfb72fd32021-10-25 01:36:33628 Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
629 tailSize, one);
Benjamin Kramer42c195f2021-03-08 14:23:28630 rewriter.replaceOp(op, {head, tail});
631 return success();
632}
633
634namespace {
Frederik Gossena70f2eb2020-09-07 13:58:01635class ToExtentTensorOpConversion
636 : public OpConversionPattern<ToExtentTensorOp> {
637public:
638 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
639
640 LogicalResult
River Riddleb54c7242021-09-24 17:50:58641 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
Frederik Gossena70f2eb2020-09-07 13:58:01642 ConversionPatternRewriter &rewriter) const override {
Jacques Pienaarcfb72fd32021-10-25 01:36:33643 if (!adaptor.getInput().getType().isa<RankedTensorType>())
Frederik Gossena70f2eb2020-09-07 13:58:01644 return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
645
Sean Silva129d6e52020-12-16 00:47:19646 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
Jacques Pienaarcfb72fd32021-10-25 01:36:33647 adaptor.getInput());
Frederik Gossena70f2eb2020-09-07 13:58:01648 return success();
649 }
650};
651} // namespace
652
653namespace {
Tres Poppd05d4212020-10-13 15:56:45654/// Import the Shape Ops to Std Patterns.
655#include "ShapeToStandard.cpp.inc"
656} // namespace
657
658namespace {
Frederik Gossen37133142020-06-03 16:14:42659/// Conversion pass.
660class ConvertShapeToStandardPass
661 : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
Frederik Gosseneaf49132020-06-18 07:51:03662
Frederik Gossen4baf18d2020-07-24 08:53:54663 void runOnOperation() override;
Frederik Gossen37133142020-06-03 16:14:42664};
Frederik Gossen37133142020-06-03 16:14:42665} // namespace
666
Frederik Gossen4baf18d2020-07-24 08:53:54667void ConvertShapeToStandardPass::runOnOperation() {
Frederik Gossen4baf18d2020-07-24 08:53:54668 // Setup target legality.
Frederik Gossenb6b9d3e2020-07-29 10:45:07669 MLIRContext &ctx = getContext();
Frederik Gossen4baf18d2020-07-24 08:53:54670 ConversionTarget target(ctx);
River Riddle23aa5a72022-02-26 22:49:54671 target.addLegalDialect<arith::ArithmeticDialect, func::FuncDialect,
Mogballa54f4ea2021-10-12 23:14:57672 SCFDialect, tensor::TensorDialect>();
Mehdi Amini973ddb72021-03-11 23:58:02673 target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
Frederik Gossen4baf18d2020-07-24 08:53:54674
675 // Setup conversion patterns.
Chris Lattnerdc4e9132021-03-22 23:58:34676 RewritePatternSet patterns(&ctx);
Chris Lattner3a506b32021-03-20 23:29:41677 populateShapeToStandardConversionPatterns(patterns);
Frederik Gossen4baf18d2020-07-24 08:53:54678
679 // Apply conversion.
680 auto module = getOperation();
River Riddle3fffffa82020-10-27 00:25:01681 if (failed(applyPartialConversion(module, target, std::move(patterns))))
Frederik Gossen4baf18d2020-07-24 08:53:54682 signalPassFailure();
683}
684
Frederik Gossen24edbdf2020-06-08 08:58:06685void mlir::populateShapeToStandardConversionPatterns(
Chris Lattnerdc4e9132021-03-22 23:58:34686 RewritePatternSet &patterns) {
Frederik Gossen37133142020-06-03 16:14:42687 // clang-format off
Chris Lattner1d909c92021-03-21 17:38:35688 populateWithGenerated(patterns);
Chris Lattnerdc4e9132021-03-22 23:58:34689 patterns.add<
Frederik Gossen9df6afb2020-07-13 08:28:13690 AnyOpConversion,
Mogballa54f4ea2021-10-12 23:14:57691 BinaryOpConversion<AddOp, arith::AddIOp>,
692 BinaryOpConversion<MulOp, arith::MulIOp>,
Frederik Gossena70f2eb2020-09-07 13:58:01693 BroadcastOpConverter,
694 ConstShapeOpConverter,
Stephan Herhut5d9f33a2020-07-28 17:08:40695 ConstSizeOpConversion,
Tres Popp511484f2020-10-29 16:13:26696 IsBroadcastableOpConverter,
Frederik Gossen8577a092020-06-30 08:33:49697 GetExtentOpConverter,
Frederik Gossen24debf52020-06-25 08:42:40698 RankOpConverter,
Frederik Gossena70f2eb2020-09-07 13:58:01699 ReduceOpConverter,
700 ShapeEqOpConverter,
Stephan Herhut5d9f33a2020-07-28 17:08:40701 ShapeOfOpConversion,
Benjamin Kramer42c195f2021-03-08 14:23:28702 SplitAtOpConversion,
Chris Lattner3a506b32021-03-20 23:29:41703 ToExtentTensorOpConversion>(patterns.getContext());
Frederik Gossen37133142020-06-03 16:14:42704 // clang-format on
705}
706
Frederik Gossen24edbdf2020-06-08 08:58:06707std::unique_ptr<OperationPass<ModuleOp>>
708mlir::createConvertShapeToStandardPass() {
Frederik Gossen37133142020-06-03 16:14:42709 return std::make_unique<ConvertShapeToStandardPass>();
710}