diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 86f8873c135ef..698b951ad4928 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -6971,7 +6971,7 @@ static bool hasPassthruOp(unsigned Opcode) { Opcode <= RISCVISD::LAST_STRICTFP_OPCODE && "not a RISC-V target specific op"); static_assert( - RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 && + RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 && RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 && "adding target specific op should update this function"); if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL) @@ -6995,7 +6995,7 @@ static bool hasMaskOp(unsigned Opcode) { Opcode <= RISCVISD::LAST_STRICTFP_OPCODE && "not a RISC-V target specific op"); static_assert( - RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 && + RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 && RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 && "adding target specific op should update this function"); if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL) @@ -18101,6 +18101,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, DAG.getBuildVector(VT, DL, RHSOps)); } +static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1, + const SDLoc &DL, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc || + RISCVISD::VQDOTSU_VL == Opc); + MVT VT = Op0.getSimpleValueType(); + assert(VT == Op1.getSimpleValueType() && + VT.getVectorElementType() == MVT::i32); + + assert(VT.isFixedLengthVector()); + MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget); + SDValue Passthru = convertToScalableVector( + ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget); + Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget); + Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget); + + auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget); + const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC; + SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT()); + SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT, + {Op0, Op1, Passthru, Mask, VL, PolicyOp}); + return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget); +} + +static MVT getQDOTXResultType(MVT OpVT) { + ElementCount OpEC = OpVT.getVectorElementCount(); + assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8); + return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4)); +} + +static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL, + SelectionDAG &DAG, + const RISCVSubtarget &Subtarget, + const RISCVTargetLowering &TLI) { + // Note: We intentionally do not check the legality of the reduction type. + // We want to handle the m4/m8 *src* types, and thus need to let illegal + // intermediate types flow through here. + if (InVec.getValueType().getVectorElementType() != MVT::i32 || + !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4)) + return SDValue(); + + // reduce (zext a) <--> reduce (mul zext a. zext 1) + // reduce (sext a) <--> reduce (mul sext a. sext 1) + if (InVec.getOpcode() == ISD::ZERO_EXTEND || + InVec.getOpcode() == ISD::SIGN_EXTEND) { + SDValue A = InVec.getOperand(0); + if (A.getValueType().getVectorElementType() != MVT::i8 || + !TLI.isTypeLegal(A.getValueType())) + return SDValue(); + + MVT ResVT = getQDOTXResultType(A.getSimpleValueType()); + A = DAG.getBitcast(ResVT, A); + SDValue B = DAG.getConstant(0x01010101, DL, ResVT); + + bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND; + unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL; + return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget); + } + + // mul (sext, sext) -> vqdot + // mul (zext, zext) -> vqdotu + // mul (sext, zext) -> vqdotsu + // mul (zext, sext) -> vqdotsu (swapped) + // TODO: Improve .vx handling - we end up with a sub-vector insert + // which confuses the splat pattern matching. Also, match vqdotus.vx + if (InVec.getOpcode() != ISD::MUL) + return SDValue(); + + SDValue A = InVec.getOperand(0); + SDValue B = InVec.getOperand(1); + unsigned Opc = 0; + if (A.getOpcode() == B.getOpcode()) { + if (A.getOpcode() == ISD::SIGN_EXTEND) + Opc = RISCVISD::VQDOT_VL; + else if (A.getOpcode() == ISD::ZERO_EXTEND) + Opc = RISCVISD::VQDOTU_VL; + else + return SDValue(); + } else { + if (B.getOpcode() != ISD::ZERO_EXTEND) + std::swap(A, B); + if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND) + return SDValue(); + Opc = RISCVISD::VQDOTSU_VL; + } + assert(Opc); + + if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || + A.getOperand(0).getValueType() != B.getOperand(0).getValueType() || + !TLI.isTypeLegal(A.getValueType())) + return SDValue(); + + MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType()); + A = DAG.getBitcast(ResVT, A.getOperand(0)); + B = DAG.getBitcast(ResVT, B.getOperand(0)); + return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget); +} + +static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget, + const RISCVTargetLowering &TLI) { + if (!Subtarget.hasStdExtZvqdotq()) + return SDValue(); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue InVec = N->getOperand(0); + if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI)) + return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V); + return SDValue(); +} + static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget, const RISCVTargetLowering &TLI) { @@ -19878,8 +19990,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, return SDValue(); } - case ISD::CTPOP: case ISD::VECREDUCE_ADD: + if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this)) + return V; + [[fallthrough]]; + case ISD::CTPOP: if (SDValue V = combineToVCPOP(N, DAG, Subtarget)) return V; break; @@ -22401,6 +22516,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(RI_VUNZIP2A_VL) NODE_NAME_CASE(RI_VUNZIP2B_VL) NODE_NAME_CASE(RI_VEXTRACT) + NODE_NAME_CASE(VQDOT_VL) + NODE_NAME_CASE(VQDOTU_VL) + NODE_NAME_CASE(VQDOTSU_VL) NODE_NAME_CASE(READ_CSR) NODE_NAME_CASE(WRITE_CSR) NODE_NAME_CASE(SWAP_CSR) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index ba24a0c324f51..3f1fce5d9f7e5 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -416,7 +416,12 @@ enum NodeType : unsigned { RI_VUNZIP2A_VL, RI_VUNZIP2B_VL, - LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL, + // zvqdot instructions with additional passthru, mask and VL operands + VQDOT_VL, + VQDOTU_VL, + VQDOTSU_VL, + + LAST_VL_VECTOR_OP = VQDOTSU_VL, // XRivosVisni // VEXTRACT matches the semantics of ri.vextract.x.v. The result is always diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td index 205fffd5115ee..6018958f6eb27 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td @@ -26,3 +26,34 @@ let Predicates = [HasStdExtZvqdotq] in { def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">; def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">; } // Predicates = [HasStdExtZvqdotq] + + +def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>; +def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>; +def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>; + +multiclass VPseudoVQDOT_VV_VX { + foreach m = MxSet<32>.m in { + defm "" : VPseudoBinaryV_VV, + SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX, + forcePassthruRead=true>; + defm "" : VPseudoBinaryV_VX, + SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX, + forcePassthruRead=true>; + } +} + +// TODO: Add pseudo and patterns for vqdotus.vx +// TODO: Add isCommutable for VQDOT and VQDOTU +let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0, + hasSideEffects = 0 in { + defm PseudoVQDOT : VPseudoVQDOT_VV_VX; + defm PseudoVQDOTU : VPseudoVQDOT_VV_VX; + defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX; +} + +defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8]; +defm : VPatBinaryVL_VV_VX; +defm : VPatBinaryVL_VV_VX; +defm : VPatBinaryVL_VV_VX; + diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll index 25192ea19aab3..e48bc9cdfea4e 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll @@ -1,21 +1,31 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s -; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s -; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s -; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT +; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT +; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32 +; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64 define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) { -; CHECK-LABEL: vqdot_vv: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma -; CHECK-NEXT: vsext.vf2 v12, v8 -; CHECK-NEXT: vsext.vf2 v14, v9 -; CHECK-NEXT: vwmul.vv v8, v12, v14 -; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma -; CHECK-NEXT: vmv.s.x v12, zero -; CHECK-NEXT: vredsum.vs v8, v8, v12 -; CHECK-NEXT: vmv.x.s a0, v8 -; CHECK-NEXT: ret +; NODOT-LABEL: vqdot_vv: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma +; NODOT-NEXT: vsext.vf2 v12, v8 +; NODOT-NEXT: vsext.vf2 v14, v9 +; NODOT-NEXT: vwmul.vv v8, v12, v14 +; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma +; NODOT-NEXT: vmv.s.x v12, zero +; NODOT-NEXT: vredsum.vs v8, v8, v12 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdot_vv: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; DOT-NEXT: vmv.v.i v10, 0 +; DOT-NEXT: vqdot.vv v10, v8, v9 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v10, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret entry: %a.sext = sext <16 x i8> %a to <16 x i32> %b.sext = sext <16 x i8> %b to <16 x i32> @@ -63,17 +73,27 @@ entry: } define i32 @vqdotu_vv(<16 x i8> %a, <16 x i8> %b) { -; CHECK-LABEL: vqdotu_vv: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma -; CHECK-NEXT: vwmulu.vv v10, v8, v9 -; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma -; CHECK-NEXT: vmv.s.x v8, zero -; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma -; CHECK-NEXT: vwredsumu.vs v8, v10, v8 -; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma -; CHECK-NEXT: vmv.x.s a0, v8 -; CHECK-NEXT: ret +; NODOT-LABEL: vqdotu_vv: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma +; NODOT-NEXT: vwmulu.vv v10, v8, v9 +; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma +; NODOT-NEXT: vwredsumu.vs v8, v10, v8 +; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotu_vv: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; DOT-NEXT: vmv.v.i v10, 0 +; DOT-NEXT: vqdotu.vv v10, v8, v9 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v10, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret entry: %a.zext = zext <16 x i8> %a to <16 x i32> %b.zext = zext <16 x i8> %b to <16 x i32> @@ -102,17 +122,27 @@ entry: } define i32 @vqdotsu_vv(<16 x i8> %a, <16 x i8> %b) { -; CHECK-LABEL: vqdotsu_vv: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma -; CHECK-NEXT: vsext.vf2 v12, v8 -; CHECK-NEXT: vzext.vf2 v14, v9 -; CHECK-NEXT: vwmulsu.vv v8, v12, v14 -; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma -; CHECK-NEXT: vmv.s.x v12, zero -; CHECK-NEXT: vredsum.vs v8, v8, v12 -; CHECK-NEXT: vmv.x.s a0, v8 -; CHECK-NEXT: ret +; NODOT-LABEL: vqdotsu_vv: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma +; NODOT-NEXT: vsext.vf2 v12, v8 +; NODOT-NEXT: vzext.vf2 v14, v9 +; NODOT-NEXT: vwmulsu.vv v8, v12, v14 +; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma +; NODOT-NEXT: vmv.s.x v12, zero +; NODOT-NEXT: vredsum.vs v8, v8, v12 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotsu_vv: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; DOT-NEXT: vmv.v.i v10, 0 +; DOT-NEXT: vqdotsu.vv v10, v8, v9 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v10, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret entry: %a.sext = sext <16 x i8> %a to <16 x i32> %b.zext = zext <16 x i8> %b to <16 x i32> @@ -122,17 +152,27 @@ entry: } define i32 @vqdotsu_vv_swapped(<16 x i8> %a, <16 x i8> %b) { -; CHECK-LABEL: vqdotsu_vv_swapped: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma -; CHECK-NEXT: vsext.vf2 v12, v8 -; CHECK-NEXT: vzext.vf2 v14, v9 -; CHECK-NEXT: vwmulsu.vv v8, v12, v14 -; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma -; CHECK-NEXT: vmv.s.x v12, zero -; CHECK-NEXT: vredsum.vs v8, v8, v12 -; CHECK-NEXT: vmv.x.s a0, v8 -; CHECK-NEXT: ret +; NODOT-LABEL: vqdotsu_vv_swapped: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma +; NODOT-NEXT: vsext.vf2 v12, v8 +; NODOT-NEXT: vzext.vf2 v14, v9 +; NODOT-NEXT: vwmulsu.vv v8, v12, v14 +; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma +; NODOT-NEXT: vmv.s.x v12, zero +; NODOT-NEXT: vredsum.vs v8, v8, v12 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT-LABEL: vqdotsu_vv_swapped: +; DOT: # %bb.0: # %entry +; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; DOT-NEXT: vmv.v.i v10, 0 +; DOT-NEXT: vqdotsu.vv v10, v8, v9 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v10, v8 +; DOT-NEXT: vmv.x.s a0, v8 +; DOT-NEXT: ret entry: %a.sext = sext <16 x i8> %a to <16 x i32> %b.zext = zext <16 x i8> %b to <16 x i32> @@ -181,14 +221,38 @@ entry: } define i32 @reduce_of_sext(<16 x i8> %a) { -; CHECK-LABEL: reduce_of_sext: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vsext.vf4 v12, v8 -; CHECK-NEXT: vmv.s.x v8, zero -; CHECK-NEXT: vredsum.vs v8, v12, v8 -; CHECK-NEXT: vmv.x.s a0, v8 -; CHECK-NEXT: ret +; NODOT-LABEL: reduce_of_sext: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma +; NODOT-NEXT: vsext.vf4 v12, v8 +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vredsum.vs v8, v12, v8 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT32-LABEL: reduce_of_sext: +; DOT32: # %bb.0: # %entry +; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; DOT32-NEXT: vmv.v.i v9, 0 +; DOT32-NEXT: lui a0, 4112 +; DOT32-NEXT: addi a0, a0, 257 +; DOT32-NEXT: vqdot.vx v9, v8, a0 +; DOT32-NEXT: vmv.s.x v8, zero +; DOT32-NEXT: vredsum.vs v8, v9, v8 +; DOT32-NEXT: vmv.x.s a0, v8 +; DOT32-NEXT: ret +; +; DOT64-LABEL: reduce_of_sext: +; DOT64: # %bb.0: # %entry +; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; DOT64-NEXT: vmv.v.i v9, 0 +; DOT64-NEXT: lui a0, 4112 +; DOT64-NEXT: addiw a0, a0, 257 +; DOT64-NEXT: vqdot.vx v9, v8, a0 +; DOT64-NEXT: vmv.s.x v8, zero +; DOT64-NEXT: vredsum.vs v8, v9, v8 +; DOT64-NEXT: vmv.x.s a0, v8 +; DOT64-NEXT: ret entry: %a.ext = sext <16 x i8> %a to <16 x i32> %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext) @@ -196,14 +260,38 @@ entry: } define i32 @reduce_of_zext(<16 x i8> %a) { -; CHECK-LABEL: reduce_of_zext: -; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vzext.vf4 v12, v8 -; CHECK-NEXT: vmv.s.x v8, zero -; CHECK-NEXT: vredsum.vs v8, v12, v8 -; CHECK-NEXT: vmv.x.s a0, v8 -; CHECK-NEXT: ret +; NODOT-LABEL: reduce_of_zext: +; NODOT: # %bb.0: # %entry +; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma +; NODOT-NEXT: vzext.vf4 v12, v8 +; NODOT-NEXT: vmv.s.x v8, zero +; NODOT-NEXT: vredsum.vs v8, v12, v8 +; NODOT-NEXT: vmv.x.s a0, v8 +; NODOT-NEXT: ret +; +; DOT32-LABEL: reduce_of_zext: +; DOT32: # %bb.0: # %entry +; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; DOT32-NEXT: vmv.v.i v9, 0 +; DOT32-NEXT: lui a0, 4112 +; DOT32-NEXT: addi a0, a0, 257 +; DOT32-NEXT: vqdotu.vx v9, v8, a0 +; DOT32-NEXT: vmv.s.x v8, zero +; DOT32-NEXT: vredsum.vs v8, v9, v8 +; DOT32-NEXT: vmv.x.s a0, v8 +; DOT32-NEXT: ret +; +; DOT64-LABEL: reduce_of_zext: +; DOT64: # %bb.0: # %entry +; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma +; DOT64-NEXT: vmv.v.i v9, 0 +; DOT64-NEXT: lui a0, 4112 +; DOT64-NEXT: addiw a0, a0, 257 +; DOT64-NEXT: vqdotu.vx v9, v8, a0 +; DOT64-NEXT: vmv.s.x v8, zero +; DOT64-NEXT: vredsum.vs v8, v9, v8 +; DOT64-NEXT: vmv.x.s a0, v8 +; DOT64-NEXT: ret entry: %a.ext = zext <16 x i8> %a to <16 x i32> %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)