blob: a3fd8b52b2f03c016f1fbff2381b255585eb89f5 [file] [log] [blame]
Rob Sudermanf0cb77d2021-12-24 00:25:531//===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg named ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
River Riddle23aa5a72022-02-26 22:49:5415#include "mlir/Dialect/Func/IR/FuncOps.h"
Rob Sudermanf0cb77d2021-12-24 00:25:5316#include "mlir/Dialect/Linalg/IR/Linalg.h"
17#include "mlir/Dialect/Math/IR/Math.h"
18#include "mlir/Dialect/SCF/SCF.h"
Rob Sudermanf0cb77d2021-12-24 00:25:5319#include "mlir/Dialect/Tensor/IR/Tensor.h"
Alexander Belyaevfd0c6f52022-01-21 18:29:0820#include "mlir/Dialect/Tensor/Utils/Utils.h"
Rob Sudermanf0cb77d2021-12-24 00:25:5321#include "mlir/Dialect/Tosa/IR/TosaOps.h"
natashaknk310e9632022-01-12 22:10:2722#include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
Rob Sudermanf0cb77d2021-12-24 00:25:5323#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28
29#include <numeric>
30
31using namespace mlir;
natashaknk310e9632022-01-12 22:10:2732using namespace mlir::tosa;
Rob Sudermanf0cb77d2021-12-24 00:25:5333
34static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
35 Attribute padAttr, OpBuilder &rewriter) {
36 // Input should be padded if necessary.
37 if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
38 return input;
39
40 ShapedType inputTy = input.getType().cast<ShapedType>();
41 Type inputETy = inputTy.getElementType();
42 auto inputShape = inputTy.getShape();
43
44 assert((inputShape.size() * 2) == pad.size());
45
46 SmallVector<int64_t, 4> paddedShape;
47 SmallVector<OpFoldResult, 8> lowIndices;
48 SmallVector<OpFoldResult, 8> highIndices;
49 for (int i = 0, s = inputShape.size(); i < s; i++) {
50 auto lowPad = pad[i * 2];
51 auto highPad = pad[i * 2 + 1];
52 paddedShape.push_back(inputShape[i] + highPad + lowPad);
53 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
54 highIndices.push_back(rewriter.getIndexAttr(highPad));
55 }
56
57 Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
58
Alexander Belyaevfd0c6f52022-01-21 18:29:0859 return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy),
60 input, padValue, lowIndices, highIndices,
61 /*nofold=*/false, loc, rewriter)
Rob Sudermanf0cb77d2021-12-24 00:25:5362 .result();
63}
64
Rob Sudermanf0cb77d2021-12-24 00:25:5365namespace {
66
67class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
68public:
69 using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
70 LogicalResult
71 matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const final {
73 Location loc = op->getLoc();
74 Value input = op->getOperand(0);
75 Value weight = op->getOperand(1);
76 Value bias = op->getOperand(2);
77
78 ShapedType inputTy = input.getType().cast<ShapedType>();
79 ShapedType weightTy = weight.getType().cast<ShapedType>();
80 ShapedType biasTy = bias.getType().cast<ShapedType>();
81 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
82
83 Type inputETy = inputTy.getElementType();
84 Type resultETy = resultTy.getElementType();
85
86 auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
87 auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
88 auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
89 bool isQuantized = op->hasAttr("quantization_info");
90
natashaknk310e9632022-01-12 22:10:2791 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
92 return rewriter.notifyMatchFailure(
93 op, "tosa.conv ops require static shapes for weight and bias");
94
95 auto dynamicDimsOr =
96 checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
97 if (!dynamicDimsOr.hasValue())
98 return failure();
99 SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
Rob Sudermanf0cb77d2021-12-24 00:25:53100
101 if (inputETy.isUnsignedInteger())
102 return rewriter.notifyMatchFailure(
103 op, "tosa.conv ops does not support unsigned integer input");
104
105 auto weightShape = weightTy.getShape();
106
107 // Apply padding as necessary.
108 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
109 if (isQuantized) {
110 auto quantizationInfo =
111 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
112 auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
113
114 int64_t intMin =
115 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
116 .getSExtValue();
117 int64_t intMax =
118 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
119 .getSExtValue();
120
121 if (iZp < intMin || iZp > intMax)
122 return rewriter.notifyMatchFailure(
123 op, "tosa.conv op quantization has zp outside of input range");
124
125 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
126 }
127
128 llvm::SmallVector<int64_t> pad;
129 pad.resize(2, 0);
130 getValuesFromIntArrayAttribute(padAttr, pad);
131 pad.resize(pad.size() + 2, 0);
132 input = applyPad(loc, input, pad, zeroAttr, rewriter);
133
134 // Transpose the kernel to match dimension ordering of the linalg
135 // convolution operation.
136 // TODO(suderman): See if this can be efficiently folded - check whether
137 // the input is used anywhere else, if not fold the constant.
138 SmallVector<int64_t> weightPerm{1, 2, 3, 0};
139 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
140 weightShape[3], weightShape[0]};
141 auto weightPermAttr = DenseIntElementsAttr::get(
142 RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
143 Value weightPermValue =
144 rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
145 Type newWeightTy =
146 RankedTensorType::get(newWeightShape, weightTy.getElementType());
147 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
148 weightPermValue);
149
150 Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
151 Value initTensor = rewriter.create<linalg::InitTensorOp>(
natashaknk310e9632022-01-12 22:10:27152 loc, dynamicDims, resultTy.getShape(), resultETy);
Rob Sudermanf0cb77d2021-12-24 00:25:53153 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
154 Value zeroTensor =
155 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
156
157 // Extract the attributes for convolution.
158 llvm::SmallVector<int64_t> stride, dilation;
159 getValuesFromIntArrayAttribute(strideTosaAttr, stride);
160 getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
161
162 // Create the convolution op.
163 auto strideAttr = DenseIntElementsAttr::get(
164 RankedTensorType::get({2}, rewriter.getI64Type()), stride);
165 auto dilationAttr = DenseIntElementsAttr::get(
166 RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
167
168 // Create maps for the bias broadcasting
169 SmallVector<AffineMap, 4> indexingMaps;
170 indexingMaps.push_back(AffineMap::get(
171 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
172 {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
173 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
174 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
175
176 Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
natashaknk310e9632022-01-12 22:10:27177 loc, dynamicDims, resultTy.getShape(), resultETy);
Rob Sudermanf0cb77d2021-12-24 00:25:53178
179 if (isQuantized) {
180 auto quantizationInfo =
181 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
182 auto iZp = rewriter.getI32IntegerAttr(
183 quantizationInfo.input_zp().getValue().getSExtValue());
184 auto kZp = rewriter.getI32IntegerAttr(
185 quantizationInfo.weight_zp().getValue().getSExtValue());
186
187 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
188 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
189 Value conv =
190 rewriter
191 .create<linalg::Conv2DNhwcHwcfQOp>(
192 loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
193 ValueRange{zeroTensor}, strideAttr, dilationAttr)
194 ->getResult(0);
195
196 Value result =
197 rewriter
198 .create<linalg::GenericOp>(
199 loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
200 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
201 [&](OpBuilder &nestedBuilder, Location nestedLoc,
202 ValueRange args) {
203 Value added = nestedBuilder.create<arith::AddIOp>(
204 loc, args[0], args[1]);
205 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
206 })
207 .getResult(0);
208 rewriter.replaceOp(op, result);
209 return success();
210 }
211
212 Value conv = rewriter
213 .create<linalg::Conv2DNhwcHwcfOp>(
214 loc, resultTy, ValueRange{input, weight},
215 ValueRange{zeroTensor}, strideAttr, dilationAttr)
216 ->getResult(0);
217
218 Value result =
219 rewriter
220 .create<linalg::GenericOp>(
221 loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
222 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
223 [&](OpBuilder &nestedBuilder, Location nestedLoc,
224 ValueRange args) {
225 Value added = nestedBuilder.create<arith::AddFOp>(
226 loc, args[0], args[1]);
227 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
228 })
229 .getResult(0);
230
231 rewriter.replaceOp(op, result);
232 return success();
233 }
234};
235
236class DepthwiseConvConverter
237 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
238public:
239 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
240 LogicalResult
241 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter) const final {
243 Location loc = op->getLoc();
244 Value input = op->getOperand(0);
245 Value weight = op->getOperand(1);
246 Value bias = op->getOperand(2);
247
248 ShapedType inputTy = input.getType().cast<ShapedType>();
249 ShapedType weightTy = weight.getType().cast<ShapedType>();
250 ShapedType biasTy = bias.getType().cast<ShapedType>();
251 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
252
253 Type inputETy = inputTy.getElementType();
254 Type resultETy = resultTy.getElementType();
255
256 auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
257 auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
258 auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
259
260 bool isQuantized = op->hasAttr("quantization_info");
261 IntegerAttr iZp;
262 IntegerAttr kZp;
263 if (isQuantized) {
264 auto quantizationInfo =
265 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
266 iZp = rewriter.getI32IntegerAttr(
267 quantizationInfo.input_zp().getValue().getSExtValue());
268 kZp = rewriter.getI32IntegerAttr(
269 quantizationInfo.weight_zp().getValue().getSExtValue());
270 }
271
natashaknk310e9632022-01-12 22:10:27272 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
273 return rewriter.notifyMatchFailure(
274 op, "tosa.depthwise_conv ops require static shapes");
275
276 auto dynamicDimsOr =
277 checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
278 if (!dynamicDimsOr.hasValue())
279 return failure();
280 SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
Rob Sudermanf0cb77d2021-12-24 00:25:53281
282 auto weightShape = weightTy.getShape();
283 auto resultShape = resultTy.getShape();
284
285 // Apply padding as necessary.
286 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
287 if (isQuantized) {
288 auto quantizationInfo =
289 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
290 auto iZp = quantizationInfo.input_zp().getValue().getSExtValue();
291
292 int64_t intMin =
293 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
294 .getSExtValue();
295 int64_t intMax =
296 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
297 .getSExtValue();
298
299 if (iZp < intMin || iZp > intMax)
300 return rewriter.notifyMatchFailure(
301 op, "tosa.depthwise_conv op quantization has zp outside of input "
302 "range");
303
304 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
305 }
306
307 llvm::SmallVector<int64_t> pad;
308 pad.resize(2, 0);
309 getValuesFromIntArrayAttribute(padAttr, pad);
310 pad.resize(pad.size() + 2, 0);
311
312 input = applyPad(loc, input, pad, zeroAttr, rewriter);
313
314 // Extract the attributes for convolution.
315 llvm::SmallVector<int64_t> stride, dilation;
316 getValuesFromIntArrayAttribute(strideTosaAttr, stride);
317 getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
318
319 // Create the convolution op.
320 auto strideAttr = DenseIntElementsAttr::get(
321 RankedTensorType::get({2}, rewriter.getI64Type()), stride);
322 auto dilationAttr = DenseIntElementsAttr::get(
323 RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
324 ShapedType linalgConvTy =
325 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
326 weightShape[2], weightShape[3]},
327 resultETy);
328
329 // Broadcast the initial value to the output tensor before convolving.
330 SmallVector<AffineMap, 4> indexingMaps;
331 indexingMaps.push_back(AffineMap::get(
332 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
333 {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
334 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
335 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
336
337 Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
338 Value initTensor = rewriter.create<linalg::InitTensorOp>(
natashaknk310e9632022-01-12 22:10:27339 loc, dynamicDims, linalgConvTy.getShape(), resultETy);
Rob Sudermanf0cb77d2021-12-24 00:25:53340 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
341 Value zeroTensor =
342 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
343
344 Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
natashaknk310e9632022-01-12 22:10:27345 loc, dynamicDims, resultTy.getShape(), resultETy);
Rob Sudermanf0cb77d2021-12-24 00:25:53346 if (!isQuantized) {
347 Value conv = rewriter
348 .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
349 loc, linalgConvTy, ValueRange{input, weight},
350 ValueRange{zeroTensor}, strideAttr, dilationAttr)
351 .getResult(0);
352 Value convReshape = rewriter.create<tosa::ReshapeOp>(
353 loc, resultTy, conv, rewriter.getI64ArrayAttr(resultTy.getShape()));
354 Value result =
355 rewriter
356 .create<linalg::GenericOp>(
357 loc, resultTy, ValueRange({bias, convReshape}),
358 biasInitTensor, indexingMaps,
359 getNParallelLoopsAttrs(resultTy.getRank()),
360 [&](OpBuilder &nestedBuilder, Location nestedLoc,
361 ValueRange args) {
362 Value added = nestedBuilder.create<arith::AddFOp>(
363 loc, args[0], args[1]);
364 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
365 })
366 .getResult(0);
367 rewriter.replaceOp(op, result);
368 } else {
369 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
370 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
371 Value conv =
372 rewriter
373 .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
374 loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
375 ValueRange{zeroTensor}, strideAttr, dilationAttr)
376 .getResult(0);
377 Value convReshape = rewriter.create<tosa::ReshapeOp>(
378 loc, resultTy, conv, rewriter.getI64ArrayAttr(resultTy.getShape()));
379 Value result =
380 rewriter
381 .create<linalg::GenericOp>(
382 loc, resultTy, ValueRange({bias, convReshape}),
383 biasInitTensor, indexingMaps,
384 getNParallelLoopsAttrs(resultTy.getRank()),
385 [&](OpBuilder &nestedBuilder, Location nestedLoc,
386 ValueRange args) {
387 Value added = nestedBuilder.create<arith::AddIOp>(
388 loc, args[0], args[1]);
389 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
390 })
391 .getResult(0);
392 rewriter.replaceOp(op, result);
393 }
394 return success();
395 }
396};
397
398class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
399public:
400 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
401 LogicalResult
402 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
403 ConversionPatternRewriter &rewriter) const final {
404 Location loc = op.getLoc();
405
406 auto outputTy = op.getType().cast<ShapedType>();
407 auto outputElementTy = outputTy.getElementType();
408
409 auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
410 auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
411
412 SmallVector<Value> dynDims;
413 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
414
415 if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
416 dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
417 }
418
419 if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
420 dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
421 }
422
423 if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
424 dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
425 }
426
natashaknk310e9632022-01-12 22:10:27427 SmallVector<Value> filteredDims = condenseValues(dynDims);
Rob Sudermanf0cb77d2021-12-24 00:25:53428
429 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
430 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
431 auto initTensor = rewriter.create<linalg::InitTensorOp>(
432 loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
433 Value zeroTensor =
434 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
435 if (!op.quantization_info()) {
436 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
437 op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
438 ValueRange{zeroTensor});
439 return success();
440 }
441
442 auto quantizationInfo = op.quantization_info().getValue();
443 auto aZp = rewriter.create<arith::ConstantOp>(
444 loc, rewriter.getI32IntegerAttr(
445 quantizationInfo.a_zp().getValue().getSExtValue()));
446 auto bZp = rewriter.create<arith::ConstantOp>(
447 loc, rewriter.getI32IntegerAttr(
448 quantizationInfo.b_zp().getValue().getSExtValue()));
449 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
450 op, TypeRange{op.getType()},
451 ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
452
453 return success();
454 }
455};
456
457class FullyConnectedConverter
458 : public OpConversionPattern<tosa::FullyConnectedOp> {
459public:
460 using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
461 LogicalResult
462 matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
463 ConversionPatternRewriter &rewriter) const final {
464 Location loc = op.getLoc();
465 auto outputTy = op.getType().cast<ShapedType>();
466 auto input = op.input();
467 auto inputTy = input.getType().cast<ShapedType>();
468
469 auto bias = op.bias();
470
471 auto weight = op.weight();
472 auto weightTy = weight.getType().cast<ShapedType>();
473 auto weightShape = weightTy.getShape();
474
475 auto outputETy = outputTy.getElementType();
476
477 SmallVector<Value> dynDims;
478 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
479
480 if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
481 dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
482 }
483
484 if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
485 dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
486 }
487
natashaknk310e9632022-01-12 22:10:27488 SmallVector<Value> filteredDims = condenseValues(dynDims);
Rob Sudermanf0cb77d2021-12-24 00:25:53489
490 // Creating maps for the output of MatMul and the bias
491 SmallVector<AffineMap, 4> indexingMaps;
492
493 // Broadcast the bias.
494 indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
495 {rewriter.getAffineDimExpr(1)},
496 rewriter.getContext()));
497
498 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
499 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
500
501 auto initTensor = rewriter.create<linalg::InitTensorOp>(
502 loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
503
504 // When quantized, the input elemeny type is not the same as the output
505 Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
506 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
507 Value zeroTensor =
508 rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
509
510 SmallVector<int64_t> permutation{1, 0};
511 auto permutationAttr = DenseIntElementsAttr::get(
512 RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
513 Value permutationValue =
514 rewriter.create<arith::ConstantOp>(loc, permutationAttr);
515
516 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
517 Type newWeightTy =
518 RankedTensorType::get(newWeightShape, weightTy.getElementType());
519
520 Value transposedWeight = rewriter.create<tosa::TransposeOp>(
521 loc, newWeightTy, weight, permutationValue);
522
523 auto biasInitTensor =
524 rewriter
525 .create<linalg::InitTensorOp>(loc, filteredDims,
526 outputTy.getShape(), outputETy)
527 ->getResults();
528
529 if (!op.quantization_info()) {
530 Value matmul = rewriter
531 .create<linalg::MatmulOp>(
532 loc, TypeRange{op.getType()},
533 ValueRange{input, transposedWeight}, zeroTensor)
534 ->getResult(0);
535
536 Value result =
537 rewriter
538 .create<linalg::GenericOp>(
539 loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
540 indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
541 [&](OpBuilder &nestedBuilder, Location nestedLoc,
542 ValueRange args) {
543 Value added = nestedBuilder.create<arith::AddFOp>(
544 loc, args[0], args[1]);
545 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
546 })
547 .getResult(0);
548 rewriter.replaceOp(op, result);
549 return success();
550 }
551
552 auto quantizationInfo = op.quantization_info().getValue();
553 auto inputZp = rewriter.create<arith::ConstantOp>(
554 loc, rewriter.getI32IntegerAttr(
555 quantizationInfo.input_zp().getValue().getSExtValue()));
556 auto outputZp = rewriter.create<arith::ConstantOp>(
557 loc, rewriter.getI32IntegerAttr(
558 quantizationInfo.weight_zp().getValue().getSExtValue()));
559 Value matmul =
560 rewriter
561 .create<linalg::QuantizedMatmulOp>(
562 loc, TypeRange{op.getType()},
563 ValueRange{input, transposedWeight, inputZp, outputZp},
564 zeroTensor)
565 ->getResult(0);
566 Value result =
567 rewriter
568 .create<linalg::GenericOp>(
569 loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
570 indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
571 [&](OpBuilder &nestedBuilder, Location nestedLoc,
572 ValueRange args) {
573 Value added = nestedBuilder.create<arith::AddIOp>(
574 loc, args[0], args[1]);
575 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
576 })
577 .getResult(0);
578 rewriter.replaceOp(op, result);
579 return success();
580 }
581};
582
583class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
584public:
585 using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
586
587 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
588 PatternRewriter &rewriter) const final {
589 Location loc = op.getLoc();
590 Value input = op.input();
591 ShapedType inputTy = input.getType().cast<ShapedType>();
592
593 ShapedType resultTy = op.getType().template cast<ShapedType>();
594 Type resultETy = inputTy.getElementType();
595
natashaknk310e9632022-01-12 22:10:27596 auto dynamicDimsOr =
597 checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
598 if (!dynamicDimsOr.hasValue())
Rob Sudermanf0cb77d2021-12-24 00:25:53599 return failure();
natashaknk310e9632022-01-12 22:10:27600 SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
Rob Sudermanf0cb77d2021-12-24 00:25:53601
602 // Determine what the initial value needs to be for the max pool op.
603 Attribute initialAttr;
604 if (resultETy.isF32())
605 initialAttr = rewriter.getFloatAttr(
606 resultETy,
607 APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
608 true));
609
610 if (resultETy.isa<IntegerType>())
611 initialAttr = rewriter.getIntegerAttr(
612 resultETy,
613 APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
614
615 if (!initialAttr)
616 return rewriter.notifyMatchFailure(
617 op, "Unsupported initial value for tosa.maxpool_2d op");
618
619 // Apply padding as necessary.
620 llvm::SmallVector<int64_t> pad;
621 pad.resize(2, 0);
622 getValuesFromIntArrayAttribute(op.pad(), pad);
623 pad.resize(pad.size() + 2, 0);
624 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
625
626 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
627
628 SmallVector<int64_t> kernel, stride;
629 getValuesFromIntArrayAttribute(op.kernel(), kernel);
630 getValuesFromIntArrayAttribute(op.stride(), stride);
631
632 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
633 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
634
635 // Create the linalg op that performs pooling.
636 Value initTensor = rewriter.create<linalg::InitTensorOp>(
natashaknk310e9632022-01-12 22:10:27637 loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
Rob Sudermanf0cb77d2021-12-24 00:25:53638
639 Value filledInitTensor =
640 rewriter.create<linalg::FillOp>(loc, initialValue, initTensor).result();
641
642 Value fakeWindowDims =
643 rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
644
645 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
646 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
647 filledInitTensor, strideAttr, dilationAttr);
648 return success();
649 }
650};
651
652class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
653public:
654 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
655
656 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
657 PatternRewriter &rewriter) const final {
658 Location loc = op.getLoc();
659 Value input = op.input();
660 ShapedType inputTy = input.getType().cast<ShapedType>();
661 Type inElementTy = inputTy.getElementType();
662
663 ShapedType resultTy = op.getType().template cast<ShapedType>();
664 Type resultETy = op.getType().cast<ShapedType>().getElementType();
665
666 Type accETy =
667 inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
668 ShapedType accTy = resultTy.clone(accETy);
669
natashaknk310e9632022-01-12 22:10:27670 auto dynamicDimsOr =
671 checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
672 if (!dynamicDimsOr.hasValue())
Rob Sudermanf0cb77d2021-12-24 00:25:53673 return failure();
natashaknk310e9632022-01-12 22:10:27674 SmallVector<Value> dynamicDims = dynamicDimsOr.getValue();
Rob Sudermanf0cb77d2021-12-24 00:25:53675
676 // Apply padding as necessary.
677 llvm::SmallVector<int64_t> pad;
678 pad.resize(2, 0);
679 getValuesFromIntArrayAttribute(op.pad(), pad);
680 pad.resize(pad.size() + 2, 0);
681 Attribute padAttr = rewriter.getZeroAttr(inElementTy);
682 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
683
684 Attribute initialAttr = rewriter.getZeroAttr(accETy);
685 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
686
687 SmallVector<int64_t> kernel, stride;
688 getValuesFromIntArrayAttribute(op.kernel(), kernel);
689 getValuesFromIntArrayAttribute(op.stride(), stride);
690
691 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
692 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
693
694 // Create the linalg op that performs pooling.
natashaknk310e9632022-01-12 22:10:27695 Value poolInitTensor = rewriter.create<linalg::InitTensorOp>(
696 loc, dynamicDims, accTy.getShape(), accETy);
Rob Sudermanf0cb77d2021-12-24 00:25:53697
698 Value filledInitTensor =
699 rewriter.create<linalg::FillOp>(loc, initialValue, poolInitTensor)
700 .result();
701
702 Value fakeWindowDims =
703 rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
704
705 // Sum across the pooled region.
706 Value poolingOp = rewriter
707 .create<linalg::PoolingNhwcSumOp>(
708 loc, ArrayRef<Type>{accTy},
709 ValueRange{paddedInput, fakeWindowDims},
710 filledInitTensor, strideAttr, dilationAttr)
711 .getResult(0);
712
713 // Normalize the summed value by the number of elements grouped in each
714 // pool.
715 auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
716 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
717
718 Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
natashaknk310e9632022-01-12 22:10:27719 loc, dynamicDims, resultTy.getShape(), resultETy);
Rob Sudermanf0cb77d2021-12-24 00:25:53720
721 auto genericOp = rewriter.create<linalg::GenericOp>(
722 loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
723 ValueRange{genericInitTensor},
724 ArrayRef<AffineMap>({affineMap, affineMap}),
725 getNParallelLoopsAttrs(resultTy.getRank()),
726 [&](OpBuilder &b, Location loc, ValueRange args) {
727 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
728 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
729 auto iH = rewriter.create<arith::ConstantIndexOp>(
730 loc, poolingOpTy.getDimSize(1) - 1);
731 auto iW = rewriter.create<arith::ConstantIndexOp>(
732 loc, poolingOpTy.getDimSize(2) - 1);
733
734 // Compute the indices from either end.
735 auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
736 auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
737 auto y1 = rewriter.create<arith::SubIOp>(loc, iH, y0);
738 auto x1 = rewriter.create<arith::SubIOp>(loc, iW, x0);
739
740 // Determines what the portion of valid input is covered by the
741 // kernel.
742 auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
743 if (pad == 0)
744 return v;
745
746 auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
747 Value dx = rewriter.create<arith::SubIOp>(loc, x, padVal);
748
749 Value cmp = rewriter.create<arith::CmpIOp>(
750 loc, arith::CmpIPredicate::slt, dx, zero);
River Riddledec8af72022-01-31 20:44:35751 Value offset = rewriter.create<arith::SelectOp>(loc, cmp, dx, zero);
Rob Sudermanf0cb77d2021-12-24 00:25:53752 return rewriter.create<arith::AddIOp>(loc, v, offset)->getResult(0);
753 };
754
755 // Compute the vertical component of coverage.
756 auto kH0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[0]);
757 auto kH1 = padFn(kH0, y0, pad[2]);
758 auto kH2 = padFn(kH1, y1, pad[3]);
759 auto kHCmp = rewriter.create<arith::CmpIOp>(
760 loc, arith::CmpIPredicate::slt, kH2, one);
River Riddledec8af72022-01-31 20:44:35761 auto kH3 = rewriter.create<arith::SelectOp>(loc, kHCmp, one, kH2);
Rob Sudermanf0cb77d2021-12-24 00:25:53762
763 // compute the horizontal component of coverage.
764 auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
765 auto kW1 = padFn(kW0, x0, pad[4]);
766 auto kW2 = padFn(kW1, x1, pad[5]);
767 auto kWCmp = rewriter.create<arith::CmpIOp>(
768 loc, arith::CmpIPredicate::slt, kW2, one);
River Riddledec8af72022-01-31 20:44:35769 auto kW3 = rewriter.create<arith::SelectOp>(loc, kWCmp, one, kW2);
Rob Sudermanf0cb77d2021-12-24 00:25:53770
771 // Compute the total number of elements and normalize.
772 Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);
773 auto countI = rewriter.create<arith::IndexCastOp>(
774 loc, rewriter.getI32Type(), count);
775
776 // Divide by the number of summed values. For floats this is just
777 // a div however for quantized values input normalization had
778 // to be applied.
779 Value poolVal = args[0];
780 if (accETy.isa<FloatType>()) {
781 auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, countI);
782 poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
783 ->getResult(0);
784 } else {
785
786 // If we have quantization information we need to apply an offset
787 // for the input zp value.
788 if (op.quantization_info()) {
789 auto quantizationInfo = op.quantization_info().getValue();
790 auto inputZp = rewriter.create<arith::ConstantOp>(
791 loc, quantizationInfo.input_zp());
792 Value offset =
793 rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
794 poolVal =
795 rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
796 }
797
798 // Compute the multiplier and shift values for the quantization
799 // normalization. Preferably we would want to compute more bits
800 // however 32-bits should be enough for compute. Honestly we
801 // should probably straight divide.
802 int64_t numerator = ((1 << 30) + 1);
803 int64_t shift = 30;
804
805 Value numeratorVal = rewriter.create<arith::ConstantOp>(
806 loc, rewriter.getI32IntegerAttr(numerator));
807 Value multiplierVal =
808 rewriter
809 .create<arith::DivUIOp>(loc, rewriter.getI32Type(),
810 numeratorVal, countI)
811 .getResult();
812 Value shiftVal = rewriter.create<arith::ConstantOp>(
813 loc, rewriter.getI8IntegerAttr(shift));
814
815 auto scaled =
816 rewriter
817 .create<tosa::ApplyScaleOp>(
818 loc, rewriter.getI32Type(), poolVal, multiplierVal,
819 shiftVal, rewriter.getBoolAttr(false))
820 .getResult();
821
822 // If we have quantization information we need to apply output
823 // zeropoint.
824 if (op.quantization_info()) {
825 auto quantizationInfo = op.quantization_info().getValue();
826 auto outputZp = rewriter.create<arith::ConstantOp>(
827 loc, quantizationInfo.output_zp());
828 scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
829 .getResult();
830 }
831
832 // Apply Clip.
833 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
834
835 auto min = rewriter.create<arith::ConstantIntOp>(
836 loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
837 accETy);
838 auto max = rewriter.create<arith::ConstantIntOp>(
839 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
840 accETy);
841 auto clamp = clampHelper<arith::CmpIOp>(
842 loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter);
843
844 poolVal = clamp;
845 // Convert type.
846 if (resultETy != clamp.getType()) {
847 poolVal =
848 rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
849 }
850 }
851
852 rewriter.create<linalg::YieldOp>(loc, poolVal);
853 });
854
855 rewriter.replaceOp(op, genericOp.getResult(0));
856 return success();
857 }
858};
859
860} // namespace
861
862void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
863 RewritePatternSet *patterns) {
864 patterns->add<
865 // clang-format off
866 ConvConverter,
867 DepthwiseConvConverter,
868 MatMulConverter,
869 MaxPool2dConverter,
870 AvgPool2dConverter,
871 FullyConnectedConverter>(patterns->getContext());
872 // clang-format on
873}