blob: 474f812c9c0bc181f8575f87de0654af679415d0 [file] [log] [blame]
Wen-Heng (Jack) Chung7bfcb912020-06-12 00:15:161#!/usr/bin/env python3
River Riddlea9d40152019-08-13 23:42:412"""A script to generate FileCheck statements for mlir unit tests.
3
4This script is a utility to add FileCheck patterns to an mlir file.
5
6NOTE: The input .mlir is expected to be the output from the parser, not a
7stripped down variant.
8
9Example usage:
10$ generate-test-checks.py foo.mlir
11$ mlir-opt foo.mlir -transformation | generate-test-checks.py
Tim Shenb877f332020-06-16 18:38:2612$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
River Riddlea9d40152019-08-13 23:42:4115
Tim Shenb877f332020-06-16 18:38:2616The script will heuristically generate CHECK/CHECK-LABEL commands for each line
River Riddlea9d40152019-08-13 23:42:4117within the file. By default this script will also try to insert string
Tim Shenb877f332020-06-16 18:38:2618substitution blocks for all SSA value names. If --source file is specified, the
19script will attempt to insert the generated CHECKs to the source file by looking
20for line positions matched by --source_delim_regex.
21
22The script is designed to make adding checks to a test case fast, it is *not*
23designed to be authoritative about what constitutes a good test!
River Riddlea9d40152019-08-13 23:42:4124"""
25
Mehdi Amini56222a02019-12-23 17:35:3626# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27# See https://llvm.org/LICENSE.txt for license information.
28# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
River Riddlea9d40152019-08-13 23:42:4129
30import argparse
31import os # Used to advertise this file's name ("autogenerated_note").
32import re
33import sys
River Riddlea9d40152019-08-13 23:42:4134
Mehdi Amini4e7c0a372021-09-20 17:28:0135ADVERT_BEGIN = '// NOTE: Assertions have been autogenerated by '
36ADVERT_END = """
37// The script is designed to make adding checks to
38// a test case fast, it is *not* designed to be authoritative
39// about what constitutes a good test! The CHECK should be
40// minimized and named to reflect the test intent.
41"""
42
River Riddlea9d40152019-08-13 23:42:4143
44# Regex command to match an SSA identifier.
45SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
46SSA_RE = re.compile(SSA_RE_STR)
47
48
49# Class used to generate and manage string substitution blocks for SSA value
50# names.
51class SSAVariableNamer:
52
53 def __init__(self):
54 self.scopes = []
55 self.name_counter = 0
56
Kazuaki Ishizakif28c5ac2019-10-20 16:44:0657 # Generate a substitution name for the given ssa value name.
River Riddlea9d40152019-08-13 23:42:4158 def generate_name(self, ssa_name):
59 variable = 'VAL_' + str(self.name_counter)
60 self.name_counter += 1
61 self.scopes[-1][ssa_name] = variable
62 return variable
63
64 # Push a new variable name scope.
65 def push_name_scope(self):
66 self.scopes.append({})
67
68 # Pop the last variable name scope.
69 def pop_name_scope(self):
70 self.scopes.pop()
71
Tim Shenb877f332020-06-16 18:38:2672 # Return the level of nesting (number of pushed scopes).
Tim Shen25b38062020-06-16 02:41:0373 def num_scopes(self):
74 return len(self.scopes)
75
Tim Shenb877f332020-06-16 18:38:2676 # Reset the counter.
Tim Shen25b38062020-06-16 02:41:0377 def clear_counter(self):
78 self.name_counter = 0
79
River Riddlea9d40152019-08-13 23:42:4180
81# Process a line of input that has been split at each SSA identifier '%'.
82def process_line(line_chunks, variable_namer):
83 output_line = ''
84
85 # Process the rest that contained an SSA value name.
86 for chunk in line_chunks:
87 m = SSA_RE.match(chunk)
88 ssa_name = m.group(0)
89
90 # Check if an existing variable exists for this name.
91 variable = None
92 for scope in variable_namer.scopes:
93 variable = scope.get(ssa_name)
94 if variable is not None:
95 break
96
97 # If one exists, then output the existing name.
98 if variable is not None:
Sean Silva0aa97ad2020-05-13 03:15:0999 output_line += '%[[' + variable + ']]'
River Riddlea9d40152019-08-13 23:42:41100 else:
101 # Otherwise, generate a new variable.
102 variable = variable_namer.generate_name(ssa_name)
Sean Silva0aa97ad2020-05-13 03:15:09103 output_line += '%[[' + variable + ':.*]]'
River Riddlea9d40152019-08-13 23:42:41104
105 # Append the non named group.
106 output_line += chunk[len(ssa_name):]
107
Tim Shena6150de2020-06-16 18:28:36108 return output_line.rstrip() + '\n'
River Riddlea9d40152019-08-13 23:42:41109
110
Tim Shenb877f332020-06-16 18:38:26111# Process the source file lines. The source file doesn't have to be .mlir.
Tim Shen25b38062020-06-16 02:41:03112def process_source_lines(source_lines, note, args):
113 source_split_re = re.compile(args.source_delim_regex)
114
115 source_segments = [[]]
116 for line in source_lines:
Tim Shenb877f332020-06-16 18:38:26117 # Remove previous note.
Tim Shen25b38062020-06-16 02:41:03118 if line == note:
119 continue
Tim Shenb877f332020-06-16 18:38:26120 # Remove previous CHECK lines.
Tim Shen25b38062020-06-16 02:41:03121 if line.find(args.check_prefix) != -1:
122 continue
Tim Shenb877f332020-06-16 18:38:26123 # Segment the file based on --source_delim_regex.
Tim Shen25b38062020-06-16 02:41:03124 if source_split_re.search(line):
125 source_segments.append([])
126
127 source_segments[-1].append(line + '\n')
128 return source_segments
129
130
River Riddle3e2ac622019-09-17 20:59:12131# Pre-process a line of input to remove any character sequences that will be
132# problematic with FileCheck.
133def preprocess_line(line):
134 # Replace any double brackets, '[[' with escaped replacements. '[['
135 # corresponds to variable names in FileCheck.
136 output_line = line.replace('[[', '{{\\[\\[}}')
137
138 # Replace any single brackets that are followed by an SSA identifier, the
139 # identifier will be replace by a variable; Creating the same situation as
140 # above.
141 output_line = output_line.replace('[%', '{{\\[}}%')
142
143 return output_line
144
145
River Riddlea9d40152019-08-13 23:42:41146def main():
River Riddlea9d40152019-08-13 23:42:41147 parser = argparse.ArgumentParser(
Jacques Pienaar5002e982019-09-20 00:45:02148 description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
River Riddlea9d40152019-08-13 23:42:41149 parser.add_argument(
150 '--check-prefix', default='CHECK', help='Prefix to use from check file.')
151 parser.add_argument(
152 '-o',
153 '--output',
154 nargs='?',
155 type=argparse.FileType('w'),
Tim Shen25b38062020-06-16 02:41:03156 default=None)
River Riddlea9d40152019-08-13 23:42:41157 parser.add_argument(
158 'input',
159 nargs='?',
160 type=argparse.FileType('r'),
161 default=sys.stdin)
Tim Shen25b38062020-06-16 02:41:03162 parser.add_argument(
163 '--source', type=str,
164 help='Print each CHECK chunk before each delimeter line in the source'
165 'file, respectively. The delimeter lines are identified by '
166 '--source_delim_regex.')
167 parser.add_argument('--source_delim_regex', type=str, default='func @')
168 parser.add_argument(
169 '--starts_from_scope', type=int, default=1,
170 help='Omit the top specified level of content. For example, by default '
171 'it omits "module {"')
172 parser.add_argument('-i', '--inplace', action='store_true', default=False)
173
River Riddlea9d40152019-08-13 23:42:41174 args = parser.parse_args()
175
176 # Open the given input file.
177 input_lines = [l.rstrip() for l in args.input]
178 args.input.close()
179
River Riddlea9d40152019-08-13 23:42:41180 # Generate a note used for the generated check file.
181 script_name = os.path.basename(__file__)
Mehdi Amini4e7c0a372021-09-20 17:28:01182 autogenerated_note = (ADVERT_BEGIN + 'utils/' + script_name + "\n" + ADVERT_END)
River Riddlea9d40152019-08-13 23:42:41183
Tim Shen25b38062020-06-16 02:41:03184 source_segments = None
185 if args.source:
186 source_segments = process_source_lines(
187 [l.rstrip() for l in open(args.source, 'r')],
188 autogenerated_note,
189 args
190 )
191
192 if args.inplace:
193 assert args.output is None
194 output = open(args.source, 'w')
195 elif args.output is None:
196 output = sys.stdout
197 else:
198 output = args.output
199
200 output_segments = [[]]
River Riddlea9d40152019-08-13 23:42:41201 # A map containing data used for naming SSA value names.
202 variable_namer = SSAVariableNamer()
203 for input_line in input_lines:
204 if not input_line:
205 continue
206 lstripped_input_line = input_line.lstrip()
207
208 # Lines with blocks begin with a ^. These lines have a trailing comment
209 # that needs to be stripped.
210 is_block = lstripped_input_line[0] == '^'
211 if is_block:
212 input_line = input_line.rsplit('//', 1)[0].rstrip()
213
Tim Shen25b38062020-06-16 02:41:03214 cur_level = variable_namer.num_scopes()
River Riddlea9d40152019-08-13 23:42:41215
216 # If the line starts with a '}', pop the last name scope.
217 if lstripped_input_line[0] == '}':
218 variable_namer.pop_name_scope()
Tim Shen25b38062020-06-16 02:41:03219 cur_level = variable_namer.num_scopes()
River Riddlea9d40152019-08-13 23:42:41220
221 # If the line ends with a '{', push a new name scope.
222 if input_line[-1] == '{':
223 variable_namer.push_name_scope()
Tim Shen25b38062020-06-16 02:41:03224 if cur_level == args.starts_from_scope:
225 output_segments.append([])
226
227 # Omit lines at the near top level e.g. "module {".
228 if cur_level < args.starts_from_scope:
229 continue
230
231 if len(output_segments[-1]) == 0:
232 variable_namer.clear_counter()
River Riddlea9d40152019-08-13 23:42:41233
River Riddle3e2ac622019-09-17 20:59:12234 # Preprocess the input to remove any sequences that may be problematic with
235 # FileCheck.
236 input_line = preprocess_line(input_line)
237
River Riddlea9d40152019-08-13 23:42:41238 # Split the line at the each SSA value name.
239 ssa_split = input_line.split('%')
240
241 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
Tim Shen25b38062020-06-16 02:41:03242 if len(output_segments[-1]) != 0 or not ssa_split[0]:
River Riddlea9d40152019-08-13 23:42:41243 output_line = '// ' + args.check_prefix + ': '
244 # Pad to align with the 'LABEL' statements.
245 output_line += (' ' * len('-LABEL'))
246
247 # Output the first line chunk that does not contain an SSA name.
248 output_line += ssa_split[0]
249
250 # Process the rest of the input line.
251 output_line += process_line(ssa_split[1:], variable_namer)
252
253 else:
River Riddlea9d40152019-08-13 23:42:41254 # Output the first line chunk that does not contain an SSA name for the
255 # label.
Tim Shen25b38062020-06-16 02:41:03256 output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
River Riddlea9d40152019-08-13 23:42:41257
Tim Shen25b38062020-06-16 02:41:03258 # Process the rest of the input line on separate check lines.
259 for argument in ssa_split[1:]:
River Riddlea9d40152019-08-13 23:42:41260 output_line += '// ' + args.check_prefix + '-SAME: '
261
262 # Pad to align with the original position in the line.
263 output_line += ' ' * len(ssa_split[0])
264
265 # Process the rest of the line.
Tim Shen25b38062020-06-16 02:41:03266 output_line += process_line([argument], variable_namer)
River Riddlea9d40152019-08-13 23:42:41267
268 # Append the output line.
Tim Shen25b38062020-06-16 02:41:03269 output_segments[-1].append(output_line)
270
271 output.write(autogenerated_note + '\n')
River Riddlea9d40152019-08-13 23:42:41272
273 # Write the output.
Tim Shen25b38062020-06-16 02:41:03274 if source_segments:
275 assert len(output_segments) == len(source_segments)
276 for check_segment, source_segment in zip(output_segments, source_segments):
277 for line in check_segment:
278 output.write(line)
279 for line in source_segment:
280 output.write(line)
281 else:
282 for segment in output_segments:
283 output.write('\n')
284 for output_line in segment:
285 output.write(output_line)
286 output.write('\n')
287 output.close()
River Riddlea9d40152019-08-13 23:42:41288
289
290if __name__ == '__main__':
291 main()