Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/gt4py/next/ffront/field_operator_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

import functools
from typing import Any, Generic, TypeAlias, TypeVar, Union

from gt4py import eve
from gt4py.eve import (
Expand All @@ -22,6 +21,7 @@
datamodels,
utils as eve_utils,
)
from gt4py.eve.extended_typing import Any, Generic, TypeAlias, TypeVar, Union
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.type_definitions import StrEnum
from gt4py.next.ffront import dialect_ast_enums, type_specifications as ts_ffront
Expand Down Expand Up @@ -123,6 +123,32 @@ class TupleExpr(Expr):
elts: list[Expr]


# TODO(tehrengruber): extend this to supported nested tuple comprehension.
# e.g. `tuple(element_expr for child in nested_tuple for grand_child in child)`
# would be represented by:
# ```
# class TupleComprehension(Expr): # ruff: noqa: ERA001
# inner: TupleComprehensionMapper | NestedTupleCompr # ruff: noqa: ERA001
# class NestedTupleCompr(Expr, SymbolTableTrait): # ruff: noqa: ERA001
# params: tuple[DataSymbol] # ruff: noqa: ERA001
# body: TupleComprehension # ruff: noqa: ERA001
# ```
class TupleComprehension(Expr):
"""
tuple(element_expr for target in iterable)
"""

inner: TupleComprehensionMapper
iterable: Expr


# this is essentially a lambda, the difference is for a lambda we might not know the type of the
# args, therefor this is named differently at the moment.
class TupleComprehensionMapper(LocatedNode, SymbolTableTrait):
target: Any # should be: NestedInTuple[DataSymbol] but this has a problem in eve
element_expr: Expr


class UnaryOp(Expr):
op: dialect_ast_enums.UnaryOperator
operand: Expr
Expand Down
64 changes: 63 additions & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import textwrap
from typing import Any, Optional, Sequence, TypeAlias, TypeVar, cast

Expand All @@ -24,6 +23,7 @@
from gt4py.next.ffront.foast_passes import utils as foast_utils
from gt4py.next.iterator import builtins
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
from gt4py.next.utils import tree_map


OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode)
Expand Down Expand Up @@ -428,6 +428,10 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri
f"Tuples need to be indexed with literal integers, got '{node.index}'.",
) from ex
new_type = types[index]
case ts.VarArgType(element_type=element_type):
new_type = (
element_type # TODO: we only temporarily allow any index for vararg types
)
case ts.OffsetType(source=source, target=(target1, target2)):
if not target2.kind == DimensionKind.LOCAL:
raise errors.DSLError(
Expand Down Expand Up @@ -674,6 +678,64 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx
new_type = ts.TupleType(types=[element.type for element in new_elts])
return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location)

def visit_TupleComprehension(
self, node: foast.TupleComprehension, **kwargs: Any
) -> foast.TupleComprehension:
target = self.visit(node.inner.target, **kwargs)
iterable = self.visit(node.iterable, **kwargs)
if isinstance(iterable.type, ts.TupleType):
if len(iterable.type.types) > 0 and not all(
t == iterable.type.types[0] for t in iterable.type.types
):
raise errors.DSLError(
iterable.location,
"Not implemented. All elements of the iterable in a tuple comprehensions must have the same type.",
)
element_type = iterable.type.types[0]
elif isinstance(iterable.type, ts.VarArgType):
element_type = iterable.type.element_type
else:
raise errors.DSLError(
iterable.location,
f"Iterable in generator expression must be a tuple, got '{iterable.type}'.",
)

inner_kwargs = {"symtable": node.inner.annex.symtable, **kwargs}

@tree_map(with_path_arg=True)
def process_target(target_el: foast.Symbol, path: tuple[int, ...]) -> None:
try:
type_ = element_type
for i in path:
if not isinstance(type_, ts.TupleType) or len(type_.types) <= i:
raise IndexError()
type_ = type_.types[i]
return self.visit(target_el, refine_type=type_, **inner_kwargs)
except IndexError:
raise errors.DSLError(
target_el.location, f"Cannot unpack non-iterable '{type_}' object."
) from None

new_target = process_target(target)

element_expr = self.visit(node.inner.element_expr, **inner_kwargs)

return_type: ts.TupleType | ts.VarArgType
if isinstance(iterable.type, ts.TupleType):
return_type = ts.TupleType(types=[element_expr.type] * len(iterable.type.types))
else:
assert isinstance(iterable.type, ts.VarArgType)
return_type = ts.VarArgType(element_type=element_expr.type)

return foast.TupleComprehension(
inner=foast.TupleComprehensionMapper(
target=new_target, element_expr=element_expr, location=node.location
),
iterable=iterable,
location=node.location,
type=return_type,
)

def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call:
new_func = self.visit(node.func, **kwargs)
new_args = self.visit(node.args, **kwargs)
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/ffront/foast_pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o

UnaryOp = as_fmt("{op}{operand}")

def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> str:
element_expr = self.visit(node.inner.element_expr, **kwargs)
target = self.visit(node.inner.target, **kwargs)
iterable = self.visit(node.iterable, **kwargs)
return f"tuple(({element_expr} for {target} in {iterable}))"

def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str:
if node.op is dialect_ast_enums.UnaryOperator.NOT:
op = "not "
Expand Down
24 changes: 24 additions & 0 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import dataclasses
import functools
from typing import Any, Callable, Optional

from gt4py import eve
Expand Down Expand Up @@ -257,6 +258,29 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr:
def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr:
return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts])

