diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 95364c26d1a7d..0b69cd2814fb9 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -82,15 +82,6 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, rhsOrResult); } -template -static arith::ConstantOp -createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType, - OpBuilder &rewriter) { - auto castedN = static_cast(zp); - return rewriter.create( - op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); -} - static Value createLinalgBodyCalculationForElementwiseOp( Operation *op, ValueRange args, ArrayRef resultTypes, ConversionPatternRewriter &rewriter) { @@ -1467,11 +1458,6 @@ class RescaleConverter : public OpRewritePattern { Value value = blockArgs[0]; Type valueTy = value.getType(); - // For now we do all of our math in 64-bit. This is not optimal but - // should be correct for now, consider computing correct bit depth - // later. - int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32; - FailureOr maybeIZp = op.getInputZeroPoint(); if (failed(maybeIZp)) { (void)rewriter.notifyMatchFailure( @@ -1479,9 +1465,12 @@ class RescaleConverter : public OpRewritePattern { return; } - auto inputZp = createConstOpFromZpVal( - op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth), - nestedBuilder); + const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth(); + // Extend zeropoint for sub-32bits widths. + const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32; + auto inputZp = nestedBuilder.create( + loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth), + *maybeIZp)); FailureOr maybeOZp = op.getOutputZeroPoint(); if (failed(maybeOZp)) { @@ -1490,16 +1479,14 @@ class RescaleConverter : public OpRewritePattern { return; }; - // pre-process OutputZP as it can be unsigned - auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth(); - APInt OZp(outBitwidth, !op.getOutputUnsigned()); - OZp = static_cast(*maybeOZp); - *maybeOZp = op.getOutputUnsigned() - ? static_cast(OZp.getZExtValue()) - : OZp.getSExtValue(); - - auto outputZp = createConstOpFromZpVal( - op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder); + IntegerType outIntType = + cast(blockArgs.back().getType()); + unsigned outBitWidth = outIntType.getWidth(); + const int32_t outAttrBitwidth = 32; + assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth"); + auto outputZp = nestedBuilder.create( + loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth), + *maybeOZp)); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; @@ -1527,10 +1514,6 @@ class RescaleConverter : public OpRewritePattern { nestedBuilder.create(nestedLoc, value, outputZp); // Saturate to the output size. - IntegerType outIntType = - cast(blockArgs.back().getType()); - unsigned outBitWidth = outIntType.getWidth(); - int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index de06b621cbe3d..2f9c6d7870782 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -2118,7 +2118,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() { // return failure if val is not a constant // set zp to -1 if val is non-zero float or val is not integer nor float // otherwise set zp to val's constant value -static FailureOr getZeroPoint(Value val) { +static FailureOr getZeroPoint(Value val, bool signExtend) { ElementsAttr zpAttr; if (!matchPattern(val, m_Constant(&zpAttr))) { return failure(); @@ -2135,7 +2135,10 @@ static FailureOr getZeroPoint(Value val) { } if (llvm::isa(zpElemType)) { - return zpAttr.getValues()[0].getSExtValue(); + if (signExtend) + return zpAttr.getValues()[0].getSExtValue(); + else + return zpAttr.getValues()[0].getZExtValue(); } // return non-zero value to trigger error check @@ -2186,30 +2189,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal, return success(); } -#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \ +#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \ FailureOr tosa::OP::get##OPERAND_NAME##ZeroPoint() { \ - return getZeroPoint(get##OPERAND_NAME##Zp()); \ + return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \ } \ LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \ return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \ } -ZERO_POINT_HELPER(Conv2DOp, Input) -ZERO_POINT_HELPER(Conv2DOp, Weight) -ZERO_POINT_HELPER(Conv3DOp, Input) -ZERO_POINT_HELPER(Conv3DOp, Weight) -ZERO_POINT_HELPER(DepthwiseConv2DOp, Input) -ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight) -ZERO_POINT_HELPER(TransposeConv2DOp, Input) -ZERO_POINT_HELPER(TransposeConv2DOp, Weight) -ZERO_POINT_HELPER(AvgPool2dOp, Input) -ZERO_POINT_HELPER(AvgPool2dOp, Output) -ZERO_POINT_HELPER(MatMulOp, A) -ZERO_POINT_HELPER(MatMulOp, B) -ZERO_POINT_HELPER(NegateOp, Input1) -ZERO_POINT_HELPER(NegateOp, Output) -ZERO_POINT_HELPER(RescaleOp, Input) -ZERO_POINT_HELPER(RescaleOp, Output) +ZERO_POINT_HELPER(Conv2DOp, Input, true) +ZERO_POINT_HELPER(Conv2DOp, Weight, true) +ZERO_POINT_HELPER(Conv3DOp, Input, true) +ZERO_POINT_HELPER(Conv3DOp, Weight, true) +ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true) +ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true) +ZERO_POINT_HELPER(TransposeConv2DOp, Input, true) +ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true) +ZERO_POINT_HELPER(AvgPool2dOp, Input, true) +ZERO_POINT_HELPER(AvgPool2dOp, Output, true) +ZERO_POINT_HELPER(MatMulOp, A, true) +ZERO_POINT_HELPER(MatMulOp, B, true) +ZERO_POINT_HELPER(NegateOp, Input1, true) +ZERO_POINT_HELPER(NegateOp, Output, true) +ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned()) +ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned()) #undef ZERO_POINT_HELPER LogicalResult tosa::TransposeOp::inferReturnTypeComponents( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 7083d19f4372a..185f1973ecdc6 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () { // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): - // CHECK: [[C17:%.+]] = arith.constant 17 + // CHECK: [[C128:%.+]] = arith.constant 128 // CHECK: [[C22:%.+]] = arith.constant 22 // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]] - // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]] // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"} // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]] // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128 @@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () { // CHECK: linalg.yield [[TRUNC]] %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16> %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8> - %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8> + %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8> %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8> %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8> return } +// ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @rescale_i48_unsigned_output +// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]: +func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () { + // CHECK: [[C19689:%.+]] = arith.constant 19689 + // CHECK: [[C15:%.+]] = arith.constant 15 + // CHECK: [[INIT:%.+]] = tensor.empty() + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>) + // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8): + // CHECK: [[C0:%.+]] = arith.constant 0 + // CHECK: [[C234:%.+]] = arith.constant 234 + // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]] + // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]] + // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0 + // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255 + // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]] + // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] + // CHECK: linalg.yield [[TRUNC]] + %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16> + %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8> + %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48> + %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8> + + // CHECK: return + return +} + // ----- // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 9ccb310c4491d..56d76585be71b 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1517,7 +1517,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8> %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16> %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16> - // expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}} + // expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got 65535}} %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16> return %0 : tensor<13x21x3xi16> }