diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 1b990e29158fd..87b6914f8a0ee 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -924,6 +924,22 @@ class SelectionDAG { /// Example: shuffle A, B, <0,5,2,7> -> shuffle B, A, <4,1,6,3> SDValue getCommutedVectorShuffle(const ShuffleVectorSDNode &SV); + /// Extract element at \p Idx from \p Vec. See EXTRACT_VECTOR_ELT + /// description for result type handling. + SDValue getExtractVectorElt(const SDLoc &DL, EVT VT, SDValue Vec, + unsigned Idx) { + return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Vec, + getVectorIdxConstant(Idx, DL)); + } + + /// Insert \p Elt into \p Vec at offset \p Idx. See INSERT_VECTOR_ELT + /// description for element type handling. + SDValue getInsertVectorElt(const SDLoc &DL, SDValue Vec, SDValue Elt, + unsigned Idx) { + return getNode(ISD::INSERT_VECTOR_ELT, DL, Vec.getValueType(), Vec, Elt, + getVectorIdxConstant(Idx, DL)); + } + /// Insert \p SubVec at the \p Idx element of \p Vec. SDValue getInsertSubvector(const SDLoc &DL, SDValue Vec, SDValue SubVec, unsigned Idx) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index effe08cdd44f8..bbf1b0fd590ef 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3244,8 +3244,7 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) { if (LegalSVT.bitsLT(SVT)) return SDValue(); } - return getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), LegalSVT, SrcVector, - getVectorIdxConstant(SplatIdx, SDLoc(V))); + return getExtractVectorElt(SDLoc(V), LegalSVT, SrcVector, SplatIdx); } return SDValue(); } @@ -7557,11 +7556,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, // elements. if (N2C && N1.getOpcode() == ISD::CONCAT_VECTORS && N1.getOperand(0).getValueType().isFixedLengthVector()) { - unsigned Factor = - N1.getOperand(0).getValueType().getVectorNumElements(); - return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, - N1.getOperand(N2C->getZExtValue() / Factor), - getVectorIdxConstant(N2C->getZExtValue() % Factor, DL)); + unsigned Factor = N1.getOperand(0).getValueType().getVectorNumElements(); + return getExtractVectorElt(DL, VT, + N1.getOperand(N2C->getZExtValue() / Factor), + N2C->getZExtValue() % Factor); } // EXTRACT_VECTOR_ELT of BUILD_VECTOR or SPLAT_VECTOR is often formed while @@ -8624,8 +8622,7 @@ static SDValue getMemsetStores(SelectionDAG &DAG, const SDLoc &dl, // Target which can combine store(extractelement VectorTy, Idx) can get // the smaller value for free. SDValue TailValue = DAG.getNode(ISD::BITCAST, dl, SVT, MemSetValue); - Value = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, TailValue, - DAG.getVectorIdxConstant(Index, dl)); + Value = DAG.getExtractVectorElt(dl, VT, TailValue, Index); } else Value = getMemsetValue(Src, VT, DAG, dl); } @@ -12775,8 +12772,7 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) { // A vector operand; extract a single element. EVT OperandEltVT = OperandVT.getVectorElementType(); - Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT, - Operand, getVectorIdxConstant(i, dl)); + Operands[j] = getExtractVectorElt(dl, OperandEltVT, Operand, i); } SDValue EltOp = getNode(N->getOpcode(), dl, {EltVT, EltVT1}, Operands); @@ -12810,8 +12806,7 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) { if (OperandVT.isVector()) { // A vector operand; extract a single element. EVT OperandEltVT = OperandVT.getVectorElementType(); - Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT, - Operand, getVectorIdxConstant(i, dl)); + Operands[j] = getExtractVectorElt(dl, OperandEltVT, Operand, i); } else { // A scalar operand; just use it as is. Operands[j] = Operand; @@ -13090,8 +13085,7 @@ void SelectionDAG::ExtractVectorElements(SDValue Op, EltVT = VT.getVectorElementType(); SDLoc SL(Op); for (unsigned i = Start, e = Start + Count; i != e; ++i) { - Args.push_back(getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Op, - getVectorIdxConstant(i, SL))); + Args.push_back(getExtractVectorElt(SL, EltVT, Op, i)); } } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index d65e921dfc660..68f01d1d675b7 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -3805,8 +3805,7 @@ static SDValue lowerBuildVectorViaDominantValues(SDValue Op, SelectionDAG &DAG, if (V.isUndef() || !Processed.insert(V).second) continue; if (ValueCounts[V] == 1) { - Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V, - DAG.getVectorIdxConstant(OpIdx.index(), DL)); + Vec = DAG.getInsertVectorElt(DL, Vec, V, OpIdx.index()); } else { // Blend in all instances of this value using a VSELECT, using a // mask where each bit signals whether that element is the one @@ -3963,10 +3962,9 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG, if (ViaIntVT == MVT::i32) SplatValue = SignExtend64<32>(SplatValue); - SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ViaVecVT, - DAG.getUNDEF(ViaVecVT), - DAG.getSignedConstant(SplatValue, DL, XLenVT), - DAG.getVectorIdxConstant(0, DL)); + SDValue Vec = DAG.getInsertVectorElt( + DL, DAG.getUNDEF(ViaVecVT), + DAG.getSignedConstant(SplatValue, DL, XLenVT), 0); if (ViaVecLen != 1) Vec = DAG.getExtractSubvector(DL, MVT::getVectorVT(ViaIntVT, 1), Vec, 0); return DAG.getBitcast(VT, Vec); @@ -7180,9 +7178,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, EVT BVT = EVT::getVectorVT(*DAG.getContext(), Op0VT, 1); if (!isTypeLegal(BVT)) return SDValue(); - return DAG.getBitcast(VT, DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, BVT, - DAG.getUNDEF(BVT), Op0, - DAG.getVectorIdxConstant(0, DL))); + return DAG.getBitcast( + VT, DAG.getInsertVectorElt(DL, DAG.getUNDEF(BVT), Op0, 0)); } return SDValue(); } @@ -7194,8 +7191,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, if (!isTypeLegal(BVT)) return SDValue(); SDValue BVec = DAG.getBitcast(BVT, Op0); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec, - DAG.getVectorIdxConstant(0, DL)); + return DAG.getExtractVectorElt(DL, VT, BVec, 0); } return SDValue(); } @@ -9916,8 +9912,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op, if (!EltVT.isInteger()) { // Floating-point extracts are handled in TableGen. - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec, - DAG.getVectorIdxConstant(0, DL)); + return DAG.getExtractVectorElt(DL, EltVT, Vec, 0); } SDValue Elt0 = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec); @@ -10321,8 +10316,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res); } case Intrinsic::riscv_vfmv_f_s: - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(), - Op.getOperand(1), DAG.getVectorIdxConstant(0, DL)); + return DAG.getExtractVectorElt(DL, Op.getValueType(), Op.getOperand(1), 0); case Intrinsic::riscv_vmv_v_x: return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2), Op.getOperand(3), Op.getSimpleValueType(), DL, DAG, @@ -10856,8 +10850,7 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT, SDValue Policy = DAG.getTargetConstant(RISCVVType::TAIL_AGNOSTIC, DL, XLenVT); SDValue Ops[] = {PassThru, Vec, InitialValue, Mask, VL, Policy}; SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, Ops); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction, - DAG.getVectorIdxConstant(0, DL)); + return DAG.getExtractVectorElt(DL, ResVT, Reduction, 0); } SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op, @@ -10902,8 +10895,7 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op, case ISD::UMIN: case ISD::SMAX: case ISD::SMIN: - StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec, - DAG.getVectorIdxConstant(0, DL)); + StartV = DAG.getExtractVectorElt(DL, VecEltVT, Vec, 0); } return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec, Mask, VL, DL, DAG, Subtarget); @@ -10934,9 +10926,7 @@ getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT, case ISD::VECREDUCE_FMAXIMUM: case ISD::VECREDUCE_FMIN: case ISD::VECREDUCE_FMAX: { - SDValue Front = - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0), - DAG.getVectorIdxConstant(0, DL)); + SDValue Front = DAG.getExtractVectorElt(DL, EltVT, Op.getOperand(0), 0); unsigned RVVOpc = (Opcode == ISD::VECREDUCE_FMIN || Opcode == ISD::VECREDUCE_FMINIMUM) ? RISCVISD::VECREDUCE_FMIN_VL @@ -14055,8 +14045,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N, EVT BVT = EVT::getVectorVT(*DAG.getContext(), VT, 1); if (isTypeLegal(BVT)) { SDValue BVec = DAG.getBitcast(BVT, Op0); - Results.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec, - DAG.getVectorIdxConstant(0, DL))); + Results.push_back(DAG.getExtractVectorElt(DL, VT, BVec, 0)); } } break; @@ -18204,12 +18193,11 @@ static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG, if (ConcatVT.getVectorElementType() != InVal.getValueType()) return SDValue(); unsigned ConcatNumElts = ConcatVT.getVectorNumElements(); - SDValue NewIdx = DAG.getVectorIdxConstant(Elt % ConcatNumElts, DL); + unsigned NewIdx = Elt % ConcatNumElts; unsigned ConcatOpIdx = Elt / ConcatNumElts; SDValue ConcatOp = InVec.getOperand(ConcatOpIdx); - ConcatOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ConcatVT, - ConcatOp, InVal, NewIdx); + ConcatOp = DAG.getInsertVectorElt(DL, ConcatOp, InVal, NewIdx); SmallVector ConcatOps(InVec->ops()); ConcatOps[ConcatOpIdx] = ConcatOp;