Skip to content

Commit bf5351e

Browse files
committed
[HashRecognize] Address review
1 parent 00dd6ba commit bf5351e

File tree

1 file changed

+69
-57
lines changed

1 file changed

+69
-57
lines changed

llvm/lib/Analysis/HashRecognize.cpp

+69-57
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,7 @@ KnownBits ValueEvolution::computeBinOp(const BinaryOperator *I,
153153
case Instruction::BinaryOps::Mul: {
154154
Value *Op0 = I->getOperand(0);
155155
Value *Op1 = I->getOperand(1);
156-
bool SelfMultiply = Op0 == Op1;
157-
if (SelfMultiply)
158-
SelfMultiply &= isGuaranteedNotToBeUndef(Op0);
156+
bool SelfMultiply = Op0 == Op1 && isGuaranteedNotToBeUndef(Op0);
159157
return KnownBits::mul(KnownL, KnownR, SelfMultiply);
160158
}
161159
case Instruction::BinaryOps::UDiv:
@@ -278,6 +276,44 @@ ValueEvolution::computeEvolutions(ArrayRef<PhiStepPair> PhiEvolutions) {
278276
return KnownPhis;
279277
}
280278

279+
/// Digs for a recurrence starting with \p V hitting the PHI node \p P in a
280+
/// use-def chain. Used by matchConditionalRecurrence.
281+
static BinaryOperator *
282+
digRecurrence(Instruction *V, const PHINode *P, const Loop &L,
283+
const APInt *&ExtraConst,
284+
Instruction::BinaryOps BOWithConstOpToMatch) {
285+
using namespace llvm::PatternMatch;
286+
287+
SmallVector<Instruction *> Worklist;
288+
Worklist.push_back(V);
289+
while (!Worklist.empty()) {
290+
Instruction *I = Worklist.pop_back_val();
291+
292+
// Don't add a PHI's operands to the Worklist.
293+
if (isa<PHINode>(I))
294+
continue;
295+
296+
// Find a recurrence over a BinOp, by matching either of its operands
297+
// with with the PHINode.
298+
if (match(I, m_c_BinOp(m_Value(), m_Specific(P))))
299+
return cast<BinaryOperator>(I);
300+
301+
// Bind to ExtraConst, if we match exactly one.
302+
if (I->getOpcode() == BOWithConstOpToMatch) {
303+
if (ExtraConst)
304+
return nullptr;
305+
match(I, m_c_BinOp(m_APInt(ExtraConst), m_Value()));
306+
}
307+
308+
// Continue along the use-def chain.
309+
for (Use &U : I->operands())
310+
if (auto *UI = dyn_cast<Instruction>(U))
311+
if (L.contains(UI))
312+
Worklist.push_back(UI);
313+
}
314+
return nullptr;
315+
}
316+
281317
/// A Conditional Recurrence is a recurrence of the form:
282318
///
283319
/// loop:
@@ -311,42 +347,15 @@ static bool matchConditionalRecurrence(
311347
m_Select(m_Cmp(), m_Instruction(TV), m_Instruction(FV))))
312348
continue;
313349

314-
auto DigRecurrence = [&](Instruction *V) -> BinaryOperator * {
315-
SmallVector<Instruction *> Worklist;
316-
Worklist.push_back(V);
317-
while (!Worklist.empty()) {
318-
Instruction *I = Worklist.pop_back_val();
319-
320-
// Don't add a PHI's operands to the Worklist.
321-
if (isa<PHINode>(I))
322-
continue;
323-
324-
// Find a recurrence over a BinOp, by matching either of its operands
325-
// with with the PHINode.
326-
if (match(I, m_c_BinOp(m_Value(), m_Specific(P))))
327-
return cast<BinaryOperator>(I);
328-
329-
// Bind to ExtraConst, if we match exactly one.
330-
if (I->getOpcode() == BOWithConstOpToMatch) {
331-
if (ExtraConst)
332-
return nullptr;
333-
match(I, m_c_BinOp(m_APInt(ExtraConst), m_Value()));
334-
}
335-
336-
// Continue along the use-def chain.
337-
for (Use &U : I->operands())
338-
if (auto *UI = dyn_cast<Instruction>(U))
339-
if (L.contains(UI))
340-
Worklist.push_back(UI);
341-
}
342-
return nullptr;
343-
};
344-
345350
// For a conditional recurrence, both the true and false values of the
346351
// select must ultimately end up in the same recurrent BinOp.
347-
BinaryOperator *FoundBO = DigRecurrence(TV);
348-
BinaryOperator *AltBO = DigRecurrence(FV);
349-
if (!FoundBO || !AltBO || FoundBO != AltBO)
352+
ExtraConst = nullptr;
353+
BinaryOperator *FoundBO =
354+
digRecurrence(TV, P, L, ExtraConst, BOWithConstOpToMatch);
355+
BinaryOperator *AltBO =
356+
digRecurrence(FV, P, L, ExtraConst, BOWithConstOpToMatch);
357+
358+
if (!FoundBO || FoundBO != AltBO)
350359
return false;
351360

352361
if (BOWithConstOpToMatch != Instruction::BinaryOpsEnd && !ExtraConst) {
@@ -404,15 +413,16 @@ static std::pair<std::optional<RecurrenceInfo>, std::optional<RecurrenceInfo>>
404413
getRecurrences(BasicBlock *LoopLatch, const PHINode *IndVar, const Loop &L) {
405414
std::optional<RecurrenceInfo> SimpleRecurrence, ConditionalRecurrence;
406415
for (PHINode &P : LoopLatch->phis()) {
407-
BinaryOperator *BO;
408-
Value *Start, *Step;
409-
const APInt *GenPoly = nullptr;
410416
if (&P == IndVar)
411417
continue;
412418
if (!P.getType()->isIntegerTy()) {
413419
LLVM_DEBUG(dbgs() << "HashRecognize: Non-integral PHI found\n");
414420
return {};
415421
}
422+
423+
BinaryOperator *BO;
424+
Value *Start, *Step;
425+
const APInt *GenPoly;
416426
if (!SimpleRecurrence && matchSimpleRecurrence(&P, BO, Start, Step)) {
417427
SimpleRecurrence = {&P, BO, Start, Step};
418428
} else if (!ConditionalRecurrence &&
@@ -461,19 +471,24 @@ CRCTable HashRecognize::genSarwateTable(const APInt &GenPoly,
461471
unsigned MSB = 1 << (BW - 1);
462472
CRCTable Table;
463473
Table[0] = APInt::getZero(BW);
464-
APInt CRCInit(BW, ByteOrderSwapped ? 1 : 128);
465-
for (unsigned I = ByteOrderSwapped ? 1 : 128; ByteOrderSwapped ? I < 256 : I;
466-
ByteOrderSwapped ? I <<= 1 : I >>= 1) {
467-
APInt CRCShift = ByteOrderSwapped ? CRCInit.shl(1) : CRCInit.lshr(1);
468-
APInt SBCheck = ByteOrderSwapped ? (CRCInit & MSB) : (CRCInit & 1);
469-
CRCInit = CRCShift ^ (SBCheck.isZero() ? APInt::getZero(BW) : GenPoly);
470-
if (ByteOrderSwapped) {
474+
475+
if (ByteOrderSwapped) {
476+
APInt CRCInit(BW, 1);
477+
for (unsigned I = 1; I < 256; I <<= 1) {
478+
CRCInit = CRCInit.shl(1) ^
479+
((CRCInit & MSB).isZero() ? APInt::getZero(BW) : GenPoly);
471480
for (unsigned J = 0; J < I; ++J)
472481
Table[I + J] = CRCInit ^ Table[J];
473-
} else {
474-
for (unsigned J = 0; J < 256; J += (I << 1))
475-
Table[I + J] = CRCInit ^ Table[J];
476482
}
483+
return Table;
484+
}
485+
486+
APInt CRCInit(BW, 128);
487+
for (unsigned I = 128; I; I >>= 1) {
488+
CRCInit = CRCInit.lshr(1) ^
489+
((CRCInit & 1).isZero() ? APInt::getZero(BW) : GenPoly);
490+
for (unsigned J = 0; J < 256; J += (I << 1))
491+
Table[I + J] = CRCInit ^ Table[J];
477492
}
478493
return Table;
479494
}
@@ -576,7 +591,7 @@ HashRecognize::recognizeCRC() const {
576591
// true even if it is only really used in an outer loop's exit block, since
577592
// the loop is in LCSSA form.
578593
auto *ComputedValue = cast<SelectInst>(ConditionalRecurrence->Step);
579-
if (!count_if(ComputedValue->users(), [Exit](User *U) {
594+
if (none_of(ComputedValue->users(), [Exit](User *U) {
580595
auto *UI = dyn_cast<Instruction>(U);
581596
return UI && UI->getParent() == Exit;
582597
}))
@@ -611,12 +626,9 @@ HashRecognize::recognizeCRC() const {
611626
}
612627

613628
void CRCTable::print(raw_ostream &OS) const {
614-
for (unsigned I = 0; I < 256; I += 16) {
615-
for (unsigned J = I; J < I + 16; ++J) {
616-
std::array<APInt, 256>::operator[](J).print(OS, false);
617-
OS << " ";
618-
}
619-
OS << "\n";
629+
for (unsigned I = 0; I < 256; I++) {
630+
(*this)[I].print(OS, false);
631+
OS << (I % 16 == 15 ? '\n' : ' ');
620632
}
621633
}
622634

0 commit comments

Comments
 (0)