Skip to content
20 changes: 6 additions & 14 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,23 +412,15 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._lower_and_map("if_", *node.args)

cond_ = self.visit(node.args[0])
true_ = self.visit(node.args[1])
false_ = self.visit(node.args[2])
cond_symref_name = f"__cond_{cond_.fingerprint()}"

def create_if(
true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec]
) -> itir.FunCall:
return _map(
"if_",
(im.ref(cond_symref_name), true_, false_),
(node.args[0].type, *arg_types),
result = im.tree_map(
im.lambda_("__a", "__b")(
im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b"))
)

result = lowering_utils.process_elements(
create_if,
(self.visit(node.args[1]), self.visit(node.args[2])),
node.type,
arg_types=(node.args[1].type, node.args[2].type),
)
)(true_, false_)

return im.let(cond_symref_name, cond_)(result)

Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def map_(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def tree_map(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def make_const_list(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -498,7 +503,8 @@ def get_domain_range(*args):
"lift",
"make_const_list",
"make_tuple",
"map_",
"tree_map",
"map_", # TODO: rename to map_list
"named_range",
"neighbors",
"reduce",
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,11 @@ def map_(op):
return call(call("map_")(op))


def tree_map(op):
"""Create a `tree_map` call: tree_map(op)(tup1, tup2, ...)."""
return call(call("tree_map")(op))


def reduce(op, expr):
"""Create a `reduce` call."""
return call(call("reduce")(op, expr))
Expand Down
38 changes: 38 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
prune_empty_concat_where,
remove_broadcast,
symbol_ref_utils,
unroll_tree_map,
)
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
Expand Down Expand Up @@ -176,6 +177,26 @@ def apply_common_transforms(
) # domain inference does not support dynamic offsets yet
ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids)

# After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that
# domain inference does not encounter `as_fieldop` nodes inside dead tuple elements
# (which would receive NEVER domain). Do multiple iterations for nested `let`s.
for _ in range(10):
collapsed = ir
ir = CollapseTuple.apply(
ir,
enabled_transformations=(
CollapseTuple.Transformation.PROPAGATE_TUPLE_GET
| CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE
),
uids=uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
if ir == collapsed:
break
Comment thread
SF-N marked this conversation as resolved.
else:
raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.")
Comment on lines +182 to +199
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this test_reduction_expression_with_where_and_tuples fails with ValueError: 'target_domain' cannot be 'NEVER' unless "allow_uninferred=True".

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: probably this is also the test case where the loop is required. I'll take a look if another configuration of the pass helps to avoid the loop.

Comment thread
SF-N marked this conversation as resolved.

ir = infer_domain.infer_program(
ir,
Expand Down Expand Up @@ -290,6 +311,23 @@ def apply_fieldview_transforms(

ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids)
for _ in range(10):
prev = ir
ir = CollapseTuple.apply(
ir,
enabled_transformations=(
CollapseTuple.Transformation.PROPAGATE_TUPLE_GET
| CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE
),
uids=uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
if ir == prev:
break
Comment thread
SF-N marked this conversation as resolved.
else:
raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.")

ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program

ir = infer_domain.infer_program(
Expand Down
62 changes: 62 additions & 0 deletions src/gt4py/next/iterator/transforms/unroll_tree_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import functools

from gt4py import eve
from gt4py.next import utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.type_system import inference as itir_inference
from gt4py.next.type_system import type_specifications as ts


@dataclasses.dataclass
class UnrollTreeMap(eve.NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("domain",)

uids: utils.IDGeneratorPool

@classmethod
def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool):
return cls(uids=uids).visit(program)

def visit_FunCall(self, node: itir.FunCall):
node = self.generic_visit(node)

if not cpm.is_call_to(node.fun, "tree_map"):
return node

f = node.fun.args[0]
tup_args = node.args
tup_types: list[ts.TupleType] = []
for tup in tup_args:
itir_inference.reinfer(tup)
assert isinstance(tup.type, ts.TupleType)
tup_types.append(tup.type)
Comment thread
SF-N marked this conversation as resolved.

tup_refs = [next(self.uids["_utm"]) for _ in tup_args]

@utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: im.make_tuple(*elts),
with_path_arg=True,
)
def mapper(*args):
*_el_types, path = args
return im.call(f)(
*(
functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, im.ref(ref_name))
for ref_name in tup_refs
)
)

result = im.let(*zip(tup_refs, tup_args))(mapper(*tup_types))

itir_inference.reinfer(result)
return result
Comment thread
SF-N marked this conversation as resolved.
29 changes: 27 additions & 2 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from gt4py.next.iterator import builtins, ir as itir
from gt4py.next.iterator.type_system import type_specifications as it_ts
from gt4py.next.type_system import type_info, type_specifications as ts
from gt4py.next.utils import tree_map


def _type_synth_arg_cache_key(type_or_synth: TypeOrTypeSynthesizer) -> int:
Expand Down Expand Up @@ -203,7 +202,7 @@ def if_(
pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType
) -> ts.DataType:
if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType):
return tree_map(
return utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]),
)(functools.partial(if_, pred))(true_branch, false_branch)
Expand Down Expand Up @@ -633,6 +632,32 @@ def applied_map(
return applied_map


@_register_builtin_type_synthesizer
def tree_map(op: TypeSynthesizer) -> TypeSynthesizer:
@type_synthesizer
def applied_map(
Comment thread
SF-N marked this conversation as resolved.
*args: ts.TupleType, offset_provider_type: common.OffsetProviderType
) -> ts.TupleType:
if not args:
raise TypeError("tree_map requires at least one argument.")
if not all(isinstance(a, ts.TupleType) for a in args):
raise TypeError(
"tree_map requires all top-level arguments to be TupleType, "
f"got {[type(a).__name__ for a in args]}."
)

def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec:
return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value]

return utils.tree_map( # type: ignore[return-value]
leaf_op,
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]),
)(*args)

