Skip to content

Commit 22c381e

Browse files
committed
Fixed atomic capture cases with atomic update inside.
1 parent d843f4e commit 22c381e

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,7 +2729,8 @@ static void genAtomicUpdateStatement(
27292729
const parser::Expr &assignmentStmtExpr,
27302730
const parser::OmpAtomicClauseList *leftHandClauseList,
27312731
const parser::OmpAtomicClauseList *rightHandClauseList, mlir::Location loc,
2732-
mlir::Operation *atomicCaptureOp = nullptr) {
2732+
mlir::Operation *atomicCaptureOp = nullptr,
2733+
lower::StatementContext *atomicCaptureStmtCtx = nullptr) {
27332734
// Generate `atomic.update` operation for atomic assignment statements
27342735
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
27352736
mlir::Location currentLocation = converter.getCurrentLocation();
@@ -2803,15 +2804,24 @@ static void genAtomicUpdateStatement(
28032804
},
28042805
assignmentStmtExpr.u);
28052806
lower::StatementContext nonAtomicStmtCtx;
2807+
lower::StatementContext *stmtCtxPtr = &nonAtomicStmtCtx;
28062808
if (!nonAtomicSubExprs.empty()) {
28072809
// Generate non atomic part before all the atomic operations.
28082810
auto insertionPoint = firOpBuilder.saveInsertionPoint();
2809-
if (atomicCaptureOp)
2811+
if (atomicCaptureOp) {
2812+
assert(atomicCaptureStmtCtx && "must specify statement context");
28102813
firOpBuilder.setInsertionPoint(atomicCaptureOp);
2814+
// Any clean-ups associated with the expression lowering
2815+
// must also be generated outside of the atomic update operation
2816+
// and after the atomic capture operation.
2817+
// The atomicCaptureStmtCtx will be finalized at the end
2818+
// of the atomic capture operation generation.
2819+
stmtCtxPtr = atomicCaptureStmtCtx;
2820+
}
28112821
mlir::Value nonAtomicVal;
28122822
for (auto *nonAtomicSubExpr : nonAtomicSubExprs) {
28132823
nonAtomicVal = fir::getBase(converter.genExprValue(
2814-
currentLocation, *nonAtomicSubExpr, nonAtomicStmtCtx));
2824+
currentLocation, *nonAtomicSubExpr, *stmtCtxPtr));
28152825
exprValueOverrides.try_emplace(nonAtomicSubExpr, nonAtomicVal);
28162826
}
28172827
if (atomicCaptureOp)
@@ -3097,7 +3107,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
30973107
genAtomicUpdateStatement(
30983108
converter, stmt2LHSArg, stmt2VarType, stmt2Var, stmt2Expr,
30993109
/*leftHandClauseList=*/nullptr,
3100-
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
3110+
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
31013111
} else {
31023112
// Atomic capture construct is of the form [capture-stmt, write-stmt]
31033113
firOpBuilder.setInsertionPoint(atomicCaptureOp);
@@ -3121,7 +3131,7 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
31213131
genAtomicUpdateStatement(
31223132
converter, stmt1LHSArg, stmt1VarType, stmt1Var, stmt1Expr,
31233133
/*leftHandClauseList=*/nullptr,
3124-
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp);
3134+
/*rightHandClauseList=*/nullptr, loc, atomicCaptureOp, &stmtCtx);
31253135
genAtomicCaptureStatement(converter, stmt1LHSArg, stmt2LHSArg,
31263136
/*leftHandClauseList=*/nullptr,
31273137
/*rightHandClauseList=*/nullptr, elementType,

flang/test/Lower/OpenMP/atomic-capture.f90

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,54 @@ subroutine pointers_in_atomic_capture()
102102
! are generated after the omp.atomic.capture operation:
103103
! CHECK-LABEL: func.func @_QPfunc_call_cleanup(
104104
subroutine func_call_cleanup(x, v, vv)
105+
interface
106+
integer function func(x)
107+
integer :: x
108+
end function func
109+
end interface
105110
integer :: x, v, vv
106111

107112
! CHECK: %[[VAL_7:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
108-
! CHECK: %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> f32
109-
! CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (f32) -> i32
113+
! CHECK: %[[VAL_8:.*]] = fir.call @_QPfunc(%[[VAL_7]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
110114
! CHECK: omp.atomic.capture {
111-
! CHECK: omp.atomic.read %{{.*}} = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
112-
! CHECK: omp.atomic.write %[[VAL_3]]#0 = %[[VAL_9]] : !fir.ref<i32>, i32
115+
! CHECK: omp.atomic.read %[[VAL_1:.*]]#0 = %[[VAL_3:.*]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
116+
! CHECK: omp.atomic.write %[[VAL_3]]#0 = %[[VAL_8]] : !fir.ref<i32>, i32
113117
! CHECK: }
114118
! CHECK: hlfir.end_associate %[[VAL_7]]#1, %[[VAL_7]]#2 : !fir.ref<i32>, i1
115119
!$omp atomic capture
116120
v = x
117121
x = func(vv + 1)
118122
!$omp end atomic
123+
124+
! CHECK: %[[VAL_12:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
125+
! CHECK: %[[VAL_13:.*]] = fir.call @_QPfunc(%[[VAL_12]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
126+
! CHECK: omp.atomic.capture {
127+
! CHECK: omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
128+
! CHECK: omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
129+
! CHECK: ^bb0(%[[VAL_14:.*]]: i32):
130+
! CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : i32
131+
! CHECK: omp.yield(%[[VAL_15]] : i32)
132+
! CHECK: }
133+
! CHECK: }
134+
! CHECK: hlfir.end_associate %[[VAL_12]]#1, %[[VAL_12]]#2 : !fir.ref<i32>, i1
135+
!$omp atomic capture
136+
v = x
137+
x = func(vv + 1) + x
138+
!$omp end atomic
139+
140+
! CHECK: %[[VAL_19:.*]]:3 = hlfir.associate %{{.*}} {adapt.valuebyref} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
141+
! CHECK: %[[VAL_20:.*]] = fir.call @_QPfunc(%[[VAL_19]]#0) fastmath<contract> : (!fir.ref<i32>) -> i32
142+
! CHECK: omp.atomic.capture {
143+
! CHECK: omp.atomic.update %[[VAL_3]]#0 : !fir.ref<i32> {
144+
! CHECK: ^bb0(%[[VAL_21:.*]]: i32):
145+
! CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : i32
146+
! CHECK: omp.yield(%[[VAL_22]] : i32)
147+
! CHECK: }
148+
! CHECK: omp.atomic.read %[[VAL_1]]#0 = %[[VAL_3]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
149+
! CHECK: }
150+
! CHECK: hlfir.end_associate %[[VAL_19]]#1, %[[VAL_19]]#2 : !fir.ref<i32>, i1
151+
!$omp atomic capture
152+
x = func(vv + 1) + x
153+
v = x
154+
!$omp end atomic
119155
end subroutine func_call_cleanup

0 commit comments

Comments
 (0)