-
Notifications
You must be signed in to change notification settings - Fork 56
feat[next]: Tracer support part 1: tree_map #2586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1b4707c
902f8a3
02f881f
0ec4692
ab84ecc
36d6956
152300e
d459b0e
97af81e
8d75708
067bc29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| prune_empty_concat_where, | ||
| remove_broadcast, | ||
| symbol_ref_utils, | ||
| unroll_tree_map, | ||
| ) | ||
| from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet | ||
| from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple | ||
|
|
@@ -176,6 +177,26 @@ def apply_common_transforms( | |
| ) # domain inference does not support dynamic offsets yet | ||
| ir = infer_domain_ops.InferDomainOps.apply(ir) | ||
| ir = concat_where.canonicalize_domain_argument(ir) | ||
| ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) | ||
|
|
||
| # After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that | ||
| # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements | ||
| # (which would receive NEVER domain). Do multiple iterations for nested `let`s. | ||
| for _ in range(10): | ||
| collapsed = ir | ||
| ir = CollapseTuple.apply( | ||
| ir, | ||
| enabled_transformations=( | ||
| CollapseTuple.Transformation.PROPAGATE_TUPLE_GET | ||
| | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE | ||
| ), | ||
| uids=uids, | ||
| offset_provider_type=offset_provider_type, | ||
| ) # type: ignore[assignment] # always an itir.Program | ||
| if ir == collapsed: | ||
| break | ||
| else: | ||
| raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") | ||
|
Comment on lines
+182
to
+199
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: probably this is also the test case where the loop is required. I'll take a look if another configuration of the pass helps to avoid the loop.
SF-N marked this conversation as resolved.
|
||
|
|
||
| ir = infer_domain.infer_program( | ||
| ir, | ||
|
|
@@ -290,6 +311,23 @@ def apply_fieldview_transforms( | |
|
|
||
| ir = infer_domain_ops.InferDomainOps.apply(ir) | ||
| ir = concat_where.canonicalize_domain_argument(ir) | ||
| ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) | ||
| for _ in range(10): | ||
| prev = ir | ||
| ir = CollapseTuple.apply( | ||
| ir, | ||
| enabled_transformations=( | ||
| CollapseTuple.Transformation.PROPAGATE_TUPLE_GET | ||
| | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE | ||
| ), | ||
| uids=uids, | ||
| offset_provider_type=offset_provider_type, | ||
| ) # type: ignore[assignment] # always an itir.Program | ||
| if ir == prev: | ||
| break | ||
|
SF-N marked this conversation as resolved.
|
||
| else: | ||
| raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") | ||
|
|
||
| ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program | ||
|
|
||
| ir = infer_domain.infer_program( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| import dataclasses | ||
| import functools | ||
|
|
||
| from gt4py import eve | ||
| from gt4py.next import utils | ||
| from gt4py.next.iterator import ir as itir | ||
| from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im | ||
| from gt4py.next.iterator.type_system import inference as itir_inference | ||
| from gt4py.next.type_system import type_specifications as ts | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class UnrollTreeMap(eve.NodeTranslator): | ||
| PRESERVED_ANNEX_ATTRS = ("domain",) | ||
|
|
||
| uids: utils.IDGeneratorPool | ||
|
|
||
| @classmethod | ||
| def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): | ||
| return cls(uids=uids).visit(program) | ||
|
|
||
| def visit_FunCall(self, node: itir.FunCall): | ||
| node = self.generic_visit(node) | ||
|
|
||
| if not cpm.is_call_to(node.fun, "tree_map"): | ||
| return node | ||
|
|
||
| f = node.fun.args[0] | ||
| tup_args = node.args | ||
| tup_types: list[ts.TupleType] = [] | ||
| for tup in tup_args: | ||
| itir_inference.reinfer(tup) | ||
| assert isinstance(tup.type, ts.TupleType) | ||
| tup_types.append(tup.type) | ||
|
SF-N marked this conversation as resolved.
|
||
|
|
||
| tup_refs = [next(self.uids["_utm"]) for _ in tup_args] | ||
|
|
||
| @utils.tree_map( | ||
| collection_type=ts.TupleType, | ||
| result_collection_constructor=lambda _, elts: im.make_tuple(*elts), | ||
| with_path_arg=True, | ||
| ) | ||
| def mapper(*args): | ||
| *_el_types, path = args | ||
| return im.call(f)( | ||
| *( | ||
| functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, im.ref(ref_name)) | ||
| for ref_name in tup_refs | ||
| ) | ||
| ) | ||
|
|
||
| result = im.let(*zip(tup_refs, tup_args))(mapper(*tup_types)) | ||
|
|
||
| itir_inference.reinfer(result) | ||
| return result | ||
|
SF-N marked this conversation as resolved.
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from gt4py.next import common, utils | ||
| from gt4py.next.iterator import ir as itir | ||
| from gt4py.next.iterator.ir_utils import ir_makers as im | ||
| from gt4py.next.iterator.transforms.unroll_tree_map import UnrollTreeMap | ||
| from gt4py.next.type_system import type_specifications as ts | ||
|
|
||
| IDim = common.Dimension("IDim") | ||
| T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) | ||
| i_field = ts.FieldType(dims=[IDim], dtype=T) | ||
| i_tuple_field = ts.TupleType(types=[i_field, i_field]) | ||
| i_nested_tuple_field = ts.TupleType(types=[i_tuple_field, i_field]) | ||
|
|
||
| i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) | ||
|
|
||
|
|
||
| def _make_program( | ||
| params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field | ||
| ) -> itir.Program: | ||
| return itir.Program( | ||
| id="testee", | ||
| function_definitions=[], | ||
| params=[*params, im.sym("out", out_type)], | ||
| declarations=[], | ||
| body=[ | ||
| itir.SetAt( | ||
| expr=expr, | ||
| domain=i_domain, | ||
| target=im.ref("out", out_type), | ||
| ) | ||
| ], | ||
| ) | ||
|
|
||
|
|
||
| def _neg(): | ||
| return im.lambda_("__a")(im.op_as_fieldop("neg")("__a")) | ||
|
|
||
|
|
||
| def _plus(): | ||
| return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) | ||
|
|
||
|
|
||
| def test_multi_arg(): | ||
| uids = utils.IDGeneratorPool() | ||
| program = _make_program( | ||
| [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], | ||
| im.call(im.call("tree_map")(_plus()))( | ||
| im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) | ||
| ), | ||
| out_type=i_tuple_field, | ||
| ) | ||
| result = UnrollTreeMap.apply(program, uids=uids) | ||
|
|
||
| expected = _make_program( | ||
| [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], | ||
| im.let(("_utm_0", "a"), ("_utm_1", "b"))( | ||
| im.make_tuple( | ||
| im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), | ||
| im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), | ||
| ) | ||
| ), | ||
| out_type=i_tuple_field, | ||
| ) | ||
| assert result == expected | ||
|
|
||
|
|
||
| def test_nested(): | ||
| uids = utils.IDGeneratorPool() | ||
| program = _make_program( | ||
| [im.sym("t", i_nested_tuple_field)], | ||
| im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)), | ||
| out_type=i_nested_tuple_field, | ||
| ) | ||
| result = UnrollTreeMap.apply(program, uids=uids) | ||
|
|
||
| expected = _make_program( | ||
| [im.sym("t", i_nested_tuple_field)], | ||
| im.let("_utm_0", "t")( | ||
| im.make_tuple( | ||
| im.make_tuple( | ||
| im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))), | ||
| im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))), | ||
| ), | ||
| im.call(_neg())(im.tuple_get(1, "_utm_0")), | ||
| ) | ||
| ), | ||
| out_type=i_nested_tuple_field, | ||
| ) | ||
| assert result == expected |
Uh oh!
There was an error while loading. Please reload this page.