blob: dbfcbd52091a221a03ad56d4d117beb082bdf71c [file] [log] [blame]
River Riddlea9d40152019-08-13 23:42:411#!/usr/bin/env python
2"""A script to generate FileCheck statements for mlir unit tests.
3
4This script is a utility to add FileCheck patterns to an mlir file.
5
6NOTE: The input .mlir is expected to be the output from the parser, not a
7stripped down variant.
8
9Example usage:
10$ generate-test-checks.py foo.mlir
11$ mlir-opt foo.mlir -transformation | generate-test-checks.py
12
13The script will heuristically insert CHECK/CHECK-LABEL commands for each line
14within the file. By default this script will also try to insert string
15substitution blocks for all SSA value names. The script is designed to make
16adding checks to a test case fast, it is *not* designed to be authoritative
17about what constitutes a good test!
18"""
19
Mehdi Amini56222a02019-12-23 17:35:3620# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
21# See https://llvm.org/LICENSE.txt for license information.
22# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
River Riddlea9d40152019-08-13 23:42:4123
24import argparse
25import os # Used to advertise this file's name ("autogenerated_note").
26import re
27import sys
River Riddlea9d40152019-08-13 23:42:4128
29ADVERT = '// NOTE: Assertions have been autogenerated by '
30
31# Regex command to match an SSA identifier.
32SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
33SSA_RE = re.compile(SSA_RE_STR)
34
35
36# Class used to generate and manage string substitution blocks for SSA value
37# names.
38class SSAVariableNamer:
39
40 def __init__(self):
41 self.scopes = []
42 self.name_counter = 0
43
Kazuaki Ishizakif28c5ac2019-10-20 16:44:0644 # Generate a substitution name for the given ssa value name.
River Riddlea9d40152019-08-13 23:42:4145 def generate_name(self, ssa_name):
46 variable = 'VAL_' + str(self.name_counter)
47 self.name_counter += 1
48 self.scopes[-1][ssa_name] = variable
49 return variable
50
51 # Push a new variable name scope.
52 def push_name_scope(self):
53 self.scopes.append({})
54
55 # Pop the last variable name scope.
56 def pop_name_scope(self):
57 self.scopes.pop()
58
59
60# Process a line of input that has been split at each SSA identifier '%'.
61def process_line(line_chunks, variable_namer):
62 output_line = ''
63
64 # Process the rest that contained an SSA value name.
65 for chunk in line_chunks:
66 m = SSA_RE.match(chunk)
67 ssa_name = m.group(0)
68
69 # Check if an existing variable exists for this name.
70 variable = None
71 for scope in variable_namer.scopes:
72 variable = scope.get(ssa_name)
73 if variable is not None:
74 break
75
76 # If one exists, then output the existing name.
77 if variable is not None:
Sean Silva0aa97ad2020-05-13 03:15:0978 output_line += '%[[' + variable + ']]'
River Riddlea9d40152019-08-13 23:42:4179 else:
80 # Otherwise, generate a new variable.
81 variable = variable_namer.generate_name(ssa_name)
Sean Silva0aa97ad2020-05-13 03:15:0982 output_line += '%[[' + variable + ':.*]]'
River Riddlea9d40152019-08-13 23:42:4183
84 # Append the non named group.
85 output_line += chunk[len(ssa_name):]
86
87 return output_line + '\n'
88
89
River Riddle3e2ac622019-09-17 20:59:1290# Pre-process a line of input to remove any character sequences that will be
91# problematic with FileCheck.
92def preprocess_line(line):
93 # Replace any double brackets, '[[' with escaped replacements. '[['
94 # corresponds to variable names in FileCheck.
95 output_line = line.replace('[[', '{{\\[\\[}}')
96
97 # Replace any single brackets that are followed by an SSA identifier, the
98 # identifier will be replace by a variable; Creating the same situation as
99 # above.
100 output_line = output_line.replace('[%', '{{\\[}}%')
101
102 return output_line
103
104
River Riddlea9d40152019-08-13 23:42:41105def main():
River Riddlea9d40152019-08-13 23:42:41106 parser = argparse.ArgumentParser(
Jacques Pienaar5002e982019-09-20 00:45:02107 description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
River Riddlea9d40152019-08-13 23:42:41108 parser.add_argument(
109 '--check-prefix', default='CHECK', help='Prefix to use from check file.')
110 parser.add_argument(
111 '-o',
112 '--output',
113 nargs='?',
114 type=argparse.FileType('w'),
115 default=sys.stdout)
116 parser.add_argument(
117 'input',
118 nargs='?',
119 type=argparse.FileType('r'),
120 default=sys.stdin)
121 args = parser.parse_args()
122
123 # Open the given input file.
124 input_lines = [l.rstrip() for l in args.input]
125 args.input.close()
126
127 output_lines = []
128
129 # Generate a note used for the generated check file.
130 script_name = os.path.basename(__file__)
131 autogenerated_note = (ADVERT + 'utils/' + script_name)
132 output_lines.append(autogenerated_note + '\n')
133
134 # A map containing data used for naming SSA value names.
135 variable_namer = SSAVariableNamer()
136 for input_line in input_lines:
137 if not input_line:
138 continue
139 lstripped_input_line = input_line.lstrip()
140
141 # Lines with blocks begin with a ^. These lines have a trailing comment
142 # that needs to be stripped.
143 is_block = lstripped_input_line[0] == '^'
144 if is_block:
145 input_line = input_line.rsplit('//', 1)[0].rstrip()
146
147 # Top-level operations are heuristically the operations at nesting level 1.
148 is_toplevel_op = (not is_block and input_line.startswith(' ') and
149 input_line[2] != ' ' and input_line[2] != '}')
150
151 # If the line starts with a '}', pop the last name scope.
152 if lstripped_input_line[0] == '}':
153 variable_namer.pop_name_scope()
154
155 # If the line ends with a '{', push a new name scope.
156 if input_line[-1] == '{':
157 variable_namer.push_name_scope()
158
River Riddle3e2ac622019-09-17 20:59:12159 # Preprocess the input to remove any sequences that may be problematic with
160 # FileCheck.
161 input_line = preprocess_line(input_line)
162
River Riddlea9d40152019-08-13 23:42:41163 # Split the line at the each SSA value name.
164 ssa_split = input_line.split('%')
165
166 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
167 if not is_toplevel_op or not ssa_split[0]:
168 output_line = '// ' + args.check_prefix + ': '
169 # Pad to align with the 'LABEL' statements.
170 output_line += (' ' * len('-LABEL'))
171
172 # Output the first line chunk that does not contain an SSA name.
173 output_line += ssa_split[0]
174
175 # Process the rest of the input line.
176 output_line += process_line(ssa_split[1:], variable_namer)
177
178 else:
179 # Append a newline to the output to separate the logical blocks.
180 output_lines.append('\n')
181 output_line = '// ' + args.check_prefix + '-LABEL: '
182
183 # Output the first line chunk that does not contain an SSA name for the
184 # label.
185 output_line += ssa_split[0] + '\n'
186
187 # Process the rest of the input line on a separate check line.
188 if len(ssa_split) > 1:
189 output_line += '// ' + args.check_prefix + '-SAME: '
190
191 # Pad to align with the original position in the line.
192 output_line += ' ' * len(ssa_split[0])
193
194 # Process the rest of the line.
195 output_line += process_line(ssa_split[1:], variable_namer)
196
197 # Append the output line.
198 output_lines.append(output_line)
199
200 # Write the output.
201 for output_line in output_lines:
202 args.output.write(output_line)
203 args.output.write('\n')
204 args.output.close()
205
206
207if __name__ == '__main__':
208 main()