diff --git a/pyproject.toml b/pyproject.toml index 0b1205f10f..08a4f3e36c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,25 +224,17 @@ follow_imports = 'silent' module = 'gt4py.cartesian.*' [[tool.mypy.overrides]] -ignore_errors = true +disable_error_code = "call-arg" module = 'gt4py.cartesian.frontend.nodes' [[tool.mypy.overrides]] -ignore_errors = true -module = 'gt4py.cartesian.frontend.node_util' - -[[tool.mypy.overrides]] -ignore_errors = true +disable_error_code = "call-arg" module = 'gt4py.cartesian.frontend.gtscript_frontend' [[tool.mypy.overrides]] -ignore_errors = true +disable_error_code = "call-arg" module = 'gt4py.cartesian.frontend.defir_to_gtir' -[[tool.mypy.overrides]] -ignore_errors = true -module = 'gt4py.cartesian.frontend.meta' - [[tool.mypy.overrides]] module = 'gt4py.eve.extended_typing' warn_unused_ignores = false diff --git a/src/gt4py/cartesian/frontend/base.py b/src/gt4py/cartesian/frontend/base.py index 84c65f97ec..9f1da816f0 100644 --- a/src/gt4py/cartesian/frontend/base.py +++ b/src/gt4py/cartesian/frontend/base.py @@ -81,7 +81,10 @@ def generate( @classmethod @abc.abstractmethod def prepare_stencil_definition( - cls, definition: AnyStencilFunc, externals: dict[str, Any] + cls, + definition: AnyStencilFunc, + externals: dict[str, Any], + options: BuildOptions | None = None, ) -> AnnotatedStencilFunc: """ Annotate the stencil function if not already done so. diff --git a/src/gt4py/cartesian/frontend/defir_to_gtir.py b/src/gt4py/cartesian/frontend/defir_to_gtir.py index 9c70f8da38..0c2274ce78 100644 --- a/src/gt4py/cartesian/frontend/defir_to_gtir.py +++ b/src/gt4py/cartesian/frontend/defir_to_gtir.py @@ -10,7 +10,7 @@ import functools import itertools import numbers -from typing import Any, Final, List, Optional, Tuple, Union, cast +from typing import Any, Final, List, Optional, Tuple, Union import numpy as np @@ -66,9 +66,7 @@ def _convert_dtype(data_type) -> common.DataType: if dtype == common.DataType.DEFAULT: # TODO: this will be a frontend choice later # in non-GTC parts, this is set in the backend - dtype = cast( - common.DataType, common.DataType.FLOAT64 - ) # see https://github.com/GridTools/gtc/issues/100 + dtype = common.DataType.FLOAT64 return dtype @@ -161,7 +159,7 @@ def _nested_list_dim(self, a: List) -> List[int]: def visit_Assign( self, node: Assign, *, fields_decls: dict[str, FieldDecl], **kwargs - ) -> Union[gtir.ParAssignStmt, List[gtir.ParAssignStmt]]: + ) -> Assign | list[Assign]: if self._is_vector_assignment(node, fields_decls): assert isinstance(node.target, FieldRef) or isinstance(node.target, VarRef) target_dims = fields_decls[node.target.name].data_dims @@ -249,20 +247,20 @@ def visit_FieldRef(self, node: FieldRef, *, fields_decls: dict[str, FieldDecl], def visit_UnaryOpExpr(self, node: UnaryOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs): if node.op == UnaryOperator.TRANSPOSED: - node = self.visit(node.arg, fields_decls=fields_decls, **kwargs) - assert isinstance(node, list) and all( - isinstance(row, list) and len(row) == len(node[0]) for row in node + argument = self.visit(node.arg, fields_decls=fields_decls, **kwargs) + assert isinstance(argument, list) and all( + isinstance(row, list) and len(row) == len(argument[0]) for row in argument ) # transpose list - node = [list(x) for x in zip(*node)] - return node + argument = [list(x) for x in zip(*argument)] + return argument return self.generic_visit(node, **kwargs) def visit_BinOpExpr(self, node: BinOpExpr, *, fields_decls: dict[str, FieldDecl], **kwargs): lhs = self.visit(node.lhs, fields_decls=fields_decls, **kwargs) rhs = self.visit(node.rhs, fields_decls=fields_decls, **kwargs) - result: Union[List[BinOpExpr], BinOpExpr] = [] + result: list[BinOpExpr] = [] if node.op == BinaryOperator.MATMULT: for j in range(len(lhs)): @@ -587,20 +585,20 @@ def visit_While(self, node: While) -> gtir.While: def visit_VarRef(self, node: VarRef, **kwargs) -> gtir.ScalarAccess: return gtir.ScalarAccess(name=node.name, loc=location_to_source_location(node.loc)) - def visit_AxisInterval(self, node: AxisInterval) -> Tuple[gtir.AxisBound, gtir.AxisBound]: + def visit_AxisInterval(self, node: AxisInterval) -> tuple[common.AxisBound, common.AxisBound]: return self.visit(node.start), self.visit(node.end) - def visit_AxisBound(self, node: AxisBound) -> gtir.AxisBound: + def visit_AxisBound(self, node: AxisBound) -> common.AxisBound: # TODO(havogt) add support VarRef - return gtir.AxisBound( + return common.AxisBound( level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[node.level], offset=node.offset ) - def visit_RuntimeAxisBound(self, node: RuntimeAxisBound) -> gtir.RuntimeAxisBound: + def visit_RuntimeAxisBound(self, node: RuntimeAxisBound) -> common.RuntimeAxisBound: utils.warn_experimental_feature( feature="Runtime Interval Bounds", ADR="experimental/runtime-intervals.md" ) - return gtir.RuntimeAxisBound( + return common.RuntimeAxisBound( level=self.GT4PY_LEVELMARKER_TO_GTIR_LEVELMARKER[node.level], offset=self.visit(node.offset), ) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 55d0320f3e..8956f5a1f8 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -15,24 +15,11 @@ import textwrap import time import types -from typing import ( - Any, - Callable, - Dict, - Final, - List, - Literal, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, -) +from typing import Any, Dict, Final, List, Literal, Optional, Sequence, Set, Type, Union import numpy as np -from gt4py.cartesian import definitions as gt_definitions, gtscript, utils as gt_utils +from gt4py.cartesian import definitions as gt_definitions, gtscript, type_hints, utils as gt_utils from gt4py.cartesian.frontend import node_util, nodes from gt4py.cartesian.frontend.base import Frontend, register from gt4py.cartesian.frontend.defir_builder import DefIRBuilder @@ -56,11 +43,11 @@ class AssertionChecker(ast.NodeTransformer): """Check assertions and remove from the AST for further parsing.""" @classmethod - def apply(cls, func_node: ast.FunctionDef, context: Dict[str, Any], source: str): + def apply(cls, func_node: ast.FunctionDef, context: Dict[str, Any], source: str) -> None: checker = cls(context, source) checker.visit(func_node) - def __init__(self, context: Dict[str, Any], source: str): + def __init__(self, context: Dict[str, Any], source: str) -> None: self.context = context self.source = source @@ -76,8 +63,6 @@ def _process_assertion(self, expr_node: ast.Expr) -> None: loc = nodes.Location.from_ast_node(expr_node) raise GTScriptAssertionError(source_lines[loc.line - 1], loc=loc) - return None - def _process_call(self, node: ast.Call) -> Optional[ast.Call]: name = gt_meta.get_qualified_name_from_node(node.func) if name != "compile_assert": @@ -87,7 +72,8 @@ def _process_call(self, node: ast.Call) -> Optional[ast.Call]: raise GTScriptSyntaxError( "Invalid assertion. Correct syntax: compile_assert(condition)" ) - return self._process_assertion(node.args[0]) + self._process_assertion(node.args[0]) + return None def visit_Expr(self, node: ast.Expr) -> Optional[ast.AST]: if isinstance(node.value, ast.Call): @@ -110,7 +96,7 @@ def __init__( error_msg = "Invalid interval range specification" - if self.loc is not None: + if loc is not None: error_msg = f"{error_msg} at line {loc.line} (column: {loc.column})" self.interval_error = GTScriptSyntaxError(error_msg) @@ -198,6 +184,8 @@ def visit_BinOp(self, node: ast.BinOp) -> Union[gtscript.AxisIndex, nodes.AxisBo if isinstance(left, nodes.VarRef): if not isinstance(right, numbers.Number): raise incompatible_types_error + if u_op is None: + raise GTScriptSyntaxError("Unexpected unary operator found in interval expression.") return nodes.AxisBound(level=left, offset=u_op(right), loc=self.loc) if isinstance(left, nodes.AxisBound): if not isinstance(right, numbers.Number): @@ -208,6 +196,8 @@ def visit_BinOp(self, node: ast.BinOp) -> Union[gtscript.AxisIndex, nodes.AxisBo if isinstance(left, numbers.Number) and isinstance(right, numbers.Number): return bin_op(left, right) + raise GTScriptSyntaxError("Unexpected arguments to binary operator in IntervalParser.") + def visit_UnaryOp(self, node: ast.UnaryOp) -> nodes.AxisBound: if not isinstance(node.op, ast.USub): raise self.interval_error @@ -257,23 +247,25 @@ def apply( return nodes.AxisInterval(start=start, end=end, loc=loc) - def visit_Subscript(self, node: ast.Subscript) -> nodes.AxisBound: + def visit_Subscript(self, node: ast.Subscript) -> gtscript.AxisIndex: # This allows for the syntax # `region[I[0] : I[0] + 2, J[0] : J[0] + 2]` # to exist if not isinstance(node.value, ast.Name): raise self.interval_error + if node.value.id != self.axis_name: raise GTScriptSyntaxError( "Invalid horizontal range specification:" f"Expected axis {self.axis_name}, got {node.value.id}" ) + if isinstance(node.slice, ast.Constant): if node.slice.value != 0: raise GTScriptSyntaxError( "Invalid horizontal range specification:" f"Expected specification {self.axis_name}[0] or {self.axis_name}[-1]" - f", got {self.axis_name}[{node.slice.value}]" + f", got {self.axis_name}[{node.slice.value!r}]" ) elif isinstance(node.slice, ast.UnaryOp): if not isinstance(node.slice.operand, ast.Constant) or node.slice.operand.value not in ( @@ -374,8 +366,11 @@ def visit_Subscript(self, node: ast.Subscript): field_name = node.value.value.id # Ensure the indexing is correct, first we need a 0-offset in i and j if isinstance(node.value.slice, ast.Tuple): - axis_offsets: list[ast.Constant] = node.value.slice.elts - if not all(offset.value == 0 for offset in axis_offsets): + axis_offsets = node.value.slice.elts + if not all( + isinstance(offset, ast.Constant) and offset.value == 0 + for offset in axis_offsets + ): raise self.interval_error else: raise self.interval_error @@ -398,8 +393,10 @@ def visit_Subscript(self, node: ast.Subscript): # This is a non-higher dimensional field, but a normal field accessed with an offset if isinstance(node.value, ast.Name) and isinstance(node.slice, ast.Tuple): # We need to check that the offset is 0 everywhere, since no horizontal dependencies are allowed - axis_offsets: list[ast.Constant] = node.slice.elts - if not all(offset.value == 0 for offset in axis_offsets): + axis_offsets = node.slice.elts + if not all( + isinstance(offset, ast.Constant) and offset.value == 0 for offset in axis_offsets + ): raise self.interval_error # If the offset is 0, we are safe to visit the field return self.visit(node.value) @@ -466,6 +463,9 @@ def _get_num_values(node: ast.AST) -> int: return len(node.elts) if isinstance(node, ast.Tuple) else 1 def visit_Return(self, node: ast.Return, *, target_node: ast.AST) -> ast.Assign: + if node.value is None: + raise GTScriptSyntaxError("Return replacer needs a return value to work.") + rhs_length = self._get_num_values(node.value) lhs_length = self._get_num_values(target_node) @@ -571,12 +571,17 @@ def visit_Assign(self, node: ast.Assign): return self.generic_visit(node) def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complexity too high + if self.current_block is None: + raise RuntimeError( + "CallInliner can't visit call without `self.current_block` being defined." + ) + if _filter_absolute_K_index_method(node): return node call_name = gt_meta.get_qualified_name_from_node(node.func) - if call_name in self.call_stack: + if self.call_stack is not None and call_name in self.call_stack: raise GTScriptSyntaxError( message=f"Found recursive function call '{call_name}' in the stack.", loc=nodes.Location.from_ast_node(node), @@ -752,11 +757,12 @@ def apply(cls, ast_object: ast.AST, context: Dict[str, Any]): def __init__(self, context: Dict[str, Any]): self.context = context - def visit_If(self, node: ast.If): + def visit_If(self, ast_node: ast.If): # Compile-time evaluation of "if" conditions - node = self.generic_visit(node) + node = self.generic_visit(ast_node) if ( - isinstance(node.test, ast.Call) + isinstance(node, ast.If) + and isinstance(node.test, ast.Call) and isinstance(node.test.func, ast.Name) and node.test.func.id == "__INLINED" and len(node.test.args) == 1 @@ -764,11 +770,11 @@ def visit_If(self, node: ast.If): eval_node = node.test.args[0] condition_value = gt_meta.ast_eval(eval_node, self.context, default=gt_utils.NOTHING) if condition_value is not gt_utils.NOTHING: - node = node.body if condition_value else node.orelse - else: - raise GTScriptSyntaxError( - "Evaluation of compile-time 'if' condition failed at the preprocessing step" - ) + return node.body if condition_value else node.orelse + + raise GTScriptSyntaxError( + "Evaluation of compile-time 'if' condition failed at the preprocessing step" + ) return node if node else None @@ -796,16 +802,16 @@ def _make_init_computations( if not temp_decls: return [] - stmts: List[nodes.Assign] = [] + statements: list[nodes.Statement] = [] for name in init_values: decl = temp_decls[name] - stmts.append(decl) + statements.append(decl) if decl.data_dims: for index in itertools.product(*(range(i) for i in decl.data_dims)): literal_index = [ nodes.ScalarLiteral(value=i, data_type=nodes.DataType.INT32) for i in index ] - stmts.append( + statements.append( nodes.Assign( target=nodes.FieldRef.at_center( name, axes=decl.axes, data_index=literal_index @@ -816,7 +822,7 @@ def _make_init_computations( ) ) else: - stmts.append( + statements.append( nodes.Assign( target=nodes.FieldRef.at_center(name, axes=decl.axes), value=init_values[name], @@ -827,7 +833,7 @@ def _make_init_computations( nodes.ComputationBlock( interval=nodes.AxisInterval.full_interval(), iteration_order=nodes.IterationOrder.PARALLEL, - body=nodes.BlockStmt(stmts=stmts), + body=nodes.BlockStmt(stmts=statements), ) ] @@ -913,9 +919,9 @@ def __init__( self.literal_int_precision = options.literal_int_precision self.literal_float_precision = options.literal_float_precision self.temp_decls = temp_decls or {} - self.parsing_context = None - self.iteration_order = None - self.decls_stack = [] + self.parsing_context: ParsingContext | None = None + self.iteration_order: nodes.IterationOrder | None = None + self.decls_stack: list[list[nodes.Decl]] = [] self.parsing_horizontal_region = False self.written_vars: Set[str] = set() self.dtypes = dtypes @@ -969,7 +975,7 @@ def __init__( } # Conversion table for functions to NativeFunctions # Filter the field type from `dtypes` - self.temporary_field_type = {} + self.temporary_field_type: dict[str | type, type] = {} if self.dtypes: for name, _type in self.dtypes.items(): if isinstance(_type, gtscript._FieldDescriptor): @@ -1247,7 +1253,7 @@ def visit_Constant( return nodes.ScalarLiteral(value=value, data_type=data_type) raise GTScriptSyntaxError( - f"Unknown constant value found: {value}. Expected boolean or number.", + f"Unknown constant value found: {value!r}. Expected boolean or number.", loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), ) @@ -1304,6 +1310,9 @@ def _eval_new_spatial_index( index_nodes: Sequence[nodes.Expr], field_axes: Optional[Set[Literal["I", "J", "K"]]], ) -> List[int]: + if field_axes is None: + return [] + index_dict = {} all_spatial_axes = ("I", "J", "K") last_index = -1 @@ -1314,38 +1323,38 @@ def _eval_new_spatial_index( value = gt_meta.ast_eval(index_node, axis_context) except Exception as ex: raise GTScriptSyntaxError( - message="Could not evaluate axis shift expression.", loc=index_node + message="Could not evaluate axis shift expression.", + loc=nodes.Location.from_ast_node(index_node), ) from ex if not isinstance(value, (gtscript.ShiftedAxis, gtscript.Axis)): raise GTScriptSyntaxError( message=f"Axis shift expression evaluated to unrecognized type {type(value)}.", - loc=index_node, + loc=nodes.Location.from_ast_node(index_node), ) axis_index = all_spatial_axes.index(value.name) if axis_index < 0: raise GTScriptSyntaxError( - message=f"Unrecognized axis: {value.name}", loc=index_node + message=f"Unrecognized axis: {value.name}", + loc=nodes.Location.from_ast_node(index_node), ) if axis_index < last_index: raise GTScriptSyntaxError( message=f"Axis {value.name} is specified out of order", - loc=index_node, + loc=nodes.Location.from_ast_node(index_node), ) if axis_index == last_index: raise GTScriptSyntaxError( - message=f"Duplicate axis found: {value.name}", loc=index_node + message=f"Duplicate axis found: {value.name}", + loc=nodes.Location.from_ast_node(index_node), ) last_index = axis_index - try: - shift = value.shift - except AttributeError: - shift = 0 + shift = value.shift if isinstance(value, gtscript.ShiftedAxis) else 0 index_dict[value.name] = shift - return [index_dict.get(axis, 0) for axis in ("I", "J", "K") if axis in field_axes] + return [index_dict.get(axis, 0) for axis in all_spatial_axes if axis in field_axes] def _eval_index( self, @@ -1358,7 +1367,9 @@ def _eval_index( ) if any(isinstance(cn, ast.Slice) for cn in index_nodes): - raise GTScriptSyntaxError(message="Invalid target in assignment.", loc=node) + raise GTScriptSyntaxError( + message="Invalid target in assignment.", loc=nodes.Location.from_ast_node(node) + ) if any(isinstance(cn, ELLIPSIS_TYPE) for cn in index_nodes): return None @@ -1397,6 +1408,7 @@ def visit_Subscript(self, node: ast.Subscript): index = self._eval_index(node, axes) if isinstance(result, nodes.VarRef): assert index is not None + assert not isinstance(index, nodes.AbsoluteKIndex) result.index = index[0] elif isinstance(result, nodes.FieldRef): if isinstance(index, nodes.AbsoluteKIndex): @@ -1415,14 +1427,18 @@ def visit_Subscript(self, node: ast.Subscript): ) result.offset = {axis: value for axis, value in zip(field_axes, index)} elif isinstance(node.value, ast.Subscript) or _is_datadims_indexing_node(node): - result.data_index = [ - ( - nodes.ScalarLiteral(value=value, data_type=nodes.DataType.INT32) - if isinstance(value, numbers.Integral) - else value - ) - for value in index - ] + result.data_index = ( + [ + ( + nodes.ScalarLiteral(value=value, data_type=nodes.DataType.INT32) + if isinstance(value, numbers.Integral) + else value + ) + for value in index + ] + if index is not None + else [] + ) if len(result.data_index) != len(self.fields[result.name].data_dims): raise GTScriptSyntaxError( f"Incorrect data index length {len(result.data_index)}. " @@ -1509,7 +1525,7 @@ def visit_MatMult(self, node: ast.MatMult) -> nodes.BinaryOperator: def visit_And(self, node: ast.And) -> nodes.BinaryOperator: return nodes.BinaryOperator.AND - def visit_Or(self, node: ast.And) -> nodes.BinaryOperator: + def visit_Or(self, node: ast.Or) -> nodes.BinaryOperator: return nodes.BinaryOperator.OR def visit_Eq(self, node: ast.Eq) -> nodes.BinaryOperator: @@ -1555,7 +1571,7 @@ def visit_Compare(self, node: ast.Compare) -> nodes.BinOpExpr: args.append(rhs) for i in range(len(node.comparators) - 2, -1, -1): - lhs = self.visit(node.values[i]) + lhs = self.visit(node.comparators[i]) rhs = nodes.BinOpExpr( op=op, lhs=lhs, @@ -1583,7 +1599,7 @@ def visit_IfExp(self, node: ast.IfExp) -> nodes.TernaryOpExpr: return result - def visit_If(self, node: ast.If) -> list: + def visit_If(self, node: ast.If) -> list[nodes.Statement]: self.decls_stack.append([]) main_stmts = [] @@ -1597,7 +1613,7 @@ def visit_If(self, node: ast.If) -> list: else_stmts.extend(gt_utils.listify(self.visit(stmt))) assert all(isinstance(item, nodes.Statement) for item in else_stmts) - result = [] + result: list[nodes.Statement] = [] if len(self.decls_stack) == 1: result.extend(self.decls_stack.pop()) elif len(self.decls_stack) > 1: @@ -1635,14 +1651,14 @@ def visit_If(self, node: ast.If) -> list: return result - def visit_While(self, node: ast.While) -> list: + def visit_While(self, node: ast.While) -> list[nodes.Statement]: loc = nodes.Location.from_ast_node(node, scope=self.stencil_name) self.decls_stack.append([]) stmts = gt_utils.flatten([self.visit(stmt) for stmt in node.body]) assert all(isinstance(item, nodes.Statement) for item in stmts) - result = [ + result: list[nodes.Statement] = [ nodes.While( condition=self.visit(node.test), loc=nodes.Location.from_ast_node(node, scope=self.stencil_name), @@ -1737,9 +1753,11 @@ def visit_Call(self, node: ast.Call): # -- Statement nodes -- def _parse_assign_target( self, target_node: Union[ast.Subscript, ast.Name] - ) -> Tuple[str, Optional[List[int]], Optional[List[int]]]: + ) -> tuple[ + str, list[int] | nodes.AbsoluteKIndex | None, list[int] | nodes.AbsoluteKIndex | None + ]: invalid_target = GTScriptSyntaxError( - message="Invalid target in assignment.", loc=target_node + message="Invalid target in assignment.", loc=nodes.Location.from_ast_node(target_node) ) spatial_offset = None data_index = None @@ -1789,8 +1807,8 @@ def _resolve_assign( node: Union[ast.AnnAssign, ast.Assign], targets: List[Any], target_annotation: Optional[Any] = None, - ) -> list: - result = [] + ) -> list[nodes.Statement]: + result: list[nodes.Statement] = [] # Create decls for temporary fields target = [] @@ -1803,6 +1821,11 @@ def _resolve_assign( for t in targets[0].elts if isinstance(targets[0], ast.Tuple) else targets: name, spatial_offset, data_index = self._parse_assign_target(t) if spatial_offset: + if isinstance(spatial_offset, nodes.AbsoluteKIndex): + raise GTScriptSyntaxError( + message="Assignment with absolute K index is not allowed.", + loc=nodes.Location.from_ast_node(t), + ) if spatial_offset[0] != 0 or spatial_offset[1] != 0: raise GTScriptSyntaxError( message="Assignment to non-zero offsets is not supported in IJ.", @@ -1835,10 +1858,15 @@ def _resolve_assign( if target_annotation is not None: source = ast.unparse(target_annotation) try: + just_string_types = { + k: v + for k, v in self.temporary_field_type.items() + if isinstance(k, str) + } dtype_or_field_desc = eval( source, self.temporary_type_as_str_to_native_type - | self.temporary_field_type + | just_string_types | gtscript.__dict__, ) except NameError: @@ -1990,7 +2018,7 @@ def visit_With(self, node: ast.With): # Splice `withItems` of current/primary with statement into nested with with_node.items.extend(node.items) - compute_blocks.append(self._visit_computation_node(with_node)) + compute_blocks.append(self._visit_computation_node(with_node)) # type: ignore[arg-type] # we check above # Validate block specification order # the nested computation blocks must be specified in their order of execution. The order of execution is @@ -2033,7 +2061,7 @@ def apply(cls, node: ast.FunctionDef): return cls()(node) def __call__(self, node: ast.FunctionDef): - self.local_symbols = set() + self.local_symbols: set[str] = set() self.visit(node) result = self.local_symbols del self.local_symbols @@ -2067,7 +2095,7 @@ def visit_Assign(self, node: ast.Assign): class GTScriptParser(ast.NodeVisitor): CONST_VALUE_TYPES = ( - *gtscript._VALID_DATA_TYPES, + *gtscript._VALID_DATA_TYPES, # type: ignore[has-type] types.FunctionType, type(None), gtscript.AxisIndex, @@ -2096,10 +2124,10 @@ def __str__(self) -> str: @staticmethod def annotate_definition( - definition: Callable, + definition: type_hints.StencilFunc, options: gt_definitions.BuildOptions | None = None, externals=None, - ) -> Callable: + ) -> type_hints.AnnotatedStencilFunc: """Annotate the function definition with dtypes, resolve externals and add default values. Args: @@ -2148,6 +2176,7 @@ def annotate_definition( ) default = param.default + dtype_annotation: str | gtscript._FieldDescriptor | np.dtype | None if isinstance(param.annotation, (str, gtscript._FieldDescriptor)): dtype_annotation = param.annotation elif ( @@ -2232,7 +2261,8 @@ def annotate_definition( assert isinstance(ann_assign.value, ast.Constant) temp_init_values[name] = ann_assign.value.value - definition._gtscript_ = dict( + annotated: type_hints.AnnotatedStencilFunc = definition # type: ignore[assignment] + annotated._gtscript_ = dict( qualified_name=qualified_name, api_signature=api_signature, api_annotations=api_annotations, @@ -2244,7 +2274,7 @@ def annotate_definition( externals=resolved_externals if externals is not None else {}, ) - return definition + return annotated @staticmethod def collect_external_symbols(definition): @@ -2356,9 +2386,9 @@ def resolve_external_symbols( resolved_imports[imported_name] = imported_value # Collect all imported and inlined values recursively through all the external symbols - last_resolved_values = set() + last_resolved_values: Any = set() while resolved_imports or resolved_values_list: - new_imports = {} + new_imports: dict = {} for name, accesses in resolved_imports.items(): if accesses: for attr_name, attr_nodes in accesses.items(): @@ -2575,10 +2605,10 @@ def get_stencil_id(cls, qualified_name, definition, externals, options_id): @classmethod def prepare_stencil_definition( cls, - definition: Callable, - externals, + definition: type_hints.AnyStencilFunc, + externals: dict[str, Any], options: gt_definitions.BuildOptions | None = None, - ) -> Callable: + ) -> type_hints.AnnotatedStencilFunc: """Return an annotated version of the stencil definition. Args: diff --git a/src/gt4py/cartesian/frontend/node_util.py b/src/gt4py/cartesian/frontend/node_util.py index 52496b12a4..dc6a6066be 100644 --- a/src/gt4py/cartesian/frontend/node_util.py +++ b/src/gt4py/cartesian/frontend/node_util.py @@ -8,7 +8,7 @@ import collections import operator -from typing import Generator, Optional, Type +from typing import Generator, Iterable, Mapping, Optional, Type import boltons.typeutils @@ -27,7 +27,7 @@ def iter_attributes(node: Node): Yield a tuple of ``(attrib_name, value)`` for each attribute in ``node.attributes`` that is present on *node*. """ - for attrib_name in node.attributes: + for attrib_name in node.attributes: # type: ignore[attr-defined] try: yield attrib_name, getattr(node, attrib_name) except AttributeError: @@ -50,20 +50,22 @@ def _visit(self, node: Node, **kwargs): return visitor(node, **kwargs) def generic_visit(self, node: Node, **kwargs): - items = [] if isinstance(node, (str, bytes, bytearray)): - pass - elif isinstance(node, collections.abc.Mapping): - items = node.items() - elif isinstance(node, collections.abc.Iterable): - items = enumerate(node) - elif isinstance(node, Node): - items = iter_attributes(node) - else: - pass + return - for _, value in items: - self._visit(value, **kwargs) + if isinstance(node, Mapping): + for value in node.values(): + self._visit(value, **kwargs) + return + + if isinstance(node, Iterable): + for value in node: + self._visit(value, **kwargs) + return + + if isinstance(node, Node): + for _, value in iter_attributes(node): + self._visit(value, **kwargs) class IRNodeMapper: @@ -84,8 +86,9 @@ def _visit(self, node: Node, **kwargs): def generic_visit(self, node: Node, **kwargs): if isinstance(node, (str, bytes, bytearray)): return node - elif isinstance(node, collections.abc.Iterable): - if isinstance(node, collections.abc.Mapping): + + if isinstance(node, Iterable): + if isinstance(node, Mapping): items = node.items() else: items = enumerate(node) @@ -93,20 +96,20 @@ def generic_visit(self, node: Node, **kwargs): delattr_op = operator.delitem elif isinstance(node, Node): items = iter_attributes(node) - setattr_op = setattr - delattr_op = delattr + setattr_op = setattr # type: ignore[assignment] + delattr_op = delattr # type: ignore[assignment] else: return node - del_items = [] + del_items: list = [] for key, old_value in items: new_value = self._visit(old_value, **kwargs) if new_value == NOTHING: del_items.append(key) elif new_value != old_value: - setattr_op(node, key, new_value) + setattr_op(node, key, new_value) # type: ignore[call-overload] for key in reversed(del_items): # reversed, so that keys remain valid in sequences - delattr_op(node, key) + delattr_op(node, key) # type: ignore[call-overload] return node diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index fbffc354a0..c713d28bb7 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -140,11 +140,11 @@ import operator import sys from abc import ABC -from typing import List, Optional, Sequence +from typing import Optional, Sequence import numpy as np -from gt4py.cartesian.definitions import CartesianSpace +from gt4py.cartesian.gtc.definitions import CartesianSpace from gt4py.cartesian.utils.attrib import ( Any as Any, Dict as DictOf, @@ -309,7 +309,7 @@ def frontend_type_to_native_type( } -DataType.NATIVE_TYPE_TO_NUMPY = { +DataType.NATIVE_TYPE_TO_NUMPY = { # type: ignore[attr-defined] DataType.DEFAULT: "float_", DataType.BOOL: "bool", DataType.INT8: "int8", @@ -320,7 +320,7 @@ def frontend_type_to_native_type( DataType.FLOAT64: "float64", } -DataType.NUMPY_TO_NATIVE_TYPE = {value: key for key, value in DataType.NATIVE_TYPE_TO_NUMPY.items()} +DataType.NUMPY_TO_NATIVE_TYPE = {value: key for key, value in DataType.NATIVE_TYPE_TO_NUMPY.items()} # type: ignore[attr-defined] # ---- IR: expressions ---- @@ -384,7 +384,7 @@ class FieldRef(Ref): @classmethod def at_center( - cls, name: str, axes: Sequence[str], data_index: Optional[List[int]] = None, loc=None + cls, name: str, axes: Sequence[str], data_index: Optional[Sequence[Expr]] = None, loc=None ): return cls( name=name, offset={axis: 0 for axis in axes}, data_index=data_index or [], loc=loc @@ -466,7 +466,7 @@ def arity(self): return type(self).IR_OP_TO_NUM_ARGS[self] -NativeFunction.IR_OP_TO_NUM_ARGS = { +NativeFunction.IR_OP_TO_NUM_ARGS = { # type: ignore[attr-defined] NativeFunction.ABS: 1, NativeFunction.MIN: 2, NativeFunction.MAX: 2, @@ -536,13 +536,13 @@ def python_symbol(self): return type(self).IR_OP_TO_PYTHON_SYMBOL[self] -UnaryOperator.IR_OP_TO_PYTHON_OP = { +UnaryOperator.IR_OP_TO_PYTHON_OP = { # type: ignore[attr-defined] UnaryOperator.POS: operator.pos, UnaryOperator.NEG: operator.neg, UnaryOperator.NOT: operator.not_, } -UnaryOperator.IR_OP_TO_PYTHON_SYMBOL = { +UnaryOperator.IR_OP_TO_PYTHON_SYMBOL = { # type: ignore[attr-defined] UnaryOperator.POS: "+", UnaryOperator.NEG: "-", UnaryOperator.NOT: "not", @@ -586,7 +586,7 @@ def python_symbol(self): return type(self).IR_OP_TO_PYTHON_SYMBOL[self] -BinaryOperator.IR_OP_TO_PYTHON_OP = { +BinaryOperator.IR_OP_TO_PYTHON_OP = { # type: ignore[attr-defined] BinaryOperator.ADD: operator.add, BinaryOperator.SUB: operator.sub, BinaryOperator.MUL: operator.mul, @@ -603,7 +603,7 @@ def python_symbol(self): BinaryOperator.NE: operator.ne, } -BinaryOperator.IR_OP_TO_PYTHON_SYMBOL = { +BinaryOperator.IR_OP_TO_PYTHON_SYMBOL = { # type: ignore[attr-defined] BinaryOperator.ADD: "+", BinaryOperator.SUB: "-", BinaryOperator.MUL: "*", @@ -722,12 +722,6 @@ def symbol(self): def __str__(self) -> str: return self.name - def __lshift__(self, steps: int): - return self.cycle(steps=-steps) - - def __rshift__(self, steps: int): - return self.cycle(steps=steps) - class BaseAxisBound(Node, ABC): level = attribute(of=LevelMarker)