Skip to content

[RISCV] Extend zvqdot matching to handle reduction trees #138965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 9, 2025

Conversation

preames
Copy link
Collaborator

@preames preames commented May 7, 2025

Now that we have matching for vqdot in it's basic variants, we can extend the matcher to handle reduction trees instead of individual reductions. This is important as we canonicalize reductions by performing a tree in the vector domain before the root reduction instruction.

The particular approach taken here has the unfortunate implication that non-matches visit the entire reduction tree once for each time the reduction root is visited in DAG. While conceptually problematic for compile time, this is probably fine in practice as we should only visit the root once per pass of DAGCombine. I don't really see a better solution - suggestions welcome.

Now that we have matching for vqdot in it's basic variants, we can
extend the matcher to handle reduction trees instead of individual
reductions.  This is important as we canonicalize reductions
by performing a tree in the vector domain before the root reduction
instruction.

The particular approach taken here has the unfortunate implication
that non-matches visit the entire reduction tree once for each
time the reduction root is visited in DAG.  While conceptually
problematic for compile time, this is probably fine in practice
as we should only visit the root once per pass of DAGCombine.
I don't really see a better solution - suggestions welcome.
@preames preames requested review from asb, lukel97 and topperc May 7, 2025 20:54
@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

Changes

Now that we have matching for vqdot in it's basic variants, we can extend the matcher to handle reduction trees instead of individual reductions. This is important as we canonicalize reductions by performing a tree in the vector domain before the root reduction instruction.

The particular approach taken here has the unfortunate implication that non-matches visit the entire reduction tree once for each time the reduction root is visited in DAG. While conceptually problematic for compile time, this is probably fine in practice as we should only visit the root once per pass of DAGCombine. I don't really see a better solution - suggestions welcome.


Full diff: https://github.com/llvm/llvm-project/pull/138965.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+43)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+102-47)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 756c563f0194d..a4d778c054ecf 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18131,6 +18131,29 @@ static MVT getQDOTXResultType(MVT OpVT) {
   return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
 }
 
+/// Given fixed length vectors A and B with equal element types, but possibly
+/// different number of elements, return A + B where either A or B is zero
+/// padded to the larger number of elements.
+static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
+                                SelectionDAG &DAG) {
+  // NOTE: Manually doing the extract/add/insert scheme produces
+  // significantly better coegen than the naive pad with zeros
+  // and add scheme.
+  EVT AVT = A.getValueType();
+  EVT BVT = B.getValueType();
+  assert(AVT.getVectorElementType() == BVT.getVectorElementType());
+  if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
+    std::swap(A, B);
+    std::swap(AVT, BVT);
+  }
+
+  SDValue BPart = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, AVT, B,
+                              DAG.getVectorIdxConstant(0, DL));
+  SDValue Res = DAG.getNode(ISD::ADD, DL, AVT, A, BPart);
+  return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, BVT, B, Res,
+                     DAG.getVectorIdxConstant(0, DL));
+}
+
 static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
                                          SelectionDAG &DAG,
                                          const RISCVSubtarget &Subtarget,
@@ -18142,6 +18165,26 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
       !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
     return SDValue();
 
+  // Recurse through adds (since generic dag canonicalizes to that
+  // form).
+  if (InVec->getOpcode() == ISD::ADD) {
+    SDValue A = InVec.getOperand(0);
+    SDValue B = InVec.getOperand(1);
+    SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
+    SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
+    if (AOpt || BOpt) {
+      if (AOpt)
+        A = AOpt;
+      if (BOpt)
+        B = BOpt;
+      // From here, we're doing A + B with mixed types, implicitly zero
+      // padded to the wider type.  Note that we *don't* need the result
+      // type to be the original VT, and in fact prefer narrower ones
+      // if possible.
+      return getZeroPaddedAdd(DL, A, B, DAG);
+    }
+  }
+
   // reduce (zext a) <--> reduce (mul zext a. zext 1)
   // reduce (sext a) <--> reduce (mul sext a. sext 1)
   if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index edc9886abc3b9..e5546ad404c1b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -299,17 +299,31 @@ entry:
 }
 
 define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdot_vv_accum:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vsext.vf2 v16, v9
-; CHECK-NEXT:    vwmacc.vv v12, v10, v16
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdot_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vsext.vf2 v16, v9
+; NODOT-NEXT:    vwmacc.vv v12, v10, v16
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v9
+; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.sext = sext <16 x i8> %b to <16 x i32>
@@ -320,17 +334,31 @@ entry:
 }
 
 define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotu_vv_accum:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT:    vwmulu.vv v10, v8, v9
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vwaddu.wv v12, v12, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotu_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT:    vwmulu.vv v10, v8, v9
+; NODOT-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT:    vwaddu.wv v12, v12, v10
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotu_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotu.vv v10, v8, v9
+; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.zext = zext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -341,17 +369,31 @@ entry:
 }
 
 define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotsu_vv_accum:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v16, v9
-; CHECK-NEXT:    vwmaccsu.vv v12, v10, v16
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vzext.vf2 v16, v9
+; NODOT-NEXT:    vwmaccsu.vv v12, v10, v16
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -455,20 +497,33 @@ entry:
 }
 
 define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
-; CHECK-LABEL: vqdot_vv_split:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vsext.vf2 v14, v9
-; CHECK-NEXT:    vsext.vf2 v16, v10
-; CHECK-NEXT:    vsext.vf2 v18, v11
-; CHECK-NEXT:    vwmul.vv v8, v12, v14
-; CHECK-NEXT:    vwmacc.vv v8, v16, v18
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdot_vv_split:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vsext.vf2 v14, v9
+; NODOT-NEXT:    vsext.vf2 v16, v10
+; NODOT-NEXT:    vsext.vf2 v18, v11
+; NODOT-NEXT:    vwmul.vv v8, v12, v14
+; NODOT-NEXT:    vwmacc.vv v8, v16, v18
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv_split:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vmv.v.i v13, 0
+; DOT-NEXT:    vqdot.vv v12, v8, v9
+; DOT-NEXT:    vqdot.vv v13, v10, v11
+; DOT-NEXT:    vadd.vv v8, v12, v13
+; DOT-NEXT:    vmv.s.x v9, zero
+; DOT-NEXT:    vredsum.vs v8, v8, v9
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.sext = sext <16 x i8> %b to <16 x i32>

; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; DOT-NEXT: vmv.v.i v10, 0
; DOT-NEXT: vqdot.vv v10, v8, v9
; DOT-NEXT: vadd.vv v8, v10, v12
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This vadd will fold into the accumulator of the vqdot.vv, but it'll require another couple patches first. I have to fix a DAG combine problem before the straight forward patch paralleling VWMUL works.

; DOT-NEXT: vqdot.vv v10, v8, v9
; DOT-NEXT: vadd.vv v8, v10, v12
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
; DOT-NEXT: vmv.v.v v12, v8
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting missed optimization, but not specific to this example. a) this vmv.v.v could be done at m4 since we know that 4 is less than m1, and b) it could be folded into the passthru operand of the vadd.vv.

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 7e64ade into llvm:main May 9, 2025
6 of 10 checks passed
@preames preames deleted the pr-zvqdotq-add-recurse-only branch May 9, 2025 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants