Skip to content

Commit 81714fd

Browse files
authored
[py][nfc] Move analysis to the proper place (#3335)
Moves the implementations of `ValidateArgumentAnnotations` and `ValidateReturnStatements` to the `analysis` module. Signed-off-by: boschmitt <[email protected]>
1 parent df2d760 commit 81714fd

File tree

2 files changed

+63
-69
lines changed

2 files changed

+63
-69
lines changed

python/cudaq/kernel/analysis.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,63 @@ def fetch(func_obj: object):
233233
code += src + '\n'
234234

235235
return code
236+
237+
238+
class ValidateArgumentAnnotations(ast.NodeVisitor):
239+
"""
240+
Utility visitor for finding argument annotations
241+
"""
242+
243+
def __init__(self, bridge):
244+
self.bridge = bridge
245+
246+
def visit_FunctionDef(self, node):
247+
for arg in node.args.args:
248+
if arg.annotation == None:
249+
self.bridge.emitFatalError(
250+
'cudaq.kernel functions must have argument type annotations.',
251+
arg)
252+
253+
254+
class ValidateReturnStatements(ast.NodeVisitor):
255+
"""
256+
Analyze the AST and ensure that functions with a return-type annotation
257+
actually have a return statement in all paths.
258+
"""
259+
260+
def __init__(self, bridge):
261+
self.bridge = bridge
262+
263+
def visit_FunctionDef(self, node):
264+
# skip if un-annotated or explicitly marked as None
265+
is_none_ret = (isinstance(node.returns, ast.Constant) and
266+
node.returns.value
267+
is None) or (isinstance(node.returns, ast.Name) and
268+
node.returns.id == 'None')
269+
270+
if node.returns is None or is_none_ret:
271+
return self.generic_visit(node)
272+
273+
def all_paths_return(stmts):
274+
for stmt in stmts:
275+
if isinstance(stmt, ast.Return):
276+
return True
277+
278+
if isinstance(stmt, ast.If):
279+
if all_paths_return(stmt.body) and all_paths_return(
280+
stmt.orelse):
281+
return True
282+
283+
if isinstance(stmt, (ast.For, ast.While)):
284+
if all_paths_return(stmt.body) or all_paths_return(
285+
stmt.orelse):
286+
return True
287+
288+
return False
289+
290+
if not all_paths_return(node.body):
291+
self.bridge.emitFatalError(
292+
'cudaq.kernel functions with return type annotations must have a return statement.',
293+
node)
294+
295+
self.generic_visit(node)

python/cudaq/kernel/ast_bridge.py

Lines changed: 3 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
InsertionPoint, IntegerAttr, IntegerType, Location,
2727
Module, StringAttr, SymbolTable, TypeAttr, UnitAttr)
2828
from cudaq.mlir.passmanager import PassManager
29-
from .analysis import FindDepKernelsVisitor
29+
from .analysis import FindDepKernelsVisitor, ValidateArgumentAnnotations, ValidateReturnStatements
3030
from .captured_data import CapturedDataStorage
3131
from .utils import (
3232
Color,
@@ -242,72 +242,6 @@ def emitFatalError(self, msg, astNode=None):
242242
hasattr(ast, 'unparse') and astNode is not None else '') + Color.END
243243
raise CompilerError(msg)
244244

245-
def validateArgumentAnnotations(self, astModule):
246-
"""
247-
Utility function for quickly validating that we have
248-
all arguments annotated.
249-
"""
250-
251-
class ValidateArgumentAnnotations(ast.NodeVisitor):
252-
"""
253-
Utility visitor for finding argument annotations
254-
"""
255-
256-
def __init__(self, bridge):
257-
self.bridge = bridge
258-
259-
def visit_FunctionDef(self, node):
260-
for arg in node.args.args:
261-
if arg.annotation == None:
262-
self.bridge.emitFatalError(
263-
'cudaq.kernel functions must have argument type annotations.',
264-
arg)
265-
266-
ValidateArgumentAnnotations(self).visit(astModule)
267-
268-
# Ensure that functions with a return-type annotation actually has a valid return statement
269-
# in all paths, if not throw an error.
270-
class ValidateReturnStatements(ast.NodeVisitor):
271-
272-
def __init__(self, bridge):
273-
self.bridge = bridge
274-
275-
def visit_FunctionDef(self, node):
276-
# skip if un-annotated or explicitly marked as None
277-
is_none_ret = (isinstance(node.returns, ast.Constant) and
278-
node.returns.value is None) or (
279-
isinstance(node.returns, ast.Name) and
280-
node.returns.id == 'None')
281-
282-
if node.returns is None or is_none_ret:
283-
return self.generic_visit(node)
284-
285-
def all_paths_return(stmts):
286-
for stmt in stmts:
287-
if isinstance(stmt, ast.Return):
288-
return True
289-
290-
if isinstance(stmt, ast.If):
291-
if all_paths_return(stmt.body) and all_paths_return(
292-
stmt.orelse):
293-
return True
294-
295-
if isinstance(stmt, (ast.For, ast.While)):
296-
if all_paths_return(stmt.body) or all_paths_return(
297-
stmt.orelse):
298-
return True
299-
300-
return False
301-
302-
if not all_paths_return(node.body):
303-
self.bridge.emitFatalError(
304-
'cudaq.kernel functions with return type annotations must have a return statement.',
305-
node)
306-
307-
self.generic_visit(node)
308-
309-
ValidateReturnStatements(self).visit(astModule)
310-
311245
def getVeqType(self, size=None):
312246
"""
313247
Return a `quake.VeqType`. Pass the size of the `quake.veq` if known.
@@ -4605,8 +4539,8 @@ def compile_to_mlir(astModule, capturedDataStorage: CapturedDataStorage,
46054539
locationOffset=lineNumberOffset,
46064540
capturedVariables=parentVariables)
46074541

4608-
# First validate the arguments, make sure they are annotated
4609-
bridge.validateArgumentAnnotations(astModule)
4542+
ValidateArgumentAnnotations(bridge).visit(astModule)
4543+
ValidateReturnStatements(bridge).visit(astModule)
46104544

46114545
# First we need to find any dependent kernels, they have to be
46124546
# built as part of this ModuleOp...

0 commit comments

Comments
 (0)