Skip to content

[BOLT] Gadget scanner: prevent false positives due to jump tables #138884

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/atrosinenko/bolt-mcinstmatcher
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bolt/include/bolt/Core/MCInstUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ class MCInstReference {
return nullptr;
}

/// Returns the only preceding instruction, or std::nullopt if multiple or no
/// predecessors are possible.
///
/// If CFG information is available, basic block boundary can be crossed,
/// provided there is exactly one predecessor. If CFG is not available, the
/// preceding instruction in the offset order is returned, unless this is the
/// first instruction of the function.
std::optional<MCInstReference> getSinglePredecessor();

raw_ostream &print(raw_ostream &OS) const;
};

Expand Down
14 changes: 14 additions & 0 deletions bolt/include/bolt/Core/MCPlusBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef BOLT_CORE_MCPLUSBUILDER_H
#define BOLT_CORE_MCPLUSBUILDER_H

#include "bolt/Core/MCInstUtils.h"
#include "bolt/Core/MCPlus.h"
#include "bolt/Core/Relocation.h"
#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -699,6 +700,19 @@ class MCPlusBuilder {
return std::nullopt;
}

/// Tests if BranchInst corresponds to an instruction sequence which is known
/// to be a safe dispatch via jump table.
///
/// The target can decide which instruction sequences to consider "safe" from
/// the Pointer Authentication point of view, such as any jump table dispatch
/// sequence without function calls inside, any sequence which is contiguous,
/// or only some specific well-known sequences.
virtual bool
isSafeJumpTableBranchForPtrAuth(MCInstReference BranchInst) const {
llvm_unreachable("not implemented");
return false;
}

virtual bool isTerminator(const MCInst &Inst) const;

virtual bool isNoop(const MCInst &Inst) const {
Expand Down
20 changes: 20 additions & 0 deletions bolt/lib/Core/MCInstUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,23 @@ raw_ostream &MCInstReference::print(raw_ostream &OS) const {
OS << ">";
return OS;
}

std::optional<MCInstReference> MCInstReference::getSinglePredecessor() {
if (const RefInBB *Ref = tryGetRefInBB()) {
if (Ref->It != Ref->BB->begin())
return MCInstReference(Ref->BB, &*std::prev(Ref->It));

if (Ref->BB->pred_size() != 1)
return std::nullopt;

BinaryBasicBlock *PredBB = *Ref->BB->pred_begin();
assert(!PredBB->empty() && "Empty basic blocks are not supported yet");
return MCInstReference(PredBB, &*PredBB->rbegin());
}

const RefInBF &Ref = getRefInBF();
if (Ref.It == Ref.BF->instrs().begin())
return std::nullopt;

return MCInstReference(Ref.BF, std::prev(Ref.It));
}
10 changes: 10 additions & 0 deletions bolt/lib/Passes/PAuthGadgetScanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,11 @@ shouldReportUnsafeTailCall(const BinaryContext &BC, const BinaryFunction &BF,
return std::nullopt;
}

if (BC.MIB->isSafeJumpTableBranchForPtrAuth(Inst)) {
LLVM_DEBUG({ dbgs() << " Safe jump table detected, skipping.\n"; });
return std::nullopt;
}

// Returns at most one report per instruction - this is probably OK...
for (auto Reg : RegsToCheck)
if (!S.TrustedRegs[Reg])
Expand Down Expand Up @@ -1358,6 +1363,11 @@ shouldReportCallGadget(const BinaryContext &BC, const MCInstReference &Inst,
if (S.SafeToDerefRegs[DestReg])
return std::nullopt;

if (BC.MIB->isSafeJumpTableBranchForPtrAuth(Inst)) {
LLVM_DEBUG({ dbgs() << " Safe jump table detected, skipping.\n"; });
return std::nullopt;
}

return make_gadget_report(CallKind, Inst, DestReg);
}

Expand Down
73 changes: 73 additions & 0 deletions bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,79 @@ class AArch64MCPlusBuilder : public MCPlusBuilder {
return std::nullopt;
}

bool
isSafeJumpTableBranchForPtrAuth(MCInstReference BranchInst) const override {
MCInstReference CurRef = BranchInst;
auto StepBack = [&]() {
do {
auto PredInst = CurRef.getSinglePredecessor();
if (!PredInst)
return false;
CurRef = *PredInst;
} while (isCFI(CurRef));

return true;
};

// Match this contiguous sequence:
// cmp Xm, #count
// csel Xm, Xm, xzr, ls
// adrp Xn, .LJTIxyz
// add Xn, Xn, :lo12:.LJTIxyz
// ldrsw Xm, [Xn, Xm, lsl #2]
// .Ltmp:
// adr Xn, .Ltmp
// add Xm, Xn, Xm
// br Xm

// FIXME: Check label operands of ADR/ADRP+ADD and #count operand of CMP.

using namespace MCInstMatcher;
Reg Xm, Xn;

if (!matchInst(CurRef, AArch64::BR, Xm) || !StepBack())
return false;

if (!matchInst(CurRef, AArch64::ADDXrs, Xm, Xn, Xm, Imm(0)) || !StepBack())
return false;

if (!matchInst(CurRef, AArch64::ADR, Xn /*, .Ltmp*/) || !StepBack())
return false;

if (!matchInst(CurRef, AArch64::LDRSWroX, Xm, Xn, Xm, Imm(0), Imm(1)) ||
!StepBack())
return false;

if (matchInst(CurRef, AArch64::ADR, Xn /*, .LJTIxyz*/)) {
if (!StepBack())
return false;
if (!matchInst(CurRef, AArch64::HINT, Imm(0)) || !StepBack())
return false;
} else if (matchInst(CurRef, AArch64::ADDXri, Xn,
Xn /*, :lo12:.LJTIxyz*/)) {
if (!StepBack())
return false;
if (!matchInst(CurRef, AArch64::ADRP, Xn /*, .LJTIxyz*/) || !StepBack())
return false;
} else {
return false;
}

if (!matchInst(CurRef, AArch64::CSELXr, Xm, Xm, Reg(AArch64::XZR),
Imm(AArch64CC::LS)) ||
!StepBack())
return false;

if (!matchInst(CurRef, AArch64::SUBSXri, Reg(AArch64::XZR),
Xm /*, #count*/))
return false;

// Some platforms treat X16 and X17 as more protected registers, others
// do not make such distinction. So far, accept any registers as Xm and Xn.

return true;
}

bool isADRP(const MCInst &Inst) const override {
return Inst.getOpcode() == AArch64::ADRP;
}
Expand Down
Loading
Loading