blob: e0115e51f269ec03e36875006fd7f6e92e242041 [file] [log] [blame]
River Riddlea9d40152019-08-13 23:42:411#!/usr/bin/env python
2"""A script to generate FileCheck statements for mlir unit tests.
3
4This script is a utility to add FileCheck patterns to an mlir file.
5
6NOTE: The input .mlir is expected to be the output from the parser, not a
7stripped down variant.
8
9Example usage:
10$ generate-test-checks.py foo.mlir
11$ mlir-opt foo.mlir -transformation | generate-test-checks.py
12
13The script will heuristically insert CHECK/CHECK-LABEL commands for each line
14within the file. By default this script will also try to insert string
15substitution blocks for all SSA value names. The script is designed to make
16adding checks to a test case fast, it is *not* designed to be authoritative
17about what constitutes a good test!
18"""
19
20# Copyright 2019 The MLIR Authors.
21#
22# Licensed under the Apache License, Version 2.0 (the "License");
23# you may not use this file except in compliance with the License.
24# You may obtain a copy of the License at
25#
26# http://www.apache.org/licenses/LICENSE-2.0
27#
28# Unless required by applicable law or agreed to in writing, software
29# distributed under the License is distributed on an "AS IS" BASIS,
30# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31# See the License for the specific language governing permissions and
32# limitations under the License.
33
34import argparse
35import os # Used to advertise this file's name ("autogenerated_note").
36import re
37import sys
38import string
39
40ADVERT = '// NOTE: Assertions have been autogenerated by '
41
42# Regex command to match an SSA identifier.
43SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
44SSA_RE = re.compile(SSA_RE_STR)
45
46
47# Class used to generate and manage string substitution blocks for SSA value
48# names.
49class SSAVariableNamer:
50
51 def __init__(self):
52 self.scopes = []
53 self.name_counter = 0
54
55 # Generate a subsitution name for the given ssa value name.
56 def generate_name(self, ssa_name):
57 variable = 'VAL_' + str(self.name_counter)
58 self.name_counter += 1
59 self.scopes[-1][ssa_name] = variable
60 return variable
61
62 # Push a new variable name scope.
63 def push_name_scope(self):
64 self.scopes.append({})
65
66 # Pop the last variable name scope.
67 def pop_name_scope(self):
68 self.scopes.pop()
69
70
71# Process a line of input that has been split at each SSA identifier '%'.
72def process_line(line_chunks, variable_namer):
73 output_line = ''
74
75 # Process the rest that contained an SSA value name.
76 for chunk in line_chunks:
77 m = SSA_RE.match(chunk)
78 ssa_name = m.group(0)
79
80 # Check if an existing variable exists for this name.
81 variable = None
82 for scope in variable_namer.scopes:
83 variable = scope.get(ssa_name)
84 if variable is not None:
85 break
86
87 # If one exists, then output the existing name.
88 if variable is not None:
89 output_line += '[[' + variable + ']]'
90 else:
91 # Otherwise, generate a new variable.
92 variable = variable_namer.generate_name(ssa_name)
93 output_line += '[[' + variable + ':%.*]]'
94
95 # Append the non named group.
96 output_line += chunk[len(ssa_name):]
97
98 return output_line + '\n'
99
100
River Riddle3e2ac622019-09-17 20:59:12101# Pre-process a line of input to remove any character sequences that will be
102# problematic with FileCheck.
103def preprocess_line(line):
104 # Replace any double brackets, '[[' with escaped replacements. '[['
105 # corresponds to variable names in FileCheck.
106 output_line = line.replace('[[', '{{\\[\\[}}')
107
108 # Replace any single brackets that are followed by an SSA identifier, the
109 # identifier will be replace by a variable; Creating the same situation as
110 # above.
111 output_line = output_line.replace('[%', '{{\\[}}%')
112
113 return output_line
114
115
River Riddlea9d40152019-08-13 23:42:41116def main():
117 from argparse import RawTextHelpFormatter
118 parser = argparse.ArgumentParser(
119 description=__doc__, formatter_class=RawTextHelpFormatter)
120 parser.add_argument(
121 '--check-prefix', default='CHECK', help='Prefix to use from check file.')
122 parser.add_argument(
123 '-o',
124 '--output',
125 nargs='?',
126 type=argparse.FileType('w'),
127 default=sys.stdout)
128 parser.add_argument(
129 'input',
130 nargs='?',
131 type=argparse.FileType('r'),
132 default=sys.stdin)
133 args = parser.parse_args()
134
135 # Open the given input file.
136 input_lines = [l.rstrip() for l in args.input]
137 args.input.close()
138
139 output_lines = []
140
141 # Generate a note used for the generated check file.
142 script_name = os.path.basename(__file__)
143 autogenerated_note = (ADVERT + 'utils/' + script_name)
144 output_lines.append(autogenerated_note + '\n')
145
146 # A map containing data used for naming SSA value names.
147 variable_namer = SSAVariableNamer()
148 for input_line in input_lines:
149 if not input_line:
150 continue
151 lstripped_input_line = input_line.lstrip()
152
153 # Lines with blocks begin with a ^. These lines have a trailing comment
154 # that needs to be stripped.
155 is_block = lstripped_input_line[0] == '^'
156 if is_block:
157 input_line = input_line.rsplit('//', 1)[0].rstrip()
158
159 # Top-level operations are heuristically the operations at nesting level 1.
160 is_toplevel_op = (not is_block and input_line.startswith(' ') and
161 input_line[2] != ' ' and input_line[2] != '}')
162
163 # If the line starts with a '}', pop the last name scope.
164 if lstripped_input_line[0] == '}':
165 variable_namer.pop_name_scope()
166
167 # If the line ends with a '{', push a new name scope.
168 if input_line[-1] == '{':
169 variable_namer.push_name_scope()
170
River Riddle3e2ac622019-09-17 20:59:12171 # Preprocess the input to remove any sequences that may be problematic with
172 # FileCheck.
173 input_line = preprocess_line(input_line)
174
River Riddlea9d40152019-08-13 23:42:41175 # Split the line at the each SSA value name.
176 ssa_split = input_line.split('%')
177
178 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
179 if not is_toplevel_op or not ssa_split[0]:
180 output_line = '// ' + args.check_prefix + ': '
181 # Pad to align with the 'LABEL' statements.
182 output_line += (' ' * len('-LABEL'))
183
184 # Output the first line chunk that does not contain an SSA name.
185 output_line += ssa_split[0]
186
187 # Process the rest of the input line.
188 output_line += process_line(ssa_split[1:], variable_namer)
189
190 else:
191 # Append a newline to the output to separate the logical blocks.
192 output_lines.append('\n')
193 output_line = '// ' + args.check_prefix + '-LABEL: '
194
195 # Output the first line chunk that does not contain an SSA name for the
196 # label.
197 output_line += ssa_split[0] + '\n'
198
199 # Process the rest of the input line on a separate check line.
200 if len(ssa_split) > 1:
201 output_line += '// ' + args.check_prefix + '-SAME: '
202
203 # Pad to align with the original position in the line.
204 output_line += ' ' * len(ssa_split[0])
205
206 # Process the rest of the line.
207 output_line += process_line(ssa_split[1:], variable_namer)
208
209 # Append the output line.
210 output_lines.append(output_line)
211
212 # Write the output.
213 for output_line in output_lines:
214 args.output.write(output_line)
215 args.output.write('\n')
216 args.output.close()
217
218
219if __name__ == '__main__':
220 main()