diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 95a2588077..8043964cf2 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -9,7 +9,6 @@ from __future__ import annotations import functools -from typing import Any, Generic, TypeAlias, TypeVar, Union from gt4py import eve from gt4py.eve import ( @@ -22,6 +21,7 @@ datamodels, utils as eve_utils, ) +from gt4py.eve.extended_typing import Any, Generic, TypeAlias, TypeVar, Union from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.type_definitions import StrEnum from gt4py.next.ffront import dialect_ast_enums, type_specifications as ts_ffront @@ -123,6 +123,32 @@ class TupleExpr(Expr): elts: list[Expr] +# TODO(tehrengruber): extend this to supported nested tuple comprehension. +# e.g. `tuple(element_expr for child in nested_tuple for grand_child in child)` +# would be represented by: +# ``` +# class TupleComprehension(Expr): # ruff: noqa: ERA001 +# inner: TupleComprehensionMapper | NestedTupleCompr # ruff: noqa: ERA001 +# class NestedTupleCompr(Expr, SymbolTableTrait): # ruff: noqa: ERA001 +# params: tuple[DataSymbol] # ruff: noqa: ERA001 +# body: TupleComprehension # ruff: noqa: ERA001 +# ``` +class TupleComprehension(Expr): + """ + tuple(element_expr for target in iterable) + """ + + inner: TupleComprehensionMapper + iterable: Expr + + +# this is essentially a lambda, the difference is for a lambda we might not know the type of the +# args, therefor this is named differently at the moment. +class TupleComprehensionMapper(LocatedNode, SymbolTableTrait): + target: Any # should be: NestedInTuple[DataSymbol] but this has a problem in eve + element_expr: Expr + + class UnaryOp(Expr): op: dialect_ast_enums.UnaryOperator operand: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 11c0bfd88b..160dc05e4a 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -5,7 +5,6 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - import textwrap from typing import Any, Optional, Sequence, TypeAlias, TypeVar, cast @@ -24,6 +23,7 @@ from gt4py.next.ffront.foast_passes import utils as foast_utils from gt4py.next.iterator import builtins from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +from gt4py.next.utils import tree_map OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) @@ -428,6 +428,10 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex new_type = types[index] + case ts.VarArgType(element_type=element_type): + new_type = ( + element_type # TODO: we only temporarily allow any index for vararg types + ) case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: raise errors.DSLError( @@ -674,6 +678,64 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx new_type = ts.TupleType(types=[element.type for element in new_elts]) return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location) + def visit_TupleComprehension( + self, node: foast.TupleComprehension, **kwargs: Any + ) -> foast.TupleComprehension: + target = self.visit(node.inner.target, **kwargs) + iterable = self.visit(node.iterable, **kwargs) + if isinstance(iterable.type, ts.TupleType): + if len(iterable.type.types) > 0 and not all( + t == iterable.type.types[0] for t in iterable.type.types + ): + raise errors.DSLError( + iterable.location, + "Not implemented. All elements of the iterable in a tuple comprehensions must have the same type.", + ) + element_type = iterable.type.types[0] + elif isinstance(iterable.type, ts.VarArgType): + element_type = iterable.type.element_type + else: + raise errors.DSLError( + iterable.location, + f"Iterable in generator expression must be a tuple, got '{iterable.type}'.", + ) + + inner_kwargs = {"symtable": node.inner.annex.symtable, **kwargs} + + @tree_map(with_path_arg=True) + def process_target(target_el: foast.Symbol, path: tuple[int, ...]) -> None: + try: + type_ = element_type + for i in path: + if not isinstance(type_, ts.TupleType) or len(type_.types) <= i: + raise IndexError() + type_ = type_.types[i] + return self.visit(target_el, refine_type=type_, **inner_kwargs) + except IndexError: + raise errors.DSLError( + target_el.location, f"Cannot unpack non-iterable '{type_}' object." + ) from None + + new_target = process_target(target) + + element_expr = self.visit(node.inner.element_expr, **inner_kwargs) + + return_type: ts.TupleType | ts.VarArgType + if isinstance(iterable.type, ts.TupleType): + return_type = ts.TupleType(types=[element_expr.type] * len(iterable.type.types)) + else: + assert isinstance(iterable.type, ts.VarArgType) + return_type = ts.VarArgType(element_type=element_expr.type) + + return foast.TupleComprehension( + inner=foast.TupleComprehensionMapper( + target=new_target, element_expr=element_expr, location=node.location + ), + iterable=iterable, + location=node.location, + type=return_type, + ) + def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: new_func = self.visit(node.func, **kwargs) new_args = self.visit(node.args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 8b2e369501..273ecaacaf 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -120,6 +120,12 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o UnaryOp = as_fmt("{op}{operand}") + def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> str: + element_expr = self.visit(node.inner.element_expr, **kwargs) + target = self.visit(node.inner.target, **kwargs) + iterable = self.visit(node.iterable, **kwargs) + return f"tuple(({element_expr} for {target} in {iterable}))" + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: if node.op is dialect_ast_enums.UnaryOperator.NOT: op = "not " diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3825072cb7..7383ff9d9e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -8,6 +8,7 @@ import dataclasses +import functools from typing import Any, Callable, Optional from gt4py import eve @@ -257,6 +258,29 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) + def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: + target = self.visit(node.inner.target, **kwargs) + element_expr = self.visit(node.inner.element_expr, **kwargs) + + # e.g. `(... for el1, el2 in ...)` -> `(let el1 = t[0], el2[1] ... for t in ...)` + if isinstance(target, tuple): + flat_targets = utils.flatten_nested_tuple(target) + new_target = next(self.uid_generator["__tuple_comprh"]) + flat_targets_vals = utils.flatten_nested_tuple( + utils.tree_map( + lambda _, path: functools.reduce( + lambda el, i: im.tuple_get(i, el), path, new_target + ), + with_path_arg=True, + )(target) + ) + target = new_target + element_expr = im.let(*zip(flat_targets, flat_targets_vals))(element_expr) + + return im.call(im.call("map_tuple")(im.lambda_(target)(element_expr)))( + self.visit(node.iterable, **kwargs) + ) + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 05b080b70b..c37cba5a78 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.type_system import type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: *partial_program_type.definition.kw_only_args.keys(), ] assert isinstance(type_, ts.CallableType) - assert arg_types[-1] == type_info.return_type( - type_, with_args=list(arg_types), with_kwargs=kwarg_types - ) + # assert arg_types[-1] == type_info.return_type( + # type_, with_args=list(arg_types), with_kwargs=kwarg_types + # ) assert args_names[-1] == "out" params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ced0ff3905..e4126546c0 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -11,9 +11,9 @@ import ast import textwrap import typing -from typing import Any, Type import gt4py.eve as eve +from gt4py.eve.extended_typing import Any, NestedTuple, Type from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, @@ -336,8 +336,13 @@ def visit_Return(self, node: ast.Return, **kwargs: Any) -> foast.Return: def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) - def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name: - return foast.Name(id=node.id, location=self.get_location(node)) + def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.DataSymbol | foast.Name: + loc = self.get_location(node) + if isinstance(node.ctx, ast.Store): + return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None)) + else: + assert isinstance(node.ctx, ast.Load) + return foast.Name(id=node.id, location=loc) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: return foast.UnaryOp( @@ -469,24 +474,64 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: - if len(node.args) > 0: - arg = node.args[0] + assert isinstance(node.func, ast.Name) + (arg,) = ( + node.args + ) # note for review: the change here is unrelated to the actual pr and just a small cleanup + if node.func.id == "tuple": if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) + or isinstance(arg, ast.GeneratorExp) ): raise errors.DSLError( self.get_location(node), - f"'{self._func_name(node)}()' only takes literal arguments.", + f"'{self._func_name(node)}()' only takes literal arguments or a generator expression.", ) def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. - def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: - # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? + def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call | foast.TupleComprehension: if isinstance(node.func, ast.Name): func_name = self._func_name(node) + + if ( + func_name == "tuple" + and len(node.args) == 1 + and isinstance(gen_expr := node.args[0], ast.GeneratorExp) + ): + if len(gen_expr.generators) != 1: + raise errors.DSLError( + self.get_location(node), + "Nested generator expressions are not supported.", + ) + if gen_expr.generators[0].ifs != []: + raise errors.DSLError( + self.get_location(node), + "Conditionals are not supported in generator expressions as they size of " + "the result can only be deduced at runtime.", + ) + + def parse_target(target: ast.expr) -> NestedTuple[foast.Name]: + if isinstance(target, ast.Tuple): + return tuple(parse_target(el) for el in target.elts) + assert isinstance(target, ast.Name) + return self.visit(target, **kwargs) + + target = parse_target(gen_expr.generators[0].target) + + return foast.TupleComprehension( + inner=foast.TupleComprehensionMapper( + target=target, + element_expr=self.visit(gen_expr.elt, **kwargs), + location=self.get_location(node), + ), + iterable=self.visit(gen_expr.generators[0].iter, **kwargs), + location=self.get_location(node), + ) + + # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if func_name in fbuiltins.TYPE_BUILTIN_NAMES: self._verify_builtin_type_constructor(node) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9d021ceb51..530d407459 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: operator_return_type = type_info.return_type( new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) - if operator_return_type != new_kwargs["out"].type: + if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type): raise ValueError( "Expected keyword argument 'out' to be of " f"type '{operator_return_type}', got " diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index e54c6ea3d7..7b24c91884 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -498,7 +498,8 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_", + "map_tuple", + "map_", # TODO: rename to map_list "named_range", "neighbors", "reduce", diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index a6228c6125..8825ad00ed 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -23,6 +23,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, + unroll_map_tuple, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -176,6 +177,7 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = infer_domain.infer_program( ir, @@ -290,6 +292,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py new file mode 100644 index 0000000000..b47f5ba7d7 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass +class UnrollMapTuple(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.FunCall): + node = self.generic_visit(node) + + if cpm.is_call_to(node.fun, "map_tuple"): + # TODO: we have to duplicate the function here since the domain inference can not handle them yet + f = node.fun.args[0] + tup = node.args[0] + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + tup_ref = next(self.uids["_ump"]) + + result = im.let(tup_ref, tup)( + im.make_tuple( + *(im.call(f)(im.tuple_get(i, tup_ref)) for i in range(len(tup.type.types))) + ) + ) + itir_inference.reinfer(result) + + return result + return node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 16d5da7e3b..98f3540d91 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -633,6 +633,19 @@ def applied_map( return applied_map +@_register_builtin_type_synthesizer +def map_tuple(op: TypeSynthesizer) -> TypeSynthesizer: + @type_synthesizer + def applied_map( + arg: ts.TupleType, offset_provider_type: common.OffsetProviderType + ) -> ts.TupleType: + return ts.TupleType( + types=[op(arg_, offset_provider_type=offset_provider_type) for arg_ in arg.types] + ) + + return applied_map + + @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @type_synthesizer diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index eb70d15947..69fccd33da 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -566,6 +566,14 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: or issubclass(type_class(to_type), symbol_type.constraint) ): return True + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.VarArgType): + return is_concretizable(symbol_type.element_type, to_type.element_type) + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.TupleType): + if len(to_type.types) == 0 or ( + all(type_ == to_type.types[0] for type_ in to_type.types) + and is_concretizable(symbol_type.element_type, to_type.types[0]) + ): + return True elif is_concrete(symbol_type): return symbol_type == to_type return False diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 59ac40f0f3..409138d593 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -148,6 +148,15 @@ def __len__(self) -> int: return len(self.types) +class VarArgType(DataType): + """Represents a variable number of arguments of the same type.""" + + element_type: DataType # TODO: maybe also support different DataTypes + + def __str__(self) -> str: + return f"VarArg[{self.element_type}]" + + class AnyPythonType: """Marker type representing any Python type which cannot be used for instantiation. diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 3671c5b344..1d7a9aa2f7 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -181,8 +181,12 @@ def from_type_hint( case builtins.tuple: if not args: raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") - if Ellipsis in args: - raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") + if len(args) == 2 and args[1] is Ellipsis: + return ts.VarArgType(element_type=from_type_hint_same_ns(args[0])) + elif Ellipsis in args: + raise ValueError( + f"Vararg tuple annotation '{type_hint}' cannot have more than one argument." + ) tuple_types = [from_type_hint_same_ns(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) return ts.TupleType(types=tuple_types) @@ -326,7 +330,19 @@ def from_value(value: Any) -> ts.TypeSpec: return NamespaceProxy(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) - symbol_type = from_type_hint(type_) + if type_ == type[tuple]: + # TODO: this special casing here is not nice, but infer_type is also called on the annotations where + # we don't want to allow unparameterized tuples (or do we?). + symbol_type = ts.ConstructorType( + definition=ts.FunctionType( + pos_only_args=[ts.DeferredType(constraint=None)], + pos_or_kw_args={}, + kw_only_args={}, + returns=ts.DeferredType(constraint=ts.VarArgType), + ) + ) + else: + symbol_type = from_type_hint(type_) if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d552a09a2a..d65ecefb10 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -603,6 +603,15 @@ def _allocate_from_type( for t in types ) ) + case ts.VarArgType(element_type=element_type): + return tuple( + ( + _allocate_from_type( + case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy + ) + for t in [element_type] * 3 # TODO: revisit + ) + ) case ts.NamedCollectionType(types=types) as named_collection_type_spec: container_constructor = ( named_collections.make_named_collection_constructor_from_type_spec( @@ -648,6 +657,8 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> return sum([get_param_size(t, sizes=sizes) for t in types]) case ts.NamedCollectionType(types=types): return sum([get_param_size(t, sizes=sizes) for t in types]) + case ts.VarArgType(element_type=element_type): + return get_param_size(ts.TupleType(types=[element_type] * 3), sizes) # TODO: revisit case _: raise TypeError(f"Can not get size for parameter of type '{param_type}'.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index c58ac5f497..27812ef5d1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -338,6 +338,92 @@ def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: ) +@pytest.mark.uses_tuple_args +def test_fixed_len_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IField, cases.IField], factor: int32 + ) -> tuple[cases.IField, cases.IField]: + return tuple(tracer * factor for tracer in tracers) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t), + ) + + +@pytest.mark.uses_tuple_args +def test_var_len_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee(tracers: tuple[cases.IField, ...], factor: int32) -> tuple[cases.IField, ...]: + return tuple(tracer * factor for tracer in tracers) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t), + ) + + +@pytest.mark.uses_tuple_args +def test_nested_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee( + vals: tuple[tuple[cases.IField, ...], ...], factor: int32 + ) -> tuple[tuple[cases.IField, ...], ...]: + return tuple(tuple(grand_child * factor for grand_child in child) for child in vals) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(tuple(grand_child * f for grand_child in child) for child in t), + ) + + +@pytest.mark.uses_tuple_args +def test_nested_tuple_comprehension_shadowing_names(cartesian_case): + @gtx.field_operator + def testee( + vals: tuple[tuple[cases.IField, ...], ...], factor: int32 + ) -> tuple[tuple[cases.IField, ...], ...]: + return tuple(tuple(child * factor for child in child) for child in vals) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(tuple(child * f for child in child) for child in t), + ) + + +@pytest.mark.uses_tuple_args +def test_multi_target_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee(nested_tuple: tuple[tuple[int32, cases.IField], ...]) -> tuple[cases.IField, ...]: + return tuple(factor * tracer for factor, tracer in nested_tuple) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t: tuple(f * el for f, el in t), + ) + + +@pytest.mark.uses_tuple_args +def test_tuple_vararg(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, cases.IFloatField]: + return tracers[0] * factor, tracers[1] * factor + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t[:2]), + ) + + @pytest.mark.uses_tuple_args @pytest.mark.xfail(reason="Iterator of tuple approach in lowering does not allow this.") def test_tuple_arg_with_unpromotable_dims(unstructured_case): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 57c2a8be3a..d231fc58ec 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -447,3 +447,36 @@ def tuple_index_failure( with pytest.raises(errors.DSLError, match=r"need .* literal"): _ = FieldOperatorParser.apply_to_function(tuple_index_failure) + + +def test_tuple_compr_non_tuple_iterable_failure(): + def testee(arg: float): + return tuple(_ for _ in arg) + + with pytest.raises( + errors.DSLError, + match=re.escape("Iterable in generator expression must be a tuple, got 'float64'."), + ): + _ = FieldOperatorParser.apply_to_function(testee) + + +def test_nested_tuple_compr_failure(): + def testee(nested_tuple: tuple[tuple[gtx.Field[[TDim], float64], ...], ...], factor: int32): + return tuple(grandchild * factor for child in nested_tuple for grandchild in child) + + with pytest.raises( + errors.DSLError, + match=re.escape("Nested generator expressions are not supported."), + ): + _ = FieldOperatorParser.apply_to_function(testee) + + +def test_tuple_compr_unpacking_failure(): + def testee(arg: tuple[int32, ...]): + return tuple(a * b for a, b in arg) + + with pytest.raises( + errors.DSLError, + match=re.escape("Cannot unpack non-iterable 'int32' object."), + ): + _ = FieldOperatorParser.apply_to_function(testee)