def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr:
target = self.visit(node.inner.target, **kwargs)
element_expr = self.visit(node.inner.element_expr, **kwargs)

# e.g. `(... for el1, el2 in ...)` -> `(let el1 = t[0], el2[1] ... for t in ...)`
if isinstance(target, tuple):
flat_targets = utils.flatten_nested_tuple(target)
new_target = next(self.uid_generator["__tuple_comprh"])
flat_targets_vals = utils.flatten_nested_tuple(
utils.tree_map(
lambda _, path: functools.reduce(
lambda el, i: im.tuple_get(i, el), path, new_target
),
with_path_arg=True,
)(target)
)
target = new_target
element_expr = im.let(*zip(flat_targets, flat_targets_vals))(element_expr)

return im.call(im.call("map_tuple")(im.lambda_(target)(element_expr)))(
self.visit(node.iterable, **kwargs)
)

def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
# TODO(tehrengruber): extend iterator ir to support unary operators
dtype = type_info.extract_dtype(node.type)
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/ffront/foast_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import toolchain, workflow
from gt4py.next.type_system import type_info, type_specifications as ts
from gt4py.next.type_system import type_specifications as ts


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef:
*partial_program_type.definition.kw_only_args.keys(),
]
assert isinstance(type_, ts.CallableType)
assert arg_types[-1] == type_info.return_type(
type_, with_args=list(arg_types), with_kwargs=kwarg_types
)
# assert arg_types[-1] == type_info.return_type(
# type_, with_args=list(arg_types), with_kwargs=kwarg_types
# )
assert args_names[-1] == "out"

params_decl: list[past.Symbol] = [
Expand Down
61 changes: 53 additions & 8 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import ast
import textwrap
import typing
from typing import Any, Type

import gt4py.eve as eve
from gt4py.eve.extended_typing import Any, NestedTuple, Type
from gt4py.next import errors
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -336,8 +336,13 @@ def visit_Return(self, node: ast.Return, **kwargs: Any) -> foast.Return:
def visit_Expr(self, node: ast.Expr) -> foast.Expr:
return self.visit(node.value)

def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name:
return foast.Name(id=node.id, location=self.get_location(node))
def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.DataSymbol | foast.Name:
loc = self.get_location(node)
if isinstance(node.ctx, ast.Store):
return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None))
else:
assert isinstance(node.ctx, ast.Load)
return foast.Name(id=node.id, location=loc)

def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp:
return foast.UnaryOp(
Expand Down Expand Up @@ -469,24 +474,64 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator:
return foast.CompareOperator.NOTEQ

def _verify_builtin_type_constructor(self, node: ast.Call) -> None:
if len(node.args) > 0:
arg = node.args[0]
assert isinstance(node.func, ast.Name)
(arg,) = (
node.args
) # note for review: the change here is unrelated to the actual pr and just a small cleanup
if node.func.id == "tuple":
if not (
isinstance(arg, ast.Constant)
or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant))
or isinstance(arg, ast.GeneratorExp)
):
raise errors.DSLError(
self.get_location(node),
f"'{self._func_name(node)}()' only takes literal arguments.",
f"'{self._func_name(node)}()' only takes literal arguments or a generator expression.",
)

def _func_name(self, node: ast.Call) -> str:
return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly.

def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call:
# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call | foast.TupleComprehension:
if isinstance(node.func, ast.Name):
func_name = self._func_name(node)

if (
func_name == "tuple"
and len(node.args) == 1
and isinstance(gen_expr := node.args[0], ast.GeneratorExp)
):
if len(gen_expr.generators) != 1:
raise errors.DSLError(
self.get_location(node),
"Nested generator expressions are not supported.",
)
if gen_expr.generators[0].ifs != []:
raise errors.DSLError(
self.get_location(node),
"Conditionals are not supported in generator expressions as they size of "
"the result can only be deduced at runtime.",
)

def parse_target(target: ast.expr) -> NestedTuple[foast.Name]:
if isinstance(target, ast.Tuple):
return tuple(parse_target(el) for el in target.elts)
assert isinstance(target, ast.Name)
return self.visit(target, **kwargs)

target = parse_target(gen_expr.generators[0].target)

return foast.TupleComprehension(
inner=foast.TupleComprehensionMapper(
target=target,
element_expr=self.visit(gen_expr.elt, **kwargs),
location=self.get_location(node),
),
iterable=self.visit(gen_expr.generators[0].iter, **kwargs),
location=self.get_location(node),
)

# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call:
operator_return_type = type_info.return_type(
new_func.type, with_args=arg_types, with_kwargs=kwarg_types
)
if operator_return_type != new_kwargs["out"].type:
if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type):
raise ValueError(
"Expected keyword argument 'out' to be of "
f"type '{operator_return_type}', got "
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def get_domain_range(*args):
"lift",
"make_const_list",
"make_tuple",
"map_",
"map_tuple",
"map_", # TODO: rename to map_list
"named_range",
"neighbors",
"reduce",
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
prune_empty_concat_where,
remove_broadcast,
symbol_ref_utils,
unroll_map_tuple,
)
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
Expand Down Expand Up @@ -176,6 +177,7 @@ def apply_common_transforms(
) # domain inference does not support dynamic offsets yet
ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids)

ir = infer_domain.infer_program(
ir,
Expand Down Expand Up @@ -290,6 +292,7 @@ def apply_fieldview_transforms(

ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids)
ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program

ir = infer_domain.infer_program(
Expand Down
Loading
Loading