-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][OpenMP] Convert omp.cancel parallel to LLVMIR #137192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -158,6 +158,12 @@ static LogicalResult checkImplementationStatus(Operation &op) { | |||
if (op.getBare()) | ||||
result = todo("ompx_bare"); | ||||
}; | ||||
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) { | ||||
omp::ClauseCancellationConstructType cancelledDirective = | ||||
op.getCancelDirective(); | ||||
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel) | ||||
result = todo("cancel directive construct type not yet supported"); | ||||
}; | ||||
auto checkDepend = [&todo](auto op, LogicalResult &result) { | ||||
if (!op.getDependVars().empty() || op.getDependKinds()) | ||||
result = todo("depend"); | ||||
|
@@ -248,6 +254,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { | |||
|
||||
LogicalResult result = success(); | ||||
llvm::TypeSwitch<Operation &>(op) | ||||
.Case([&](omp::CancelOp op) { checkCancelDirective(op, result); }) | ||||
.Case([&](omp::DistributeOp op) { | ||||
checkAllocate(op, result); | ||||
checkDistSchedule(op, result); | ||||
|
@@ -1580,6 +1587,19 @@ cleanupPrivateVars(llvm::IRBuilderBase &builder, | |||
return success(); | ||||
} | ||||
|
||||
/// Returns true if the construct contains omp.cancel or omp.cancellation_point | ||||
static bool constructIsCancellable(Operation *op) { | ||||
// omp.cancel must be "closely nested" so it will be visible and not inside of | ||||
// funcion calls. This is enforced by the verifier. | ||||
return op | ||||
->walk([](Operation *child) { | ||||
if (mlir::isa<omp::CancelOp>(child)) | ||||
return WalkResult::interrupt(); | ||||
return WalkResult::advance(); | ||||
}) | ||||
.wasInterrupted(); | ||||
} | ||||
|
||||
static LogicalResult | ||||
convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, | ||||
LLVM::ModuleTranslation &moduleTranslation) { | ||||
|
@@ -2524,8 +2544,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, | |||
auto pbKind = llvm::omp::OMP_PROC_BIND_default; | ||||
if (auto bind = opInst.getProcBindKind()) | ||||
pbKind = getProcBindKind(*bind); | ||||
// TODO: Is the Parallel construct cancellable? | ||||
bool isCancellable = false; | ||||
bool isCancellable = constructIsCancellable(opInst); | ||||
|
||||
llvm::OpenMPIRBuilder::InsertPointTy allocaIP = | ||||
findAllocaInsertPoint(builder, moduleTranslation); | ||||
|
@@ -2991,6 +3010,47 @@ convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp, | |||
return success(); | ||||
} | ||||
|
||||
static llvm::omp::Directive convertCancellationConstructType( | ||||
omp::ClauseCancellationConstructType directive) { | ||||
switch (directive) { | ||||
case omp::ClauseCancellationConstructType::Loop: | ||||
return llvm::omp::Directive::OMPD_for; | ||||
case omp::ClauseCancellationConstructType::Parallel: | ||||
return llvm::omp::Directive::OMPD_parallel; | ||||
case omp::ClauseCancellationConstructType::Sections: | ||||
return llvm::omp::Directive::OMPD_sections; | ||||
case omp::ClauseCancellationConstructType::Taskgroup: | ||||
return llvm::omp::Directive::OMPD_taskgroup; | ||||
} | ||||
} | ||||
|
||||
static LogicalResult | ||||
convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder, | ||||
LLVM::ModuleTranslation &moduleTranslation) { | ||||
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); | ||||
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); | ||||
|
||||
if (failed(checkImplementationStatus(*op.getOperation()))) | ||||
return failure(); | ||||
|
||||
llvm::Value *ifCond = nullptr; | ||||
if (Value ifVar = op.getIfExpr()) | ||||
ifCond = moduleTranslation.lookupValue(ifVar); | ||||
|
||||
llvm::omp::Directive cancelledDirective = | ||||
convertCancellationConstructType(op.getCancelDirective()); | ||||
|
||||
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||
ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective); | ||||
|
||||
if (failed(handleError(afterIP, *op.getOperation()))) | ||||
return failure(); | ||||
|
||||
builder.restoreIP(afterIP.get()); | ||||
|
||||
return success(); | ||||
} | ||||
|
||||
/// Converts an OpenMP Threadprivate operation into LLVM IR using | ||||
/// OpenMPIRBuilder. | ||||
static LogicalResult | ||||
|
@@ -5421,6 +5481,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, | |||
.Case([&](omp::AtomicCaptureOp op) { | ||||
return convertOmpAtomicCapture(op, builder, moduleTranslation); | ||||
}) | ||||
.Case([&](omp::CancelOp op) { | ||||
return convertOmpCancel(op, builder, moduleTranslation); | ||||
}) | ||||
.Case([&](omp::SectionsOp) { | ||||
return convertOmpSections(*op, builder, moduleTranslation); | ||||
}) | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s | ||
|
||
llvm.func @cancel_parallel() { | ||
omp.parallel { | ||
omp.cancel cancellation_construct_type(parallel) | ||
omp.terminator | ||
} | ||
llvm.return | ||
} | ||
// CHECK-LABEL: define internal void @cancel_parallel..omp_par | ||
// CHECK: omp.par.entry: | ||
// CHECK: %[[VAL_5:.*]] = alloca i32, align 4 | ||
// CHECK: %[[VAL_6:.*]] = load i32, ptr %[[VAL_7:.*]], align 4 | ||
// CHECK: store i32 %[[VAL_6]], ptr %[[VAL_5]], align 4 | ||
// CHECK: %[[VAL_8:.*]] = load i32, ptr %[[VAL_5]], align 4 | ||
// CHECK: br label %[[VAL_9:.*]] | ||
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_10:.*]] | ||
// CHECK: br label %[[VAL_11:.*]] | ||
// CHECK: omp.par.region: ; preds = %[[VAL_9]] | ||
// CHECK: br label %[[VAL_12:.*]] | ||
// CHECK: omp.par.region1: ; preds = %[[VAL_11]] | ||
// CHECK: %[[VAL_13:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: %[[VAL_14:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_13]], i32 1) | ||
// CHECK: %[[VAL_15:.*]] = icmp eq i32 %[[VAL_14]], 0 | ||
// CHECK: br i1 %[[VAL_15]], label %[[VAL_16:.*]], label %[[VAL_17:.*]] | ||
// CHECK: omp.par.region1.cncl: ; preds = %[[VAL_12]] | ||
// CHECK: %[[VAL_18:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: %[[VAL_19:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_18]]) | ||
// CHECK: br label %[[VAL_20:.*]] | ||
// CHECK: omp.par.region1.split: ; preds = %[[VAL_12]] | ||
// CHECK: br label %[[VAL_21:.*]] | ||
// CHECK: omp.region.cont: ; preds = %[[VAL_16]] | ||
// CHECK: br label %[[VAL_22:.*]] | ||
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_21]] | ||
// CHECK: br label %[[VAL_20]] | ||
// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_22]], %[[VAL_17]] | ||
// CHECK: ret void | ||
|
||
llvm.func @cancel_parallel_if(%arg0 : i1) { | ||
omp.parallel { | ||
omp.cancel cancellation_construct_type(parallel) if(%arg0) | ||
omp.terminator | ||
} | ||
llvm.return | ||
} | ||
// CHECK-LABEL: define internal void @cancel_parallel_if..omp_par | ||
// CHECK: omp.par.entry: | ||
// CHECK: %[[VAL_9:.*]] = getelementptr { ptr }, ptr %[[VAL_10:.*]], i32 0, i32 0 | ||
// CHECK: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_9]], align 8 | ||
// CHECK: %[[VAL_12:.*]] = alloca i32, align 4 | ||
// CHECK: %[[VAL_13:.*]] = load i32, ptr %[[VAL_14:.*]], align 4 | ||
// CHECK: store i32 %[[VAL_13]], ptr %[[VAL_12]], align 4 | ||
// CHECK: %[[VAL_15:.*]] = load i32, ptr %[[VAL_12]], align 4 | ||
// CHECK: %[[VAL_16:.*]] = load i1, ptr %[[VAL_11]], align 1 | ||
// CHECK: br label %[[VAL_17:.*]] | ||
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_18:.*]] | ||
// CHECK: br label %[[VAL_19:.*]] | ||
// CHECK: omp.par.region: ; preds = %[[VAL_17]] | ||
// CHECK: br label %[[VAL_20:.*]] | ||
// CHECK: omp.par.region1: ; preds = %[[VAL_19]] | ||
// CHECK: br i1 %[[VAL_16]], label %[[VAL_21:.*]], label %[[VAL_22:.*]] | ||
// CHECK: 3: ; preds = %[[VAL_20]] | ||
// CHECK: br label %[[VAL_23:.*]] | ||
// CHECK: 4: ; preds = %[[VAL_22]], %[[VAL_24:.*]] | ||
// CHECK: br label %[[VAL_25:.*]] | ||
// CHECK: omp.region.cont: ; preds = %[[VAL_23]] | ||
// CHECK: br label %[[VAL_26:.*]] | ||
// CHECK: omp.par.pre_finalize: ; preds = %[[VAL_25]] | ||
// CHECK: br label %[[VAL_27:.*]] | ||
// CHECK: 5: ; preds = %[[VAL_20]] | ||
// CHECK: %[[VAL_28:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_28]], i32 1) | ||
// CHECK: %[[VAL_30:.*]] = icmp eq i32 %[[VAL_29]], 0 | ||
// CHECK: br i1 %[[VAL_30]], label %[[VAL_24]], label %[[VAL_31:.*]] | ||
// CHECK: .cncl: ; preds = %[[VAL_21]] | ||
// CHECK: %[[VAL_32:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: %[[VAL_33:.*]] = call i32 @__kmpc_cancel_barrier(ptr @2, i32 %[[VAL_32]]) | ||
// CHECK: br label %[[VAL_27]] | ||
// CHECK: .split: ; preds = %[[VAL_21]] | ||
// CHECK: br label %[[VAL_23]] | ||
// CHECK: omp.par.exit.exitStub: ; preds = %[[VAL_31]], %[[VAL_26]] | ||
// CHECK: ret void |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a TODO? I noticed that cancel on taskgroup is a TODO, rest all are PR stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO is generated by
checkImplementationStatus
.