|
26 | 26 | InsertionPoint, IntegerAttr, IntegerType, Location, |
27 | 27 | Module, StringAttr, SymbolTable, TypeAttr, UnitAttr) |
28 | 28 | from cudaq.mlir.passmanager import PassManager |
29 | | -from .analysis import FindDepKernelsVisitor |
| 29 | +from .analysis import FindDepKernelsVisitor, ValidateArgumentAnnotations, ValidateReturnStatements |
30 | 30 | from .captured_data import CapturedDataStorage |
31 | 31 | from .utils import ( |
32 | 32 | Color, |
@@ -242,72 +242,6 @@ def emitFatalError(self, msg, astNode=None): |
242 | 242 | hasattr(ast, 'unparse') and astNode is not None else '') + Color.END |
243 | 243 | raise CompilerError(msg) |
244 | 244 |
|
245 | | - def validateArgumentAnnotations(self, astModule): |
246 | | - """ |
247 | | - Utility function for quickly validating that we have |
248 | | - all arguments annotated. |
249 | | - """ |
250 | | - |
251 | | - class ValidateArgumentAnnotations(ast.NodeVisitor): |
252 | | - """ |
253 | | - Utility visitor for finding argument annotations |
254 | | - """ |
255 | | - |
256 | | - def __init__(self, bridge): |
257 | | - self.bridge = bridge |
258 | | - |
259 | | - def visit_FunctionDef(self, node): |
260 | | - for arg in node.args.args: |
261 | | - if arg.annotation == None: |
262 | | - self.bridge.emitFatalError( |
263 | | - 'cudaq.kernel functions must have argument type annotations.', |
264 | | - arg) |
265 | | - |
266 | | - ValidateArgumentAnnotations(self).visit(astModule) |
267 | | - |
268 | | - # Ensure that functions with a return-type annotation actually has a valid return statement |
269 | | - # in all paths, if not throw an error. |
270 | | - class ValidateReturnStatements(ast.NodeVisitor): |
271 | | - |
272 | | - def __init__(self, bridge): |
273 | | - self.bridge = bridge |
274 | | - |
275 | | - def visit_FunctionDef(self, node): |
276 | | - # skip if un-annotated or explicitly marked as None |
277 | | - is_none_ret = (isinstance(node.returns, ast.Constant) and |
278 | | - node.returns.value is None) or ( |
279 | | - isinstance(node.returns, ast.Name) and |
280 | | - node.returns.id == 'None') |
281 | | - |
282 | | - if node.returns is None or is_none_ret: |
283 | | - return self.generic_visit(node) |
284 | | - |
285 | | - def all_paths_return(stmts): |
286 | | - for stmt in stmts: |
287 | | - if isinstance(stmt, ast.Return): |
288 | | - return True |
289 | | - |
290 | | - if isinstance(stmt, ast.If): |
291 | | - if all_paths_return(stmt.body) and all_paths_return( |
292 | | - stmt.orelse): |
293 | | - return True |
294 | | - |
295 | | - if isinstance(stmt, (ast.For, ast.While)): |
296 | | - if all_paths_return(stmt.body) or all_paths_return( |
297 | | - stmt.orelse): |
298 | | - return True |
299 | | - |
300 | | - return False |
301 | | - |
302 | | - if not all_paths_return(node.body): |
303 | | - self.bridge.emitFatalError( |
304 | | - 'cudaq.kernel functions with return type annotations must have a return statement.', |
305 | | - node) |
306 | | - |
307 | | - self.generic_visit(node) |
308 | | - |
309 | | - ValidateReturnStatements(self).visit(astModule) |
310 | | - |
311 | 245 | def getVeqType(self, size=None): |
312 | 246 | """ |
313 | 247 | Return a `quake.VeqType`. Pass the size of the `quake.veq` if known. |
@@ -4605,8 +4539,8 @@ def compile_to_mlir(astModule, capturedDataStorage: CapturedDataStorage, |
4605 | 4539 | locationOffset=lineNumberOffset, |
4606 | 4540 | capturedVariables=parentVariables) |
4607 | 4541 |
|
4608 | | - # First validate the arguments, make sure they are annotated |
4609 | | - bridge.validateArgumentAnnotations(astModule) |
| 4542 | + ValidateArgumentAnnotations(bridge).visit(astModule) |
| 4543 | + ValidateReturnStatements(bridge).visit(astModule) |
4610 | 4544 |
|
4611 | 4545 | # First we need to find any dependent kernels, they have to be |
4612 | 4546 | # built as part of this ModuleOp... |
|
0 commit comments