feat[next]: Tracer support part 1: tree_map#2586
feat[next]: Tracer support part 1: tree_map#2586SF-N wants to merge 11 commits intoGridTools:mainfrom
Conversation
…porting nesting (extracted from GridTools#2487)
There was a problem hiding this comment.
Pull request overview
This PR introduces an iterator-level tree_map builtin as an IR operator to support mapping functions over (nested) tuples, as a first step towards tracer support and future vector operations. It includes type synthesis for tree_map and a transform (UnrollTreeMap) that lowers tree_map(f)(...) into explicit make_tuple / tuple_get IR.
Changes:
- Add
tree_mapbuiltin plumbing (builtin dispatch + IR maker helper) and update tuple-wherelowering to emittree_map. - Add
tree_maptype synthesizer and a newUnrollTreeMaptransform, wired into the iterator pass pipeline. - Add unit tests for the
_unrollhelper and adjust existing frontend lowering expectations.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py | Adds unit tests for _unroll tuple expansion behavior. |
| tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py | Updates tuple where reference IR to use tree_map. |
| src/gt4py/next/iterator/type_system/type_synthesizer.py | Registers and implements type synthesis for tree_map. |
| src/gt4py/next/iterator/transforms/unroll_tree_map.py | New transform to unroll tree_map into tuple primitives. |
| src/gt4py/next/iterator/transforms/pass_manager.py | Runs UnrollTreeMap and tuple-collapsing before domain inference. |
| src/gt4py/next/iterator/ir_utils/ir_makers.py | Adds im.tree_map(...) helper for constructing IR. |
| src/gt4py/next/iterator/builtins.py | Adds tree_map to builtin dispatch and builtin name set. |
| src/gt4py/next/ffront/foast_to_gtir.py | Lowers tuple where via tree_map instead of explicit tuple construction. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that | ||
| # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements | ||
| # (which would receive NEVER domain). Do multiple iterations for nested `let`s. | ||
| for _ in range(10): | ||
| collapsed = ir | ||
| ir = CollapseTuple.apply( | ||
| ir, | ||
| enabled_transformations=( | ||
| CollapseTuple.Transformation.PROPAGATE_TUPLE_GET | ||
| | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE | ||
| ), | ||
| uids=uids, | ||
| offset_provider_type=offset_provider_type, | ||
| ) # type: ignore[assignment] # always an itir.Program | ||
| if ir == collapsed: | ||
| break | ||
| else: | ||
| raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") |
There was a problem hiding this comment.
Without this test_reduction_expression_with_where_and_tuples fails with ValueError: 'target_domain' cannot be 'NEVER' unless "allow_uninferred=True".
There was a problem hiding this comment.
Note: probably this is also the test case where the loop is required. I'll take a look if another configuration of the pass helps to avoid the loop.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| f"Call to object of type '{type(node.func.type).__name__}' not understood." | ||
| ) | ||
|
|
||
| def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: |
There was a problem hiding this comment.
I think _visit_astype cannot use tree_map because it needs type-dependent lowering per leaf: fields use _map(cast) while scalars use cast directly. Let's discuss if you have something else in mind.
|
|
||
| return im.let(cond_symref_name, cond_)(result) | ||
|
|
||
| def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: |
There was a problem hiding this comment.
As far as I see, _visit_concat_where already has its own expand_tuple_args pass for handling nested tuples, and each branch can have a different domain. Attempting to wrap it in tree_map caused type inference failures. Let's discuss if you have something else in mind.
There was a problem hiding this comment.
Other reasons: A single concat_where is better to digest for optimizations. The lowering would actually be complicated if we would emit tree_map in foast_to_gtir.
As a first step towards tracer support (enabling vector operations), this PR introduces a new
tree_mapoperator for mapping functions over tuples (including nesting).tree_mapis unrolled tomake_tuplecalls inUnrollTreeMap.