@@ -5033,29 +5033,33 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
5033
5033
mlir::OperationState &result) {
5034
5034
auto &builder = parser.getBuilder ();
5035
5035
// Parse an opening `(` followed by induction variables followed by `)`
5036
- llvm::SmallVector<mlir::OpAsmParser::Argument, 4 > ivs;
5037
- if (parser.parseArgumentList (ivs, mlir::OpAsmParser::Delimiter::Paren))
5036
+ llvm::SmallVector<mlir::OpAsmParser::Argument, 4 > regionArgs;
5037
+
5038
+ if (parser.parseArgumentList (regionArgs, mlir::OpAsmParser::Delimiter::Paren))
5038
5039
return mlir::failure ();
5039
5040
5041
+ llvm::SmallVector<mlir::Type> argTypes (regionArgs.size (),
5042
+ builder.getIndexType ());
5043
+
5040
5044
// Parse loop bounds.
5041
5045
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > lower;
5042
5046
if (parser.parseEqual () ||
5043
- parser.parseOperandList (lower, ivs .size (),
5047
+ parser.parseOperandList (lower, regionArgs .size (),
5044
5048
mlir::OpAsmParser::Delimiter::Paren) ||
5045
5049
parser.resolveOperands (lower, builder.getIndexType (), result.operands ))
5046
5050
return mlir::failure ();
5047
5051
5048
5052
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > upper;
5049
5053
if (parser.parseKeyword (" to" ) ||
5050
- parser.parseOperandList (upper, ivs .size (),
5054
+ parser.parseOperandList (upper, regionArgs .size (),
5051
5055
mlir::OpAsmParser::Delimiter::Paren) ||
5052
5056
parser.resolveOperands (upper, builder.getIndexType (), result.operands ))
5053
5057
return mlir::failure ();
5054
5058
5055
5059
// Parse step values.
5056
5060
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > steps;
5057
5061
if (parser.parseKeyword (" step" ) ||
5058
- parser.parseOperandList (steps, ivs .size (),
5062
+ parser.parseOperandList (steps, regionArgs .size (),
5059
5063
mlir::OpAsmParser::Delimiter::Paren) ||
5060
5064
parser.resolveOperands (steps, builder.getIndexType (), result.operands ))
5061
5065
return mlir::failure ();
@@ -5086,20 +5090,72 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
5086
5090
builder.getArrayAttr (arrayAttr));
5087
5091
}
5088
5092
5089
- // Now parse the body.
5090
- mlir::Region *body = result.addRegion ();
5091
- for (auto &iv : ivs)
5092
- iv.type = builder.getIndexType ();
5093
- if (parser.parseRegion (*body, ivs))
5094
- return mlir::failure ();
5093
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands;
5094
+ if (succeeded (parser.parseOptionalKeyword (" local" ))) {
5095
+ std::size_t oldArgTypesSize = argTypes.size ();
5096
+ if (failed (parser.parseLParen ()))
5097
+ return mlir::failure ();
5098
+
5099
+ llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec;
5100
+ if (failed (parser.parseCommaSeparatedList ([&]() {
5101
+ if (failed (parser.parseAttribute (localSymbolVec.emplace_back ())))
5102
+ return mlir::failure ();
5103
+
5104
+ if (parser.parseOperand (localOperands.emplace_back ()) ||
5105
+ parser.parseArrow () ||
5106
+ parser.parseArgument (regionArgs.emplace_back ()))
5107
+ return mlir::failure ();
5108
+
5109
+ return mlir::success ();
5110
+ })))
5111
+ return mlir::failure ();
5112
+
5113
+ if (failed (parser.parseColon ()))
5114
+ return mlir::failure ();
5115
+
5116
+ if (failed (parser.parseCommaSeparatedList ([&]() {
5117
+ if (failed (parser.parseType (argTypes.emplace_back ())))
5118
+ return mlir::failure ();
5119
+
5120
+ return mlir::success ();
5121
+ })))
5122
+ return mlir::failure ();
5123
+
5124
+ if (regionArgs.size () != argTypes.size ())
5125
+ return parser.emitError (parser.getNameLoc (),
5126
+ " mismatch in number of local arg and types" );
5127
+
5128
+ if (failed (parser.parseRParen ()))
5129
+ return mlir::failure ();
5130
+
5131
+ for (auto operandType : llvm::zip_equal (
5132
+ localOperands, llvm::drop_begin (argTypes, oldArgTypesSize)))
5133
+ if (parser.resolveOperand (std::get<0 >(operandType),
5134
+ std::get<1 >(operandType), result.operands ))
5135
+ return mlir::failure ();
5136
+
5137
+ llvm::SmallVector<mlir::Attribute> symbolAttrs (localSymbolVec.begin (),
5138
+ localSymbolVec.end ());
5139
+ result.addAttribute (getLocalSymsAttrName (result.name ),
5140
+ builder.getArrayAttr (symbolAttrs));
5141
+ }
5095
5142
5096
5143
// Set `operandSegmentSizes` attribute.
5097
5144
result.addAttribute (DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
5098
5145
builder.getDenseI32ArrayAttr (
5099
5146
{static_cast <int32_t >(lower.size ()),
5100
5147
static_cast <int32_t >(upper.size ()),
5101
5148
static_cast <int32_t >(steps.size ()),
5102
- static_cast <int32_t >(reduceOperands.size ())}));
5149
+ static_cast <int32_t >(reduceOperands.size ()),
5150
+ static_cast <int32_t >(localOperands.size ())}));
5151
+
5152
+ // Now parse the body.
5153
+ for (auto [arg, type] : llvm::zip_equal (regionArgs, argTypes))
5154
+ arg.type = type;
5155
+
5156
+ mlir::Region *body = result.addRegion ();
5157
+ if (parser.parseRegion (*body, regionArgs))
5158
+ return mlir::failure ();
5103
5159
5104
5160
// Parse attributes.
5105
5161
if (parser.parseOptionalAttrDict (result.attributes ))
@@ -5109,8 +5165,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
5109
5165
}
5110
5166
5111
5167
void fir::DoConcurrentLoopOp::print (mlir::OpAsmPrinter &p) {
5112
- p << " (" << getBody ()->getArguments () << " ) = (" << getLowerBound ()
5113
- << " ) to (" << getUpperBound () << " ) step (" << getStep () << " )" ;
5168
+ p << " (" << getBody ()->getArguments ().slice (0 , getNumInductionVars ())
5169
+ << " ) = (" << getLowerBound () << " ) to (" << getUpperBound () << " ) step ("
5170
+ << getStep () << " )" ;
5114
5171
5115
5172
if (!getReduceOperands ().empty ()) {
5116
5173
p << " reduce(" ;
@@ -5123,12 +5180,27 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
5123
5180
p << ' )' ;
5124
5181
}
5125
5182
5183
+ if (!getLocalVars ().empty ()) {
5184
+ p << " local(" ;
5185
+ llvm::interleaveComma (llvm::zip_equal (getLocalSymsAttr (), getLocalVars (),
5186
+ getRegionLocalArgs ()),
5187
+ p, [&](auto it) {
5188
+ p << std::get<0 >(it) << " " << std::get<1 >(it)
5189
+ << " -> " << std::get<2 >(it);
5190
+ });
5191
+ p << " : " ;
5192
+ llvm::interleaveComma (getLocalVars (), p,
5193
+ [&](auto it) { p << it.getType (); });
5194
+ p << " )" ;
5195
+ }
5196
+
5126
5197
p << ' ' ;
5127
5198
p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false );
5128
5199
p.printOptionalAttrDict (
5129
5200
(*this )->getAttrs (),
5130
5201
/* elidedAttrs=*/ {DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
5131
- DoConcurrentLoopOp::getReduceAttrsAttrName ()});
5202
+ DoConcurrentLoopOp::getReduceAttrsAttrName (),
5203
+ DoConcurrentLoopOp::getLocalSymsAttrName ()});
5132
5204
}
5133
5205
5134
5206
llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions () {
@@ -5139,6 +5211,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
5139
5211
mlir::Operation::operand_range lbValues = getLowerBound ();
5140
5212
mlir::Operation::operand_range ubValues = getUpperBound ();
5141
5213
mlir::Operation::operand_range stepValues = getStep ();
5214
+ mlir::Operation::operand_range localVars = getLocalVars ();
5142
5215
5143
5216
if (lbValues.empty ())
5144
5217
return emitOpError (
@@ -5152,11 +5225,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
5152
5225
// Check that the body defines the same number of block arguments as the
5153
5226
// number of tuple elements in step.
5154
5227
mlir::Block *body = getBody ();
5155
- if (body->getNumArguments () != stepValues.size ())
5228
+ unsigned numIndVarArgs = body->getNumArguments () - localVars.size ();
5229
+
5230
+ if (numIndVarArgs != stepValues.size ())
5156
5231
return emitOpError () << " expects the same number of induction variables: "
5157
5232
<< body->getNumArguments ()
5158
5233
<< " as bound and step values: " << stepValues.size ();
5159
- for (auto arg : body->getArguments ())
5234
+ for (auto arg : body->getArguments (). slice ( 0 , numIndVarArgs) )
5160
5235
if (!arg.getType ().isIndex ())
5161
5236
return emitOpError (
5162
5237
" expects arguments for the induction variable to be of index type" );
@@ -5171,7 +5246,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
5171
5246
5172
5247
std::optional<llvm::SmallVector<mlir::Value>>
5173
5248
fir::DoConcurrentLoopOp::getLoopInductionVars () {
5174
- return llvm::SmallVector<mlir::Value>{getBody ()->getArguments ()};
5249
+ return llvm::SmallVector<mlir::Value>{
5250
+ getBody ()->getArguments ().slice (0 , getLowerBound ().size ())};
5175
5251
}
5176
5252
5177
5253
// ===----------------------------------------------------------------------===//
0 commit comments