Alexander Belyaev | 71c1080 | 2020-06-16 11:49:54 | [diff] [blame] | 1 | //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===// |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" |
| 10 | |
| 11 | #include "../PassDetail.h" |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 12 | #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
River Riddle | 23aa5a7 | 2022-02-26 22:49:54 | [diff] [blame^] | 13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 14 | #include "mlir/Dialect/SCF/SCF.h" |
| 15 | #include "mlir/Dialect/Shape/IR/Shape.h" |
Sean Silva | 444822d | 2020-12-11 22:20:03 | [diff] [blame] | 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 17 | #include "mlir/IR/BlockAndValueMapping.h" |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 18 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 19 | #include "mlir/Transforms/DialectConversion.h" |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 20 | #include "llvm/ADT/STLExtras.h" |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 21 | |
Frederik Gossen | 24edbdf | 2020-06-08 08:58:06 | [diff] [blame] | 22 | using namespace mlir; |
Alexander Belyaev | 80be54c | 2020-06-08 15:48:01 | [diff] [blame] | 23 | using namespace mlir::shape; |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 24 | using namespace mlir::scf; |
Frederik Gossen | 24edbdf | 2020-06-08 08:58:06 | [diff] [blame] | 25 | |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 26 | /// Conversion patterns. |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 27 | namespace { |
Frederik Gossen | 9df6afb | 2020-07-13 08:28:13 | [diff] [blame] | 28 | class AnyOpConversion : public OpConversionPattern<AnyOp> { |
| 29 | public: |
| 30 | using OpConversionPattern<AnyOp>::OpConversionPattern; |
| 31 | |
| 32 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 33 | matchAndRewrite(AnyOp op, OpAdaptor adaptor, |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 34 | ConversionPatternRewriter &rewriter) const override; |
Frederik Gossen | 9df6afb | 2020-07-13 08:28:13 | [diff] [blame] | 35 | }; |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 36 | } // namespace |
Frederik Gossen | 9df6afb | 2020-07-13 08:28:13 | [diff] [blame] | 37 | |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 38 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 39 | AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor, |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 40 | ConversionPatternRewriter &rewriter) const { |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 41 | // Replace `any` with its first operand. |
| 42 | // Any operand would be a valid substitution. |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 43 | rewriter.replaceOp(op, {adaptor.getInputs().front()}); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 44 | return success(); |
| 45 | } |
| 46 | |
| 47 | namespace { |
Alexander Belyaev | 80be54c | 2020-06-08 15:48:01 | [diff] [blame] | 48 | template <typename SrcOpTy, typename DstOpTy> |
| 49 | class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { |
| 50 | public: |
| 51 | using OpConversionPattern<SrcOpTy>::OpConversionPattern; |
| 52 | |
| 53 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 54 | matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, |
Alexander Belyaev | 80be54c | 2020-06-08 15:48:01 | [diff] [blame] | 55 | ConversionPatternRewriter &rewriter) const override { |
Frederik Gossen | 6673c6c | 2020-07-29 13:53:41 | [diff] [blame] | 56 | // For now, only error-free types are supported by this lowering. |
| 57 | if (op.getType().template isa<SizeType>()) |
| 58 | return failure(); |
| 59 | |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 60 | rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(), |
| 61 | adaptor.getRhs()); |
Alexander Belyaev | 80be54c | 2020-06-08 15:48:01 | [diff] [blame] | 62 | return success(); |
| 63 | } |
| 64 | }; |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 65 | } // namespace |
Alexander Belyaev | 80be54c | 2020-06-08 15:48:01 | [diff] [blame] | 66 | |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 67 | namespace { |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 68 | struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { |
| 69 | using OpConversionPattern<BroadcastOp>::OpConversionPattern; |
Stephan Herhut | 5d9f33a | 2020-07-28 17:08:40 | [diff] [blame] | 70 | |
| 71 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 72 | matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 73 | ConversionPatternRewriter &rewriter) const override; |
Frederik Gossen | ac3e5c4 | 2020-06-19 15:09:36 | [diff] [blame] | 74 | }; |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 75 | |
| 76 | // Get the resulting extent in a given dimension. This is computed with any |
| 77 | // number of extent tensors and shifted offsets into them. |
| 78 | Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors, |
| 79 | ValueRange rankDiffs, Value outputDimension) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 80 | Value one = lb.create<arith::ConstantIndexOp>(1); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 81 | Value broadcastedDim = one; |
| 82 | for (auto tup : llvm::zip(extentTensors, rankDiffs)) { |
| 83 | Value shape = std::get<0>(tup); |
| 84 | Value rankDiff = std::get<1>(tup); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 85 | Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult, |
| 86 | outputDimension, rankDiff); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 87 | Type indexTy = lb.getIndexType(); |
| 88 | broadcastedDim = |
| 89 | lb.create<IfOp>( |
| 90 | TypeRange{indexTy}, outOfBounds, |
| 91 | [&](OpBuilder &b, Location loc) { |
| 92 | b.create<scf::YieldOp>(loc, broadcastedDim); |
| 93 | }, |
| 94 | [&](OpBuilder &b, Location loc) { |
| 95 | // The broadcasting logic is: |
| 96 | // - if one extent (here we arbitrarily choose the |
| 97 | // extent from the greater-rank operand) is equal to 1, |
| 98 | // then take the extent from the other operand |
| 99 | // - otherwise, take the extent as-is. |
| 100 | // Note that this logic remains correct in the presence |
| 101 | // of dimensions of zero extent. |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 102 | Value lesserRankOperandDimension = b.create<arith::SubIOp>( |
| 103 | loc, indexTy, outputDimension, rankDiff); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 104 | Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( |
| 105 | loc, shape, ValueRange{lesserRankOperandDimension}); |
| 106 | |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 107 | Value dimIsOne = |
| 108 | b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| 109 | lesserRankOperandExtent, one); |
River Riddle | dec8af7 | 2022-01-31 20:44:35 | [diff] [blame] | 110 | Value dim = b.create<arith::SelectOp>( |
| 111 | loc, dimIsOne, broadcastedDim, lesserRankOperandExtent); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 112 | b.create<scf::YieldOp>(loc, dim); |
| 113 | }) |
| 114 | .getResult(0); |
| 115 | } |
| 116 | return broadcastedDim; |
| 117 | } |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 118 | } // namespace |
Frederik Gossen | ac3e5c4 | 2020-06-19 15:09:36 | [diff] [blame] | 119 | |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 120 | LogicalResult BroadcastOpConverter::matchAndRewrite( |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 121 | BroadcastOp op, OpAdaptor adaptor, |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 122 | ConversionPatternRewriter &rewriter) const { |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 123 | // For now, this lowering is only defined on `tensor<?xindex>` operands, not |
| 124 | // on shapes. |
Frederik Gossen | 6673c6c | 2020-07-29 13:53:41 | [diff] [blame] | 125 | if (op.getType().isa<ShapeType>()) |
| 126 | return failure(); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 127 | |
Frederik Gossen | 6673c6c | 2020-07-29 13:53:41 | [diff] [blame] | 128 | auto loc = op.getLoc(); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 129 | ImplicitLocOpBuilder lb(loc, rewriter); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 130 | |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 131 | Value zero = lb.create<arith::ConstantIndexOp>(0); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 132 | Type indexTy = lb.getIndexType(); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 133 | |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 134 | // Save all the ranks for bounds checking. Because this is a tensor |
| 135 | // representing the shape extents, the rank is the extent of the only |
| 136 | // dimension in the tensor. |
| 137 | SmallVector<Value> ranks, rankDiffs; |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 138 | llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { |
Matthias Springer | c0a6318 | 2021-07-01 00:58:48 | [diff] [blame] | 139 | return lb.create<tensor::DimOp>(v, zero); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 140 | })); |
| 141 | |
| 142 | // Find the maximum rank |
| 143 | Value maxRank = ranks.front(); |
| 144 | for (Value v : llvm::drop_begin(ranks, 1)) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 145 | Value rankIsGreater = |
| 146 | lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank); |
River Riddle | dec8af7 | 2022-01-31 20:44:35 | [diff] [blame] | 147 | maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 148 | } |
| 149 | |
| 150 | // Calculate the difference of ranks and the maximum rank for later offsets. |
| 151 | llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 152 | return lb.create<arith::SubIOp>(indexTy, maxRank, v); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 153 | })); |
| 154 | |
Frederik Gossen | eb56fa9 | 2021-04-29 08:07:20 | [diff] [blame] | 155 | Value replacement = lb.create<tensor::GenerateOp>( |
| 156 | getExtentTensorType(lb.getContext()), ValueRange{maxRank}, |
| 157 | [&](OpBuilder &b, Location loc, ValueRange args) { |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 158 | Value broadcastedDim = |
| 159 | getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), |
| 160 | rankDiffs, args[0]); |
Tres Popp | f30f347 | 2021-02-01 08:49:54 | [diff] [blame] | 161 | |
Frederik Gossen | eb56fa9 | 2021-04-29 08:07:20 | [diff] [blame] | 162 | b.create<tensor::YieldOp>(loc, broadcastedDim); |
| 163 | }); |
| 164 | if (replacement.getType() != op.getType()) |
| 165 | replacement = lb.create<tensor::CastOp>(op.getType(), replacement); |
| 166 | rewriter.replaceOp(op, replacement); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 167 | return success(); |
| 168 | } |
| 169 | |
| 170 | namespace { |
Frederik Gossen | dfcc098 | 2020-07-28 15:39:49 | [diff] [blame] | 171 | class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { |
| 172 | public: |
| 173 | using OpConversionPattern<ConstShapeOp>::OpConversionPattern; |
| 174 | |
| 175 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 176 | matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor, |
Frederik Gossen | dfcc098 | 2020-07-28 15:39:49 | [diff] [blame] | 177 | ConversionPatternRewriter &rewriter) const override; |
| 178 | }; |
| 179 | } // namespace |
| 180 | |
| 181 | LogicalResult ConstShapeOpConverter::matchAndRewrite( |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 182 | ConstShapeOp op, OpAdaptor adaptor, |
Frederik Gossen | dfcc098 | 2020-07-28 15:39:49 | [diff] [blame] | 183 | ConversionPatternRewriter &rewriter) const { |
| 184 | |
| 185 | // For now, this lowering supports only extent tensors, not `shape.shape` |
| 186 | // types. |
| 187 | if (op.getType().isa<ShapeType>()) |
| 188 | return failure(); |
| 189 | |
| 190 | auto loc = op.getLoc(); |
| 191 | SmallVector<Value, 4> extentOperands; |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 192 | for (auto extent : op.getShape()) { |
Frederik Gossen | dfcc098 | 2020-07-28 15:39:49 | [diff] [blame] | 193 | extentOperands.push_back( |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 194 | rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue())); |
Frederik Gossen | dfcc098 | 2020-07-28 15:39:49 | [diff] [blame] | 195 | } |
Alexander Belyaev | f77e9f8 | 2021-12-16 13:42:27 | [diff] [blame] | 196 | Type resultTy = |
| 197 | RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType()); |
Sean Silva | 84a6da6 | 2020-09-11 05:04:58 | [diff] [blame] | 198 | Value tensor = |
Alexander Belyaev | f77e9f8 | 2021-12-16 13:42:27 | [diff] [blame] | 199 | rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands); |
Sean Silva | 129d6e5 | 2020-12-16 00:47:19 | [diff] [blame] | 200 | rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor); |
Frederik Gossen | dfcc098 | 2020-07-28 15:39:49 | [diff] [blame] | 201 | return success(); |
| 202 | } |
| 203 | |
| 204 | namespace { |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 205 | class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { |
Stephan Herhut | 5d9f33a | 2020-07-28 17:08:40 | [diff] [blame] | 206 | public: |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 207 | using OpConversionPattern<ConstSizeOp>::OpConversionPattern; |
Stephan Herhut | 5d9f33a | 2020-07-28 17:08:40 | [diff] [blame] | 208 | |
| 209 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 210 | matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 211 | ConversionPatternRewriter &rewriter) const override; |
Stephan Herhut | 5d9f33a | 2020-07-28 17:08:40 | [diff] [blame] | 212 | }; |
| 213 | } // namespace |
| 214 | |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 215 | LogicalResult ConstSizeOpConversion::matchAndRewrite( |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 216 | ConstSizeOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 217 | ConversionPatternRewriter &rewriter) const { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 218 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>( |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 219 | op, op.getValue().getSExtValue()); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 220 | return success(); |
| 221 | } |
| 222 | |
Stephan Herhut | 5d9f33a | 2020-07-28 17:08:40 | [diff] [blame] | 223 | namespace { |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 224 | struct IsBroadcastableOpConverter |
| 225 | : public OpConversionPattern<IsBroadcastableOp> { |
| 226 | using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern; |
| 227 | |
| 228 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 229 | matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor, |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 230 | ConversionPatternRewriter &rewriter) const override; |
| 231 | }; |
| 232 | } // namespace |
| 233 | |
| 234 | LogicalResult IsBroadcastableOpConverter::matchAndRewrite( |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 235 | IsBroadcastableOp op, OpAdaptor adaptor, |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 236 | ConversionPatternRewriter &rewriter) const { |
| 237 | // For now, this lowering is only defined on `tensor<?xindex>` operands, not |
| 238 | // on shapes. |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 239 | if (!llvm::all_of(op.getShapes(), |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 240 | [](Value v) { return !v.getType().isa<ShapeType>(); })) |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 241 | return failure(); |
| 242 | |
| 243 | auto loc = op.getLoc(); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 244 | ImplicitLocOpBuilder lb(loc, rewriter); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 245 | Value zero = lb.create<arith::ConstantIndexOp>(0); |
| 246 | Value one = lb.create<arith::ConstantIndexOp>(1); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 247 | Type indexTy = lb.getIndexType(); |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 248 | |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 249 | // Save all the ranks for bounds checking. Because this is a tensor |
| 250 | // representing the shape extents, the rank is the extent of the only |
| 251 | // dimension in the tensor. |
| 252 | SmallVector<Value> ranks, rankDiffs; |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 253 | llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) { |
Matthias Springer | c0a6318 | 2021-07-01 00:58:48 | [diff] [blame] | 254 | return lb.create<tensor::DimOp>(v, zero); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 255 | })); |
| 256 | |
| 257 | // Find the maximum rank |
| 258 | Value maxRank = ranks.front(); |
| 259 | for (Value v : llvm::drop_begin(ranks, 1)) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 260 | Value rankIsGreater = |
| 261 | lb.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, v, maxRank); |
River Riddle | dec8af7 | 2022-01-31 20:44:35 | [diff] [blame] | 262 | maxRank = lb.create<arith::SelectOp>(rankIsGreater, v, maxRank); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 263 | } |
| 264 | |
| 265 | // Calculate the difference of ranks and the maximum rank for later offsets. |
| 266 | llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 267 | return lb.create<arith::SubIOp>(indexTy, maxRank, v); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 268 | })); |
| 269 | |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 270 | Type i1Ty = rewriter.getI1Type(); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 271 | Value trueVal = |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 272 | rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 273 | |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 274 | auto reduceResult = lb.create<ForOp>( |
| 275 | loc, zero, maxRank, one, ValueRange{trueVal}, |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 276 | [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 277 | // Find a non-1 dim, if it exists. Note that the first part of this |
| 278 | // could reuse the Broadcast lowering entirely, but we redo the work |
| 279 | // here to make optimizations easier between the two loops. |
| 280 | Value broadcastedDim = getBroadcastedDim( |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 281 | ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 282 | |
| 283 | Value broadcastable = iterArgs[0]; |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 284 | for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) { |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 285 | Value shape, rankDiff; |
| 286 | std::tie(shape, rankDiff) = tup; |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 287 | Value outOfBounds = b.create<arith::CmpIOp>( |
| 288 | loc, arith::CmpIPredicate::ult, iv, rankDiff); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 289 | broadcastable = |
| 290 | b.create<IfOp>( |
| 291 | loc, TypeRange{i1Ty}, outOfBounds, |
| 292 | [&](OpBuilder &b, Location loc) { |
| 293 | // Non existent dimensions are always broadcastable |
| 294 | b.create<scf::YieldOp>(loc, broadcastable); |
| 295 | }, |
| 296 | [&](OpBuilder &b, Location loc) { |
| 297 | // Every value needs to be either 1, or the same non-1 |
| 298 | // value to be broadcastable in this dim. |
| 299 | Value operandDimension = |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 300 | b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 301 | Value dimensionExtent = b.create<tensor::ExtractOp>( |
| 302 | loc, shape, ValueRange{operandDimension}); |
| 303 | |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 304 | Value equalOne = b.create<arith::CmpIOp>( |
| 305 | loc, arith::CmpIPredicate::eq, dimensionExtent, one); |
| 306 | Value equalBroadcasted = b.create<arith::CmpIOp>( |
| 307 | loc, arith::CmpIPredicate::eq, dimensionExtent, |
| 308 | broadcastedDim); |
| 309 | Value result = b.create<arith::AndIOp>( |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 310 | loc, broadcastable, |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 311 | b.create<arith::OrIOp>(loc, equalOne, |
| 312 | equalBroadcasted)); |
Tres Popp | 3842d4b | 2021-02-10 09:24:32 | [diff] [blame] | 313 | b.create<scf::YieldOp>(loc, result); |
| 314 | }) |
| 315 | .getResult(0); |
| 316 | } |
| 317 | |
| 318 | b.create<scf::YieldOp>(loc, broadcastable); |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 319 | }); |
| 320 | |
Jacques Pienaar | c0342a2 | 2021-12-20 16:03:43 | [diff] [blame] | 321 | rewriter.replaceOp(op, reduceResult.getResults().front()); |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 322 | return success(); |
| 323 | } |
| 324 | |
| 325 | namespace { |
Frederik Gossen | 8577a09 | 2020-06-30 08:33:49 | [diff] [blame] | 326 | class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { |
| 327 | using OpConversionPattern<GetExtentOp>::OpConversionPattern; |
| 328 | |
| 329 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 330 | matchAndRewrite(GetExtentOp op, OpAdaptor adaptor, |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 331 | ConversionPatternRewriter &rewriter) const override; |
| 332 | }; |
| 333 | } // namespace |
Frederik Gossen | 8577a09 | 2020-06-30 08:33:49 | [diff] [blame] | 334 | |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 335 | LogicalResult GetExtentOpConverter::matchAndRewrite( |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 336 | GetExtentOp op, OpAdaptor adaptor, |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 337 | ConversionPatternRewriter &rewriter) const { |
Frederik Gossen | 6673c6c | 2020-07-29 13:53:41 | [diff] [blame] | 338 | // For now, only error-free types are supported by this lowering. |
| 339 | if (op.getType().isa<SizeType>()) |
| 340 | return failure(); |
| 341 | |
| 342 | // Derive shape extent directly from shape origin if possible. This |
| 343 | // circumvents the necessity to materialize the shape in memory. |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 344 | if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) { |
| 345 | if (shapeOfOp.getArg().getType().isa<ShapedType>()) { |
| 346 | rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(), |
| 347 | adaptor.getDim()); |
Frederik Gossen | 6673c6c | 2020-07-29 13:53:41 | [diff] [blame] | 348 | return success(); |
| 349 | } |
Frederik Gossen | 8577a09 | 2020-06-30 08:33:49 | [diff] [blame] | 350 | } |
Frederik Gossen | 8577a09 | 2020-06-30 08:33:49 | [diff] [blame] | 351 | |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 352 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(), |
| 353 | adaptor.getShape(), |
| 354 | ValueRange{adaptor.getDim()}); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 355 | return success(); |
| 356 | } |
| 357 | |
| 358 | namespace { |
Frederik Gossen | 24debf5 | 2020-06-25 08:42:40 | [diff] [blame] | 359 | class RankOpConverter : public OpConversionPattern<shape::RankOp> { |
| 360 | public: |
| 361 | using OpConversionPattern<shape::RankOp>::OpConversionPattern; |
| 362 | |
| 363 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 364 | matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 365 | ConversionPatternRewriter &rewriter) const override; |
Frederik Gossen | 24debf5 | 2020-06-25 08:42:40 | [diff] [blame] | 366 | }; |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 367 | } // namespace |
Frederik Gossen | 24debf5 | 2020-06-25 08:42:40 | [diff] [blame] | 368 | |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 369 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 370 | RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 371 | ConversionPatternRewriter &rewriter) const { |
Frederik Gossen | a97940d | 2020-07-30 11:40:16 | [diff] [blame] | 372 | // For now, this lowering supports only error-free types. |
| 373 | if (op.getType().isa<SizeType>()) |
| 374 | return failure(); |
| 375 | |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 376 | rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 377 | return success(); |
| 378 | } |
| 379 | |
| 380 | namespace { |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 381 | /// Converts `shape.reduce` to `scf.for`. |
| 382 | struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { |
| 383 | public: |
| 384 | using OpConversionPattern::OpConversionPattern; |
| 385 | |
| 386 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 387 | matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 388 | ConversionPatternRewriter &rewriter) const final; |
| 389 | }; |
| 390 | } // namespace |
| 391 | |
| 392 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 393 | ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 394 | ConversionPatternRewriter &rewriter) const { |
| 395 | // For now, this lowering is only defined on `tensor<?xindex>` operands. |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 396 | if (op.getShape().getType().isa<ShapeType>()) |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 397 | return failure(); |
| 398 | |
| 399 | auto loc = op.getLoc(); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 400 | |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 401 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 402 | Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 403 | Type indexTy = rewriter.getIndexType(); |
Julian Gross | e231070 | 2021-02-10 12:53:11 | [diff] [blame] | 404 | Value rank = |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 405 | rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 406 | |
| 407 | auto loop = rewriter.create<scf::ForOp>( |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 408 | loc, zero, rank, one, op.getInitVals(), |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 409 | [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 410 | Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 411 | |
| 412 | SmallVector<Value, 2> mappedValues{iv, extent}; |
| 413 | mappedValues.append(args.begin(), args.end()); |
| 414 | |
| 415 | BlockAndValueMapping mapping; |
| 416 | Block *reduceBody = op.getBody(); |
| 417 | mapping.map(reduceBody->getArguments(), mappedValues); |
| 418 | for (auto &nested : reduceBody->without_terminator()) |
| 419 | b.clone(nested, mapping); |
| 420 | |
| 421 | SmallVector<Value, 2> mappedResults; |
| 422 | for (auto result : reduceBody->getTerminator()->getOperands()) |
| 423 | mappedResults.push_back(mapping.lookup(result)); |
| 424 | b.create<scf::YieldOp>(loc, mappedResults); |
| 425 | }); |
| 426 | |
| 427 | rewriter.replaceOp(op, loop.getResults()); |
| 428 | return success(); |
| 429 | } |
| 430 | |
| 431 | namespace { |
| 432 | /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is |
| 433 | /// only defined on `tensor<?xindex>` operands. The test for equality first |
| 434 | /// compares their size and, if equal, checks every extent for equality. |
| 435 | /// |
| 436 | /// Example: |
| 437 | /// |
| 438 | /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex> |
| 439 | /// |
| 440 | /// becomes |
| 441 | /// |
Mogball | cb3aa49 | 2021-10-14 16:55:33 | [diff] [blame] | 442 | /// %c0 = arith.constant 0 : index |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 443 | /// %0 = dim %arg0, %c0 : tensor<?xindex> |
| 444 | /// %1 = dim %arg1, %c0 : tensor<?xindex> |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 445 | /// %2 = arith.cmpi "eq", %0, %1 : index |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 446 | /// %result = scf.if %2 -> (i1) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 447 | /// %c1 = arith.constant 1 : index |
| 448 | /// %true = arith.constant true |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 449 | /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { |
Sean Silva | 444822d | 2020-12-11 22:20:03 | [diff] [blame] | 450 | /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex> |
| 451 | /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex> |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 452 | /// %7 = arith.cmpi "eq", %5, %6 : index |
| 453 | /// %8 = arith.andi %arg3, %7 : i1 |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 454 | /// scf.yield %8 : i1 |
| 455 | /// } |
| 456 | /// scf.yield %4 : i1 |
| 457 | /// } else { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 458 | /// %false = arith.constant false |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 459 | /// scf.yield %false : i1 |
| 460 | /// } |
| 461 | /// |
| 462 | struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { |
| 463 | using OpConversionPattern<ShapeEqOp>::OpConversionPattern; |
| 464 | |
| 465 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 466 | matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 467 | ConversionPatternRewriter &rewriter) const override; |
| 468 | }; |
| 469 | } // namespace |
| 470 | |
| 471 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 472 | ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 473 | ConversionPatternRewriter &rewriter) const { |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 474 | if (!llvm::all_of(op.getShapes(), |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 475 | [](Value v) { return !v.getType().isa<ShapeType>(); })) |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 476 | return failure(); |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 477 | |
| 478 | Type i1Ty = rewriter.getI1Type(); |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 479 | if (op.getShapes().size() <= 1) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 480 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty, |
| 481 | rewriter.getBoolAttr(true)); |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 482 | return success(); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 483 | } |
| 484 | |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 485 | auto loc = op.getLoc(); |
| 486 | Type indexTy = rewriter.getIndexType(); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 487 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 488 | Value firstShape = adaptor.getShapes().front(); |
Julian Gross | e231070 | 2021-02-10 12:53:11 | [diff] [blame] | 489 | Value firstRank = |
Matthias Springer | c0a6318 | 2021-07-01 00:58:48 | [diff] [blame] | 490 | rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero); |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 491 | Value result = nullptr; |
| 492 | // Generate a linear sequence of compares, all with firstShape as lhs. |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 493 | for (Value shape : adaptor.getShapes().drop_front(1)) { |
Matthias Springer | c0a6318 | 2021-07-01 00:58:48 | [diff] [blame] | 494 | Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 495 | Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| 496 | firstRank, rank); |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 497 | auto same = rewriter.create<IfOp>( |
| 498 | loc, i1Ty, eqRank, |
| 499 | [&](OpBuilder &b, Location loc) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 500 | Value one = b.create<arith::ConstantIndexOp>(loc, 1); |
| 501 | Value init = |
| 502 | b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true)); |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 503 | auto loop = b.create<scf::ForOp>( |
| 504 | loc, zero, firstRank, one, ValueRange{init}, |
| 505 | [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { |
| 506 | Value conj = args[0]; |
| 507 | Value lhsExtent = |
| 508 | b.create<tensor::ExtractOp>(loc, firstShape, iv); |
| 509 | Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 510 | Value eqExtent = b.create<arith::CmpIOp>( |
| 511 | loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent); |
| 512 | Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent); |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 513 | b.create<scf::YieldOp>(loc, ValueRange({conjNext})); |
| 514 | }); |
| 515 | b.create<scf::YieldOp>(loc, loop.getResults()); |
| 516 | }, |
| 517 | [&](OpBuilder &b, Location loc) { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 518 | Value result = |
| 519 | b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false)); |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 520 | b.create<scf::YieldOp>(loc, result); |
| 521 | }); |
| 522 | result = !result ? same.getResult(0) |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 523 | : rewriter.create<arith::AndIOp>(loc, result, |
| 524 | same.getResult(0)); |
Benjamin Kramer | 24acade | 2021-03-01 19:34:17 | [diff] [blame] | 525 | } |
| 526 | rewriter.replaceOp(op, result); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 527 | return success(); |
| 528 | } |
| 529 | |
| 530 | namespace { |
| 531 | class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { |
| 532 | public: |
| 533 | using OpConversionPattern<ShapeOfOp>::OpConversionPattern; |
| 534 | |
| 535 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 536 | matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 537 | ConversionPatternRewriter &rewriter) const override; |
| 538 | }; |
| 539 | } // namespace |
| 540 | |
| 541 | LogicalResult ShapeOfOpConversion::matchAndRewrite( |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 542 | ShapeOfOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 543 | ConversionPatternRewriter &rewriter) const { |
| 544 | |
| 545 | // For now, only error-free types are supported by this lowering. |
| 546 | if (op.getType().isa<ShapeType>()) |
| 547 | return failure(); |
| 548 | |
Sean Silva | be7352c | 2021-01-15 02:28:48 | [diff] [blame] | 549 | // For ranked tensor arguments, lower to `tensor.from_elements`. |
Frederik Gossen | 5106a8b | 2020-09-09 07:53:13 | [diff] [blame] | 550 | auto loc = op.getLoc(); |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 551 | Value tensor = adaptor.getArg(); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 552 | Type tensorTy = tensor.getType(); |
| 553 | if (tensorTy.isa<RankedTensorType>()) { |
| 554 | |
| 555 | // Build values for individual extents. |
| 556 | SmallVector<Value, 8> extentValues; |
| 557 | RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>(); |
| 558 | int64_t rank = rankedTensorTy.getRank(); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 559 | for (int64_t i = 0; i < rank; i++) { |
| 560 | if (rankedTensorTy.isDynamicDim(i)) { |
Matthias Springer | c0a6318 | 2021-07-01 00:58:48 | [diff] [blame] | 561 | Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 562 | extentValues.push_back(extent); |
| 563 | } else { |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 564 | Value extent = rewriter.create<arith::ConstantIndexOp>( |
| 565 | loc, rankedTensorTy.getDimSize(i)); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 566 | extentValues.push_back(extent); |
| 567 | } |
| 568 | } |
| 569 | |
| 570 | // Materialize extent tensor. |
Sean Silva | be7352c | 2021-01-15 02:28:48 | [diff] [blame] | 571 | Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>( |
Alexander Belyaev | f77e9f8 | 2021-12-16 13:42:27 | [diff] [blame] | 572 | loc, RankedTensorType::get({rank}, rewriter.getIndexType()), |
| 573 | extentValues); |
Sean Silva | 129d6e5 | 2020-12-16 00:47:19 | [diff] [blame] | 574 | rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), |
| 575 | staticExtentTensor); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 576 | return success(); |
| 577 | } |
| 578 | |
Sean Silva | be7352c | 2021-01-15 02:28:48 | [diff] [blame] | 579 | // Lower to `tensor.generate` otherwise. |
Frederik Gossen | 5106a8b | 2020-09-09 07:53:13 | [diff] [blame] | 580 | auto *ctx = rewriter.getContext(); |
Alexander Belyaev | 15f8f3e | 2021-12-14 08:35:14 | [diff] [blame] | 581 | Value rank = rewriter.create<tensor::RankOp>(loc, tensor); |
Sean Silva | be7352c | 2021-01-15 02:28:48 | [diff] [blame] | 582 | rewriter.replaceOpWithNewOp<tensor::GenerateOp>( |
Frederik Gossen | 5106a8b | 2020-09-09 07:53:13 | [diff] [blame] | 583 | op, getExtentTensorType(ctx), ValueRange{rank}, |
| 584 | [&](OpBuilder &b, Location loc, ValueRange args) { |
| 585 | Value dim = args.front(); |
Matthias Springer | c0a6318 | 2021-07-01 00:58:48 | [diff] [blame] | 586 | Value extent = b.create<tensor::DimOp>(loc, tensor, dim); |
Sean Silva | be7352c | 2021-01-15 02:28:48 | [diff] [blame] | 587 | b.create<tensor::YieldOp>(loc, extent); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 588 | }); |
| 589 | |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 590 | return success(); |
| 591 | } |
| 592 | |
| 593 | namespace { |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 594 | class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { |
| 595 | public: |
| 596 | using OpConversionPattern<SplitAtOp>::OpConversionPattern; |
| 597 | |
| 598 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 599 | matchAndRewrite(SplitAtOp op, OpAdaptor adaptor, |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 600 | ConversionPatternRewriter &rewriter) const override; |
| 601 | }; |
| 602 | } // namespace |
| 603 | |
| 604 | LogicalResult SplitAtOpConversion::matchAndRewrite( |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 605 | SplitAtOp op, OpAdaptor adaptor, |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 606 | ConversionPatternRewriter &rewriter) const { |
| 607 | // Error conditions are not implemented, only lower if all operands and |
| 608 | // results are extent tensors. |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 609 | if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()}, |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 610 | [](Value v) { return v.getType().isa<ShapeType>(); })) |
| 611 | return failure(); |
| 612 | |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 613 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 614 | Value zero = b.create<arith::ConstantIndexOp>(0); |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 615 | Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero); |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 616 | |
| 617 | // index < 0 ? index + rank : index |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 618 | Value originalIndex = adaptor.getIndex(); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 619 | Value add = b.create<arith::AddIOp>(originalIndex, rank); |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 620 | Value indexIsNegative = |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 621 | b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero); |
River Riddle | dec8af7 | 2022-01-31 20:44:35 | [diff] [blame] | 622 | Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex); |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 623 | |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 624 | Value one = b.create<arith::ConstantIndexOp>(1); |
Matthias Springer | 060208b | 2021-06-22 07:49:08 | [diff] [blame] | 625 | Value head = |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 626 | b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one); |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 627 | Value tailSize = b.create<arith::SubIOp>(rank, index); |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 628 | Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index, |
| 629 | tailSize, one); |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 630 | rewriter.replaceOp(op, {head, tail}); |
| 631 | return success(); |
| 632 | } |
| 633 | |
| 634 | namespace { |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 635 | class ToExtentTensorOpConversion |
| 636 | : public OpConversionPattern<ToExtentTensorOp> { |
| 637 | public: |
| 638 | using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern; |
| 639 | |
| 640 | LogicalResult |
River Riddle | b54c724 | 2021-09-24 17:50:58 | [diff] [blame] | 641 | matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 642 | ConversionPatternRewriter &rewriter) const override { |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 643 | if (!adaptor.getInput().getType().isa<RankedTensorType>()) |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 644 | return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); |
| 645 | |
Sean Silva | 129d6e5 | 2020-12-16 00:47:19 | [diff] [blame] | 646 | rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), |
Jacques Pienaar | cfb72fd3 | 2021-10-25 01:36:33 | [diff] [blame] | 647 | adaptor.getInput()); |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 648 | return success(); |
| 649 | } |
| 650 | }; |
| 651 | } // namespace |
| 652 | |
| 653 | namespace { |
Tres Popp | d05d421 | 2020-10-13 15:56:45 | [diff] [blame] | 654 | /// Import the Shape Ops to Std Patterns. |
| 655 | #include "ShapeToStandard.cpp.inc" |
| 656 | } // namespace |
| 657 | |
| 658 | namespace { |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 659 | /// Conversion pass. |
| 660 | class ConvertShapeToStandardPass |
| 661 | : public ConvertShapeToStandardBase<ConvertShapeToStandardPass> { |
Frederik Gossen | eaf4913 | 2020-06-18 07:51:03 | [diff] [blame] | 662 | |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 663 | void runOnOperation() override; |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 664 | }; |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 665 | } // namespace |
| 666 | |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 667 | void ConvertShapeToStandardPass::runOnOperation() { |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 668 | // Setup target legality. |
Frederik Gossen | b6b9d3e | 2020-07-29 10:45:07 | [diff] [blame] | 669 | MLIRContext &ctx = getContext(); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 670 | ConversionTarget target(ctx); |
River Riddle | 23aa5a7 | 2022-02-26 22:49:54 | [diff] [blame^] | 671 | target.addLegalDialect<arith::ArithmeticDialect, func::FuncDialect, |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 672 | SCFDialect, tensor::TensorDialect>(); |
Mehdi Amini | 973ddb7 | 2021-03-11 23:58:02 | [diff] [blame] | 673 | target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>(); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 674 | |
| 675 | // Setup conversion patterns. |
Chris Lattner | dc4e913 | 2021-03-22 23:58:34 | [diff] [blame] | 676 | RewritePatternSet patterns(&ctx); |
Chris Lattner | 3a506b3 | 2021-03-20 23:29:41 | [diff] [blame] | 677 | populateShapeToStandardConversionPatterns(patterns); |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 678 | |
| 679 | // Apply conversion. |
| 680 | auto module = getOperation(); |
River Riddle | 3fffffa8 | 2020-10-27 00:25:01 | [diff] [blame] | 681 | if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
Frederik Gossen | 4baf18d | 2020-07-24 08:53:54 | [diff] [blame] | 682 | signalPassFailure(); |
| 683 | } |
| 684 | |
Frederik Gossen | 24edbdf | 2020-06-08 08:58:06 | [diff] [blame] | 685 | void mlir::populateShapeToStandardConversionPatterns( |
Chris Lattner | dc4e913 | 2021-03-22 23:58:34 | [diff] [blame] | 686 | RewritePatternSet &patterns) { |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 687 | // clang-format off |
Chris Lattner | 1d909c9 | 2021-03-21 17:38:35 | [diff] [blame] | 688 | populateWithGenerated(patterns); |
Chris Lattner | dc4e913 | 2021-03-22 23:58:34 | [diff] [blame] | 689 | patterns.add< |
Frederik Gossen | 9df6afb | 2020-07-13 08:28:13 | [diff] [blame] | 690 | AnyOpConversion, |
Mogball | a54f4ea | 2021-10-12 23:14:57 | [diff] [blame] | 691 | BinaryOpConversion<AddOp, arith::AddIOp>, |
| 692 | BinaryOpConversion<MulOp, arith::MulIOp>, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 693 | BroadcastOpConverter, |
| 694 | ConstShapeOpConverter, |
Stephan Herhut | 5d9f33a | 2020-07-28 17:08:40 | [diff] [blame] | 695 | ConstSizeOpConversion, |
Tres Popp | 511484f | 2020-10-29 16:13:26 | [diff] [blame] | 696 | IsBroadcastableOpConverter, |
Frederik Gossen | 8577a09 | 2020-06-30 08:33:49 | [diff] [blame] | 697 | GetExtentOpConverter, |
Frederik Gossen | 24debf5 | 2020-06-25 08:42:40 | [diff] [blame] | 698 | RankOpConverter, |
Frederik Gossen | a70f2eb | 2020-09-07 13:58:01 | [diff] [blame] | 699 | ReduceOpConverter, |
| 700 | ShapeEqOpConverter, |
Stephan Herhut | 5d9f33a | 2020-07-28 17:08:40 | [diff] [blame] | 701 | ShapeOfOpConversion, |
Benjamin Kramer | 42c195f | 2021-03-08 14:23:28 | [diff] [blame] | 702 | SplitAtOpConversion, |
Chris Lattner | 3a506b3 | 2021-03-20 23:29:41 | [diff] [blame] | 703 | ToExtentTensorOpConversion>(patterns.getContext()); |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 704 | // clang-format on |
| 705 | } |
| 706 | |
Frederik Gossen | 24edbdf | 2020-06-08 08:58:06 | [diff] [blame] | 707 | std::unique_ptr<OperationPass<ModuleOp>> |
| 708 | mlir::createConvertShapeToStandardPass() { |
Frederik Gossen | 3713314 | 2020-06-03 16:14:42 | [diff] [blame] | 709 | return std::make_unique<ConvertShapeToStandardPass>(); |
| 710 | } |