return applied_map


@_register_builtin_type_synthesizer
def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer:
@type_synthesizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,9 @@ def foo(
lowered
) # we generate a let for the condition which is removed by inlining for easier testing

reference = im.make_tuple(
im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")),
im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")),
)
reference = im.tree_map(
im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b")))
)("b", "c")

assert lowered_inlined.expr == reference

Expand Down
17 changes: 17 additions & 0 deletions tests/next_tests/unit_tests/iterator_tests/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,23 @@ def expression_test_cases():
),
ts.ListType(element_type=int_type, offset_type=V2EDim),
),
# tree_map
(
im.tree_map(im.ref("plus"))(
im.ref("t1", ts.TupleType(types=[int_type, int_type])),
im.ref("t2", ts.TupleType(types=[int_type, int_type])),
),
ts.TupleType(types=[int_type, int_type]),
),
(
im.tree_map(im.ref("not_"))(
im.ref(
"t",
ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]),
),
),
ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]),
),
# reduce
(im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type),
(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next import common, utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms.unroll_tree_map import UnrollTreeMap
from gt4py.next.type_system import type_specifications as ts

IDim = common.Dimension("IDim")
T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
i_field = ts.FieldType(dims=[IDim], dtype=T)
i_tuple_field = ts.TupleType(types=[i_field, i_field])
i_nested_tuple_field = ts.TupleType(types=[i_tuple_field, i_field])

i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1))


def _make_program(
params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field
) -> itir.Program:
return itir.Program(
id="testee",
function_definitions=[],
params=[*params, im.sym("out", out_type)],
declarations=[],
body=[
itir.SetAt(
expr=expr,
domain=i_domain,
target=im.ref("out", out_type),
)
],
)


def _neg():
return im.lambda_("__a")(im.op_as_fieldop("neg")("__a"))


def _plus():
return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b"))


def test_multi_arg():
uids = utils.IDGeneratorPool()
program = _make_program(
[im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)],
im.call(im.call("tree_map")(_plus()))(
im.ref("a", i_tuple_field), im.ref("b", i_tuple_field)
),
out_type=i_tuple_field,
)
result = UnrollTreeMap.apply(program, uids=uids)

expected = _make_program(
[im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)],
im.let(("_utm_0", "a"), ("_utm_1", "b"))(
im.make_tuple(
im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")),
im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")),
)
),
out_type=i_tuple_field,
)
assert result == expected


def test_nested():
uids = utils.IDGeneratorPool()
program = _make_program(
[im.sym("t", i_nested_tuple_field)],
im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)),
out_type=i_nested_tuple_field,
)
result = UnrollTreeMap.apply(program, uids=uids)

expected = _make_program(
[im.sym("t", i_nested_tuple_field)],
im.let("_utm_0", "t")(
im.make_tuple(
im.make_tuple(
im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))),
im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))),
),
im.call(_neg())(im.tuple_get(1, "_utm_0")),
)
),
out_type=i_nested_tuple_field,
)
assert result == expected