diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index acc0c6967c739..458b780806144 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -3647,6 +3647,13 @@ def fir_DoConcurrentOp : fir_Op<"do_concurrent", let hasVerifier = 1; } +def fir_LocalSpecifier { + dag arguments = (ins + Variadic:$local_vars, + OptionalAttr:$local_syms + ); +} + def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, @@ -3700,7 +3707,7 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop", LLVM. }]; - let arguments = (ins + defvar opArgs = (ins Variadic:$lowerBound, Variadic:$upperBound, Variadic:$step, @@ -3709,16 +3716,40 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop", OptionalAttr:$loopAnnotation ); + let arguments = !con(opArgs, fir_LocalSpecifier.arguments); + let regions = (region SizedRegion<1>:$region); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; let extraClassDeclaration = [{ + unsigned getNumInductionVars() { return getLowerBound().size(); } + + unsigned getNumLocalOperands() { return getLocalVars().size(); } + + mlir::Block::BlockArgListType getInductionVars() { + return getBody()->getArguments().slice(0, getNumInductionVars()); + } + + mlir::Block::BlockArgListType getRegionLocalArgs() { + return getBody()->getArguments().slice(getNumInductionVars(), + getNumLocalOperands()); + } + + /// Number of operands controlling the loop + unsigned getNumControlOperands() { return getLowerBound().size() * 3; } + // Get Number of reduction operands unsigned getNumReduceOperands() { return getReduceOperands().size(); } + + mlir::Operation::operand_range getLocalOperands() { + return getOperands() + .slice(getNumControlOperands() + getNumReduceOperands(), + getNumLocalOperands()); + } }]; } diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 8da05255d5f41..0a61f61ab8f75 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2460,7 +2460,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { nestReduceAttrs.empty() ? nullptr : mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs), - nullptr); + nullptr, /*local_vars=*/std::nullopt, /*local_syms=*/nullptr); llvm::SmallVector loopBlockArgTypes( incrementLoopNestInfo.size(), builder->getIndexType()); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 955acbe7018d3..332cca1ab9f95 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -5033,21 +5033,25 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` - llvm::SmallVector ivs; - if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren)) + llvm::SmallVector regionArgs; + + if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren)) return mlir::failure(); + llvm::SmallVector argTypes(regionArgs.size(), + builder.getIndexType()); + // Parse loop bounds. llvm::SmallVector lower; if (parser.parseEqual() || - parser.parseOperandList(lower, ivs.size(), + parser.parseOperandList(lower, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(lower, builder.getIndexType(), result.operands)) return mlir::failure(); llvm::SmallVector upper; if (parser.parseKeyword("to") || - parser.parseOperandList(upper, ivs.size(), + parser.parseOperandList(upper, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(upper, builder.getIndexType(), result.operands)) return mlir::failure(); @@ -5055,7 +5059,7 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, // Parse step values. llvm::SmallVector steps; if (parser.parseKeyword("step") || - parser.parseOperandList(steps, ivs.size(), + parser.parseOperandList(steps, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(steps, builder.getIndexType(), result.operands)) return mlir::failure(); @@ -5086,12 +5090,55 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, builder.getArrayAttr(arrayAttr)); } - // Now parse the body. - mlir::Region *body = result.addRegion(); - for (auto &iv : ivs) - iv.type = builder.getIndexType(); - if (parser.parseRegion(*body, ivs)) - return mlir::failure(); + llvm::SmallVector localOperands; + if (succeeded(parser.parseOptionalKeyword("local"))) { + std::size_t oldArgTypesSize = argTypes.size(); + if (failed(parser.parseLParen())) + return mlir::failure(); + + llvm::SmallVector localSymbolVec; + if (failed(parser.parseCommaSeparatedList([&]() { + if (failed(parser.parseAttribute(localSymbolVec.emplace_back()))) + return mlir::failure(); + + if (parser.parseOperand(localOperands.emplace_back()) || + parser.parseArrow() || + parser.parseArgument(regionArgs.emplace_back())) + return mlir::failure(); + + return mlir::success(); + }))) + return mlir::failure(); + + if (failed(parser.parseColon())) + return mlir::failure(); + + if (failed(parser.parseCommaSeparatedList([&]() { + if (failed(parser.parseType(argTypes.emplace_back()))) + return mlir::failure(); + + return mlir::success(); + }))) + return mlir::failure(); + + if (regionArgs.size() != argTypes.size()) + return parser.emitError(parser.getNameLoc(), + "mismatch in number of local arg and types"); + + if (failed(parser.parseRParen())) + return mlir::failure(); + + for (auto operandType : llvm::zip_equal( + localOperands, llvm::drop_begin(argTypes, oldArgTypesSize))) + if (parser.resolveOperand(std::get<0>(operandType), + std::get<1>(operandType), result.operands)) + return mlir::failure(); + + llvm::SmallVector symbolAttrs(localSymbolVec.begin(), + localSymbolVec.end()); + result.addAttribute(getLocalSymsAttrName(result.name), + builder.getArrayAttr(symbolAttrs)); + } // Set `operandSegmentSizes` attribute. result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(), @@ -5099,7 +5146,16 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, {static_cast(lower.size()), static_cast(upper.size()), static_cast(steps.size()), - static_cast(reduceOperands.size())})); + static_cast(reduceOperands.size()), + static_cast(localOperands.size())})); + + // Now parse the body. + for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes)) + arg.type = type; + + mlir::Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return mlir::failure(); // Parse attributes. if (parser.parseOptionalAttrDict(result.attributes)) @@ -5109,8 +5165,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, } void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) { - p << " (" << getBody()->getArguments() << ") = (" << getLowerBound() - << ") to (" << getUpperBound() << ") step (" << getStep() << ")"; + p << " (" << getBody()->getArguments().slice(0, getNumInductionVars()) + << ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step (" + << getStep() << ")"; if (!getReduceOperands().empty()) { p << " reduce("; @@ -5123,12 +5180,27 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) { p << ')'; } + if (!getLocalVars().empty()) { + p << " local("; + llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(), + getRegionLocalArgs()), + p, [&](auto it) { + p << std::get<0>(it) << " " << std::get<1>(it) + << " -> " << std::get<2>(it); + }); + p << " : "; + llvm::interleaveComma(getLocalVars(), p, + [&](auto it) { p << it.getType(); }); + p << ")"; + } + p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( (*this)->getAttrs(), /*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(), - DoConcurrentLoopOp::getReduceAttrsAttrName()}); + DoConcurrentLoopOp::getReduceAttrsAttrName(), + DoConcurrentLoopOp::getLocalSymsAttrName()}); } llvm::SmallVector fir::DoConcurrentLoopOp::getLoopRegions() { @@ -5139,6 +5211,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { mlir::Operation::operand_range lbValues = getLowerBound(); mlir::Operation::operand_range ubValues = getUpperBound(); mlir::Operation::operand_range stepValues = getStep(); + mlir::Operation::operand_range localVars = getLocalVars(); if (lbValues.empty()) return emitOpError( @@ -5152,11 +5225,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { // Check that the body defines the same number of block arguments as the // number of tuple elements in step. mlir::Block *body = getBody(); - if (body->getNumArguments() != stepValues.size()) + unsigned numIndVarArgs = body->getNumArguments() - localVars.size(); + + if (numIndVarArgs != stepValues.size()) return emitOpError() << "expects the same number of induction variables: " << body->getNumArguments() << " as bound and step values: " << stepValues.size(); - for (auto arg : body->getArguments()) + for (auto arg : body->getArguments().slice(0, numIndVarArgs)) if (!arg.getType().isIndex()) return emitOpError( "expects arguments for the induction variable to be of index type"); @@ -5171,7 +5246,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { std::optional> fir::DoConcurrentLoopOp::getLoopInductionVars() { - return llvm::SmallVector{getBody()->getArguments()}; + return llvm::SmallVector{ + getBody()->getArguments().slice(0, getLowerBound().size())}; } //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/do_concurrent.fir b/flang/test/Fir/do_concurrent.fir index 4e55777402428..cc1197ba56bd7 100644 --- a/flang/test/Fir/do_concurrent.fir +++ b/flang/test/Fir/do_concurrent.fir @@ -91,7 +91,6 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index, // CHECK: } // CHECK: } - fir.local {type = local} @local_privatizer : i32 // CHECK: fir.local {type = local} @[[LOCAL_PRIV_SYM:local_privatizer]] : i32 @@ -109,3 +108,56 @@ fir.local {type = local_init} @local_init_privatizer : i32 copy { // CHECK: fir.store %[[ORIG_VAL_LD]] to %[[LOCAL_VAL]] : !fir.ref // CHECK: fir.yield(%[[LOCAL_VAL]] : !fir.ref) // CHECK: } + +func.func @do_concurrent_with_locality_specs() { + %3 = fir.alloca i32 {bindc_name = "local_init_var"} + %4:2 = hlfir.declare %3 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %5 = fir.alloca i32 {bindc_name = "local_var"} + %6:2 = hlfir.declare %5 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + fir.do_concurrent { + %9 = fir.alloca i32 {bindc_name = "i"} + %10:2 = hlfir.declare %9 {uniq_name = "_QFdo_concurrentEi"} : (!fir.ref) -> (!fir.ref, !fir.ref) + + fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) + local(@local_privatizer %6#0 -> %arg1, @local_init_privatizer %4#0 -> %arg2 : !fir.ref, !fir.ref) { + %11 = fir.convert %arg0 : (index) -> i32 + fir.store %11 to %10#0 : !fir.ref + %13:2 = hlfir.declare %arg1 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref) -> (!fir.ref, !fir.ref) + %15:2 = hlfir.declare %arg2 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref) -> (!fir.ref, !fir.ref) + } + } + return +} + +// CHECK-LABEL: func.func @do_concurrent_with_locality_specs() { +// CHECK: %[[LOC_INIT_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_init_var"} +// CHECK: %[[LOC_INIT_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ALLOC]] + +// CHECK: %[[LOC_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_var"} +// CHECK: %[[LOC_DECL:.*]]:2 = hlfir.declare %[[LOC_ALLOC]] + +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C10:.*]] = arith.constant 10 : index + +// CHECK: fir.do_concurrent { +// CHECK: %[[DC_I_ALLOC:.*]] = fir.alloca i32 {bindc_name = "i"} +// CHECK: %[[DC_I_DECL:.*]]:2 = hlfir.declare %[[DC_I_ALLOC]] + +// CHECK: fir.do_concurrent.loop (%[[IV:.*]]) = (%[[C1]]) to +// CHECK-SAME: (%[[C10]]) step (%[[C1]]) +// CHECK-SAME: local(@[[LOCAL_PRIV_SYM]] %[[LOC_DECL]]#0 -> %[[LOC_ARG:[^,]*]], +// CHECK-SAME: @[[LOCAL_INIT_PRIV_SYM]] %[[LOC_INIT_DECL]]#0 -> %[[LOC_INIT_ARG:.*]] : +// CHECK-SAME: !fir.ref, !fir.ref) { + +// CHECK: %[[IV_CVT:.*]] = fir.convert %[[IV]] : (index) -> i32 +// CHECK: fir.store %[[IV_CVT]] to %[[DC_I_DECL]]#0 : !fir.ref + +// CHECK: %[[LOC_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_ARG]] +// CHECK: %[[LOC_INIT_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ARG]] +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir index 733227339bc39..1de48b87365b3 100644 --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -1196,7 +1196,7 @@ func.func @dc_0d() { func.func @dc_invalid_parent(%arg0: index, %arg1: index) { // expected-error@+1 {{'fir.do_concurrent.loop' op expects parent op 'fir.do_concurrent'}} - "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array}> ({ ^bb0(%arg2: index): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array}> : () -> !fir.ref }) : (index, index) -> () @@ -1208,7 +1208,7 @@ func.func @dc_invalid_parent(%arg0: index, %arg1: index) { func.func @dc_invalid_control(%arg0: index, %arg1: index) { // expected-error@+2 {{'fir.do_concurrent.loop' op different number of tuple elements for lowerBound, upperBound or step}} fir.do_concurrent { - "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array}> ({ ^bb0(%arg2: index): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array}> : () -> !fir.ref }) : (index, index) -> () @@ -1221,7 +1221,7 @@ func.func @dc_invalid_control(%arg0: index, %arg1: index) { func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) { // expected-error@+2 {{'fir.do_concurrent.loop' op expects the same number of induction variables: 2 as bound and step values: 1}} fir.do_concurrent { - "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array}> ({ ^bb0(%arg3: index, %arg4: index): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array}> : () -> !fir.ref }) : (index, index, index) -> () @@ -1234,7 +1234,7 @@ func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) { func.func @dc_invalid_ind_var_type(%arg0: index, %arg1: index) { // expected-error@+2 {{'fir.do_concurrent.loop' op expects arguments for the induction variable to be of index type}} fir.do_concurrent { - "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array}> ({ ^bb0(%arg3: i32): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array}> : () -> !fir.ref }) : (index, index, index) -> () @@ -1248,7 +1248,7 @@ func.func @dc_invalid_reduction(%arg0: index, %arg1: index) { %sum = fir.alloca i32 // expected-error@+2 {{'fir.do_concurrent.loop' op mismatch in number of reduction variables and reduction attributes}} fir.do_concurrent { - "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array}> ({ ^bb0(%arg3: index): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array}> : () -> !fir.ref }) : (index, index, index, !fir.ref) -> ()