-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[InstCombine] Introduce foldICmpBinOpWithConstantViaTruthTable
folding
#139109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Introduce foldICmpBinOpWithConstantViaTruthTable
folding
#139109
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Antonio Frighetto (antoniofrighetto) ChangesMatch icmps of binops where both operands are select with constant arms, i.e., Proofs: https://alive2.llvm.org/ce/z/_kkUfJ. Fixes: #138212. Full diff: https://github.com/llvm/llvm-project/pull/139109.diff 4 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 252781e54ab06..9cb3a80c2ea20 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3110,6 +3110,44 @@ static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0,
return nullptr;
}
+Instruction *InstCombinerImpl::foldICmpBinOpWithConstantViaTruthTable(
+ ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) {
+ Value *A, *B;
+ Constant *C1, *C2, *C3, *C4;
+ if (match(BO->getOperand(0),
+ m_Select(m_Value(A), m_Constant(C1), m_Constant(C2))) &&
+ match(BO->getOperand(1),
+ m_Select(m_Value(B), m_Constant(C3), m_Constant(C4)))) {
+ std::bitset<4> Table;
+ auto ComputeTable = [&](bool First, bool Second) -> std::optional<bool> {
+ Constant *L = First ? C1 : C2;
+ Constant *R = Second ? C3 : C4;
+ if (auto *Res = ConstantFoldBinaryOpOperands(BO->getOpcode(), L, R, DL)) {
+ auto *Val = Res->getType()->isVectorTy() ? Res->getSplatValue() : Res;
+ if (auto *CI = dyn_cast_or_null<ConstantInt>(Val))
+ return ICmpInst::compare(CI->getValue(), C, Cmp.getPredicate());
+ }
+ return std::nullopt;
+ };
+
+ for (unsigned I = 0; I < 4; ++I) {
+ bool First = (I >> 1) & 1;
+ bool Second = I & 1;
+ if (auto Res = ComputeTable(First, Second))
+ Table[I] = *Res;
+ else
+ return nullptr;
+ }
+
+ // Synthesize optimal logic.
+ if (auto *Cond =
+ createLogicFromTable(Table, A, B, Builder, BO->hasOneUse()))
+ return replaceInstUsesWith(Cmp, Cond);
+ }
+
+ return nullptr;
+}
+
/// Fold icmp (add X, Y), C.
Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
BinaryOperator *Add,
@@ -4014,7 +4052,13 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
}
// TODO: These folds could be refactored to be part of the above calls.
- return foldICmpBinOpEqualityWithConstant(Cmp, BO, C);
+ if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C))
+ return I;
+
+ // Fall back to handling `icmp pred (select A ? C1 : C2) binop (select B ? C3
+ // : C4), C5` pattern, by computing a truth table of the four constant
+ // variants.
+ return foldICmpBinOpWithConstantViaTruthTable(Cmp, BO, C);
}
static Instruction *
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 324738ef8c88e..8b657b3f8555c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -736,6 +736,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1,
const APInt &C2);
+ Instruction *foldICmpBinOpWithConstantViaTruthTable(ICmpInst &Cmp,
+ BinaryOperator *BO,
+ const APInt &C);
Instruction *foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
BinaryOperator *BO,
const APInt &C);
diff --git a/llvm/test/Transforms/InstCombine/icmp-binop.ll b/llvm/test/Transforms/InstCombine/icmp-binop.ll
index 356489716fff9..e240bf1cd95f9 100644
--- a/llvm/test/Transforms/InstCombine/icmp-binop.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-binop.ll
@@ -359,3 +359,99 @@ define i1 @test_icmp_sgt_and_negpow2_invalid_c(i32 %add) {
%cmp = icmp sgt i32 %and, 48
ret i1 %cmp
}
+
+define i1 @icmp_eq_or_of_selects_with_constant(i1 %a, i1 %b) {
+; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant(
+; CHECK-NEXT: [[CMP:%.*]] = and i1 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %s1 = select i1 %a, i64 65536, i64 0
+ %s2 = select i1 %b, i64 256, i64 0
+ %or = or i64 %s1, %s2
+ %cmp = icmp eq i64 %or, 65792
+ ret i1 %cmp
+}
+
+define i1 @icmp_slt_and_of_selects_with_constant(i1 %a, i1 %b) {
+; CHECK-LABEL: @icmp_slt_and_of_selects_with_constant(
+; CHECK-NEXT: [[TMP1:%.*]] = or i1 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %s1 = select i1 %a, i8 1, i8 254
+ %s2 = select i1 %b, i8 1, i8 253
+ %and = and i8 %s1, %s2
+ %cmp = icmp slt i8 %and, 254
+ ret i1 %cmp
+}
+
+define i1 @icmp_sge_add_of_selects_with_constant(i1 %a, i1 %b) {
+; CHECK-LABEL: @icmp_sge_add_of_selects_with_constant(
+; CHECK-NEXT: ret i1 true
+;
+ %s1 = select i1 %a, i8 248, i8 7
+ %s2 = select i1 %b, i8 16, i8 0
+ %add = add i8 %s1, %s2
+ %cmp = icmp sge i8 %add, 247
+ ret i1 %cmp
+}
+
+define <2 x i1> @icmp_eq_or_of_selects_with_constant_vectorized(<2 x i1> %a, <2 x i1> %b) {
+; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_vectorized(
+; CHECK-NEXT: [[CMP:%.*]] = and <2 x i1> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: ret <2 x i1> [[CMP]]
+;
+ %s1 = select <2 x i1> %a, <2 x i64> <i64 65536, i64 65536>, <2 x i64> zeroinitializer
+ %s2 = select <2 x i1> %b, <2 x i64> <i64 256, i64 256>, <2 x i64> zeroinitializer
+ %or = or <2 x i64> %s1, %s2
+ %cmp = icmp eq <2 x i64> %or, <i64 65792, i64 65792>
+ ret <2 x i1> %cmp
+}
+
+; Negative tests.
+define i1 @icmp_eq_or_of_selects_with_constant_and_arg(i1 %a, i1 %b, i64 %arg) {
+; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_and_arg(
+; CHECK-NEXT: [[S1:%.*]] = select i1 [[A:%.*]], i64 65536, i64 [[ARG:%.*]]
+; CHECK-NEXT: [[S2:%.*]] = select i1 [[B:%.*]], i64 256, i64 0
+; CHECK-NEXT: [[OR:%.*]] = or i64 [[S1]], [[S2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR]], 65792
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %s1 = select i1 %a, i64 65536, i64 %arg
+ %s2 = select i1 %b, i64 256, i64 0
+ %or = or i64 %s1, %s2
+ %cmp = icmp eq i64 %or, 65792
+ ret i1 %cmp
+}
+
+define i1 @icmp_eq_or_of_selects_with_constant_multiuse(i1 %a, i1 %b) {
+; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_multiuse(
+; CHECK-NEXT: [[S1:%.*]] = select i1 [[A:%.*]], i64 0, i64 65536
+; CHECK-NEXT: [[S2:%.*]] = select i1 [[B:%.*]], i64 0, i64 256
+; CHECK-NEXT: [[OR:%.*]] = or disjoint i64 [[S1]], [[S2]]
+; CHECK-NEXT: call void @use64(i64 [[OR]])
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OR]], 65792
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %s1 = select i1 %a, i64 0, i64 65536
+ %s2 = select i1 %b, i64 0, i64 256
+ %or = or i64 %s1, %s2
+ call void @use64(i64 %or)
+ %cmp = icmp eq i64 %or, 65792
+ ret i1 %cmp
+}
+
+define <2 x i1> @icmp_eq_or_of_selects_with_constant_vectorized_nonsplat(<2 x i1> %a, <2 x i1> %b) {
+; CHECK-LABEL: @icmp_eq_or_of_selects_with_constant_vectorized_nonsplat(
+; CHECK-NEXT: [[S1:%.*]] = select <2 x i1> [[A:%.*]], <2 x i64> splat (i64 65536), <2 x i64> zeroinitializer
+; CHECK-NEXT: [[S2:%.*]] = select <2 x i1> [[B:%.*]], <2 x i64> <i64 256, i64 128>, <2 x i64> zeroinitializer
+; CHECK-NEXT: [[OR:%.*]] = or disjoint <2 x i64> [[S1]], [[S2]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[OR]], <i64 65792, i64 65664>
+; CHECK-NEXT: ret <2 x i1> [[CMP]]
+;
+ %s1 = select <2 x i1> %a, <2 x i64> <i64 65536, i64 65536>, <2 x i64> zeroinitializer
+ %s2 = select <2 x i1> %b, <2 x i64> <i64 256, i64 128>, <2 x i64> zeroinitializer
+ %or = or <2 x i64> %s1, %s2
+ %cmp = icmp eq <2 x i64> %or, <i64 65792, i64 65664>
+ ret <2 x i1> %cmp
+}
diff --git a/llvm/test/Transforms/InstCombine/icmp-select.ll b/llvm/test/Transforms/InstCombine/icmp-select.ll
index 1aae91302dab1..a038731abbc48 100644
--- a/llvm/test/Transforms/InstCombine/icmp-select.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-select.ll
@@ -328,10 +328,7 @@ define i1 @select_constants_and_icmp_eq0_common_bit(i1 %x, i1 %y) {
define i1 @select_constants_and_icmp_eq0_no_common_op1(i1 %x, i1 %y) {
; CHECK-LABEL: @select_constants_and_icmp_eq0_no_common_op1(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 16, i8 3
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 24, i8 3
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], 0
+; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[CMP]]
;
%s1 = select i1 %x, i8 16, i8 3
@@ -345,10 +342,7 @@ define i1 @select_constants_and_icmp_eq0_no_common_op1(i1 %x, i1 %y) {
define i1 @select_constants_and_icmp_eq0_no_common_op2(i1 %x, i1 %y) {
; CHECK-LABEL: @select_constants_and_icmp_eq0_no_common_op2(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 16, i8 3
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 16, i8 7
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], 0
+; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[CMP]]
;
%s1 = select i1 %x, i8 16, i8 3
@@ -387,14 +381,9 @@ define i1 @select_constants_and_icmp_eq0_zero_fval(i1 %x, i1 %y) {
ret i1 %cmp
}
-; TODO: x & y
-
define i1 @select_constants_and_icmp_eq_tval(i1 %x, i1 %y) {
; CHECK-LABEL: @select_constants_and_icmp_eq_tval(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 6, i8 1
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 6, i8 1
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], 6
+; CHECK-NEXT: [[CMP:%.*]] = and i1 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[CMP]]
;
%s1 = select i1 %x, i8 6, i8 1
@@ -404,14 +393,10 @@ define i1 @select_constants_and_icmp_eq_tval(i1 %x, i1 %y) {
ret i1 %cmp
}
-; TODO: ~(x | y)
-
define i1 @select_constants_and_icmp_eq_fval(i1 %x, i1 %y) {
; CHECK-LABEL: @select_constants_and_icmp_eq_fval(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 12, i8 3
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 12, i8 3
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], 3
+; CHECK-NEXT: [[TMP1:%.*]] = or i1 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true
; CHECK-NEXT: ret i1 [[CMP]]
;
%s1 = select i1 %x, i8 12, i8 3
@@ -512,10 +497,8 @@ define i1 @select_constants_and_icmp_ne0_common_bit(i1 %x, i1 %y) {
define i1 @select_constants_and_icmp_ne0_no_common_op1(i1 %x, i1 %y) {
; CHECK-LABEL: @select_constants_and_icmp_ne0_no_common_op1(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 16, i8 3
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 24, i8 3
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 0
+; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true
; CHECK-NEXT: ret i1 [[CMP]]
;
%s1 = select i1 %x, i8 16, i8 3
@@ -529,10 +512,8 @@ define i1 @select_constants_and_icmp_ne0_no_common_op1(i1 %x, i1 %y) {
define i1 @select_constants_and_icmp_ne0_no_common_op2(i1 %x, i1 %y) {
; CHECK-LABEL: @select_constants_and_icmp_ne0_no_common_op2(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 16, i8 3
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 16, i8 7
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 0
+; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true
; CHECK-NEXT: ret i1 [[CMP]]
;
%s1 = select i1 %x, i8 16, i8 3
@@ -571,14 +552,10 @@ define i1 @select_constants_and_icmp_ne0_zero_fval(i1 %x, i1 %y) {
ret i1 %cmp
}
-; TODO: ~(x & y)
-
define i1 @select_constants_and_icmp_ne_tval(i1 %x, i1 %y) {
; CHECK-LABEL: @select_constants_and_icmp_ne_tval(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 6, i8 1
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 6, i8 1
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 6
+; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[CMP:%.*]] = xor i1 [[TMP1]], true
; CHECK-NEXT: ret i1 [[CMP]]
;
%s1 = select i1 %x, i8 6, i8 1
@@ -588,14 +565,9 @@ define i1 @select_constants_and_icmp_ne_tval(i1 %x, i1 %y) {
ret i1 %cmp
}
-; TODO: (x | y)
-
define i1 @select_constants_and_icmp_ne_fval(i1 %x, i1 %y) {
; CHECK-LABEL: @select_constants_and_icmp_ne_fval(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[X:%.*]], i8 12, i8 3
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[Y:%.*]], i8 12, i8 3
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[S1]], [[S2]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 3
+; CHECK-NEXT: [[CMP:%.*]] = or i1 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[CMP]]
;
%s1 = select i1 %x, i8 12, i8 3
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this subsume the fold at
llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Lines 1969 to 1991 in 948bffa
// If we are testing the intersection of 2 select-of-nonzero-constants with no | |
// common bits set, it's the same as checking if exactly one select condition | |
// is set: | |
// ((A ? TC : FC) & (B ? TC : FC)) == 0 --> xor A, B | |
// ((A ? TC : FC) & (B ? TC : FC)) != 0 --> not(xor A, B) | |
// TODO: Generalize for non-constant values. | |
// TODO: Handle signed/unsigned predicates. | |
// TODO: Handle other bitwise logic connectors. | |
// TODO: Extend to handle a non-zero compare constant. | |
if (C.isZero() && (Pred == CmpInst::ICMP_EQ || And->hasOneUse())) { | |
assert(Cmp.isEquality() && "Not expecting non-equality predicates"); | |
Value *A, *B; | |
const APInt *TC, *FC; | |
if (match(X, m_Select(m_Value(A), m_APInt(TC), m_APInt(FC))) && | |
match(Y, | |
m_Select(m_Value(B), m_SpecificInt(*TC), m_SpecificInt(*FC))) && | |
!TC->isZero() && !FC->isZero() && !TC->intersects(*FC)) { | |
Value *R = Builder.CreateXor(A, B); | |
if (Pred == CmpInst::ICMP_NE) | |
R = Builder.CreateNot(R); | |
return replaceInstUsesWith(Cmp, R); | |
} | |
} |
f7fba64
to
e3c539c
Compare
Right, thanks for catching it, PR description updated. |
// Synthesize optimal logic. | ||
if (auto *Cond = | ||
createLogicFromTable(Table, A, B, Builder, BO->hasOneUse()); | ||
Cond && Cmp.getType() == Cond->getType()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If Cmp.getType() != Cond->getType()
, the newly created instruction will not be used. It is the root cause of https://github.com/llvm/llvm-project/pull/139109/files#r2083295338.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sure thing, check anticipated, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Match icmps of binops where both operands are select with constant arms, i.e., `icmp pred (select A ? C1 : C2) binop (select B ? C3 : C4), C5`. Fold such patterns by creating a truth table of the possible four constant variants, and materialize back the optimal logic from it via `createLogicFromTable` helper. This also generalizes an existing fold, which has therefore been dropped. Proofs: https://alive2.llvm.org/ce/z/NS7Vzu. Fixes: llvm#138212.
5c728fb
to
adfd59f
Compare
Match icmps of binops where both operands are select with constant arms, i.e.,
icmp pred (select A ? C1 : C2) binop (select B ? C3 : C4), C5
. Fold such patterns by creating a truth table of the possible four constant variants, and materialize back the optimal logic from it viacreateLogicFromTable
helper. This also generalizes an existing fold, which has therefore been dropped.Proofs: https://alive2.llvm.org/ce/z/NS7Vzu.
Fixes: #138212.