-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[RISCV] Extend zvqdot matching to handle reduction trees #138965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Now that we have matching for vqdot in it's basic variants, we can extend the matcher to handle reduction trees instead of individual reductions. This is important as we canonicalize reductions by performing a tree in the vector domain before the root reduction instruction. The particular approach taken here has the unfortunate implication that non-matches visit the entire reduction tree once for each time the reduction root is visited in DAG. While conceptually problematic for compile time, this is probably fine in practice as we should only visit the root once per pass of DAGCombine. I don't really see a better solution - suggestions welcome.
@llvm/pr-subscribers-backend-risc-v Author: Philip Reames (preames) ChangesNow that we have matching for vqdot in it's basic variants, we can extend the matcher to handle reduction trees instead of individual reductions. This is important as we canonicalize reductions by performing a tree in the vector domain before the root reduction instruction. The particular approach taken here has the unfortunate implication that non-matches visit the entire reduction tree once for each time the reduction root is visited in DAG. While conceptually problematic for compile time, this is probably fine in practice as we should only visit the root once per pass of DAGCombine. I don't really see a better solution - suggestions welcome. Full diff: https://github.com/llvm/llvm-project/pull/138965.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 756c563f0194d..a4d778c054ecf 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18131,6 +18131,29 @@ static MVT getQDOTXResultType(MVT OpVT) {
return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
}
+/// Given fixed length vectors A and B with equal element types, but possibly
+/// different number of elements, return A + B where either A or B is zero
+/// padded to the larger number of elements.
+static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
+ SelectionDAG &DAG) {
+ // NOTE: Manually doing the extract/add/insert scheme produces
+ // significantly better coegen than the naive pad with zeros
+ // and add scheme.
+ EVT AVT = A.getValueType();
+ EVT BVT = B.getValueType();
+ assert(AVT.getVectorElementType() == BVT.getVectorElementType());
+ if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
+ std::swap(A, B);
+ std::swap(AVT, BVT);
+ }
+
+ SDValue BPart = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, AVT, B,
+ DAG.getVectorIdxConstant(0, DL));
+ SDValue Res = DAG.getNode(ISD::ADD, DL, AVT, A, BPart);
+ return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, BVT, B, Res,
+ DAG.getVectorIdxConstant(0, DL));
+}
+
static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
@@ -18142,6 +18165,26 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
return SDValue();
+ // Recurse through adds (since generic dag canonicalizes to that
+ // form).
+ if (InVec->getOpcode() == ISD::ADD) {
+ SDValue A = InVec.getOperand(0);
+ SDValue B = InVec.getOperand(1);
+ SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
+ SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
+ if (AOpt || BOpt) {
+ if (AOpt)
+ A = AOpt;
+ if (BOpt)
+ B = BOpt;
+ // From here, we're doing A + B with mixed types, implicitly zero
+ // padded to the wider type. Note that we *don't* need the result
+ // type to be the original VT, and in fact prefer narrower ones
+ // if possible.
+ return getZeroPaddedAdd(DL, A, B, DAG);
+ }
+ }
+
// reduce (zext a) <--> reduce (mul zext a. zext 1)
// reduce (sext a) <--> reduce (mul sext a. sext 1)
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index edc9886abc3b9..e5546ad404c1b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -299,17 +299,31 @@ entry:
}
define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdot_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v10, v8
-; CHECK-NEXT: vsext.vf2 v16, v9
-; CHECK-NEXT: vwmacc.vv v12, v10, v16
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vsext.vf2 v16, v9
+; NODOT-NEXT: vwmacc.vv v12, v10, v16
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -320,17 +334,31 @@ entry:
}
define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotu_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT: vwmulu.vv v10, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: vwaddu.wv v12, v12, v10
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotu_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT: vwaddu.wv v12, v12, v10
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.zext = zext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -341,17 +369,31 @@ entry:
}
define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotsu_vv_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v10, v8
-; CHECK-NEXT: vzext.vf2 v16, v9
-; CHECK-NEXT: vwmaccsu.vv v12, v10, v16
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vzext.vf2 v16, v9
+; NODOT-NEXT: vwmaccsu.vv v12, v10, v16
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vadd.vv v8, v10, v12
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v8
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -455,20 +497,33 @@ entry:
}
define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
-; CHECK-LABEL: vqdot_vv_split:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vsext.vf2 v16, v10
-; CHECK-NEXT: vsext.vf2 v18, v11
-; CHECK-NEXT: vwmul.vv v8, v12, v14
-; CHECK-NEXT: vwmacc.vv v8, v16, v18
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv_split:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vsext.vf2 v16, v10
+; NODOT-NEXT: vsext.vf2 v18, v11
+; NODOT-NEXT: vwmul.vv v8, v12, v14
+; NODOT-NEXT: vwmacc.vv v8, v16, v18
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_split:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vmv.v.i v13, 0
+; DOT-NEXT: vqdot.vv v12, v8, v9
+; DOT-NEXT: vqdot.vv v13, v10, v11
+; DOT-NEXT: vadd.vv v8, v12, v13
+; DOT-NEXT: vmv.s.x v9, zero
+; DOT-NEXT: vredsum.vs v8, v8, v9
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
|
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma | ||
; DOT-NEXT: vmv.v.i v10, 0 | ||
; DOT-NEXT: vqdot.vv v10, v8, v9 | ||
; DOT-NEXT: vadd.vv v8, v10, v12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This vadd will fold into the accumulator of the vqdot.vv, but it'll require another couple patches first. I have to fix a DAG combine problem before the straight forward patch paralleling VWMUL works.
; DOT-NEXT: vqdot.vv v10, v8, v9 | ||
; DOT-NEXT: vadd.vv v8, v10, v12 | ||
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma | ||
; DOT-NEXT: vmv.v.v v12, v8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting missed optimization, but not specific to this example. a) this vmv.v.v could be done at m4 since we know that 4 is less than m1, and b) it could be folded into the passthru operand of the vadd.vv.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Now that we have matching for vqdot in it's basic variants, we can extend the matcher to handle reduction trees instead of individual reductions. This is important as we canonicalize reductions by performing a tree in the vector domain before the root reduction instruction.
The particular approach taken here has the unfortunate implication that non-matches visit the entire reduction tree once for each time the reduction root is visited in DAG. While conceptually problematic for compile time, this is probably fine in practice as we should only visit the root once per pass of DAGCombine. I don't really see a better solution - suggestions welcome.