From d0e3ae4b35bf29c1eb1dde32d3efdfb94df38d11 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 5 Mar 2026 14:34:56 +0100 Subject: [PATCH 01/32] Add failing test --- .../test_move_dataflow_into_if_body.py | 253 ++++++++++++++++++ 1 file changed, 253 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index d2088a7bb7..0d369f7d40 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -76,6 +76,104 @@ def _make_if_block( ) +def _make_if_block_with_two_args( + state: dace.SDFGState, + outer_sdfg: dace.SDFG, + b1_name: str = "__arg1", + b2_name: str = "__arg2", + b3_name: str = "__arg3", + b4_name: str = "__arg4", + cond_name: str = "__cond", + output_name: str = "__output", + b1_type: dace.typeclass = dace.float64, + b2_type: dace.typeclass = dace.float64, + b3_type: dace.typeclass = dace.float64, + b4_type: dace.typeclass = dace.float64, + output_type: dace.typeclass = dace.float64, +) -> dace_nodes.NestedSDFG: + inner_sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_stmt_")) + + types = {b1_name: b1_type, b2_name: b2_type, b3_name: b3_type, b4_name: b4_type, cond_name: dace.bool_, output_name: output_type} + for name in {b1_name, b2_name, b3_name, b4_name, cond_name, output_name}: + inner_sdfg.add_scalar( + name, + dtype=types[name], + transient=False, + ) + + if_region = dace.sdfg.state.ConditionalBlock(gtx_transformations.utils.unique_name("if")) + inner_sdfg.add_node(if_region, is_start_block=True) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) + tstate = then_body.add_state("true_branch_0_1_2_3_4", is_start_block=True) + tasklet_true = tstate.add_tasklet( + "true_tasklet", + inputs={"__tasklet_in1", "__tasklet_in2"}, + outputs={"__tasklet_out"}, + code="__tasklet_out = __tasklet_in1 + __tasklet_in2", + ) + tstate.add_edge( + tstate.add_access(b1_name), + None, + tasklet_true, + "__tasklet_in1", + dace.Memlet(f"{b1_name}[0]"), + ) + tstate.add_edge( + tstate.add_access(b2_name), + None, + tasklet_true, + "__tasklet_in2", + dace.Memlet(f"{b2_name}[0]"), + ) + tstate.add_edge( + tasklet_true, + "__tasklet_out", + tstate.add_access(output_name), + None, + dace.Memlet(f"{output_name}[0]"), + ) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=inner_sdfg) + fstate = else_body.add_state("false_branch_0_1_2_3_4", is_start_block=True) + tasklet_false = fstate.add_tasklet( + "false_tasklet", + inputs={"__tasklet_in1", "__tasklet_in2"}, + outputs={"__tasklet_out"}, + code="__tasklet_out = __tasklet_in1 - __tasklet_in2", + ) + fstate.add_edge( + fstate.add_access(b3_name), + None, + tasklet_false, + "__tasklet_in1", + dace.Memlet(f"{b3_name}[0]"), + ) + fstate.add_edge( + fstate.add_access(b4_name), + None, + tasklet_false, + "__tasklet_in2", + dace.Memlet(f"{b4_name}[0]"), + ) + fstate.add_edge( + tasklet_false, + "__tasklet_out", + fstate.add_access(output_name), + None, + dace.Memlet(f"{output_name}[0]"), + ) + + if_region.add_branch(dace.sdfg.state.CodeBlock(cond_name), then_body) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"not {cond_name}"), else_body) + + return state.add_nested_sdfg( + sdfg=inner_sdfg, + inputs={b1_name, b2_name, b3_name, b4_name, cond_name}, + outputs={output_name}, + ) + + def _perform_test( sdfg: dace.SDFG, explected_applies: int, @@ -674,6 +772,161 @@ def test_if_mover_dependent_branch_3(): assert set(gnames) == sdfg.arrays.keys() +def test_if_mover_dependent_branch_4(): + """ + Essentially tests the following situation: + ```python + s = buu(...) + a = foo(s, ...) + b = bar(s, a, ...) + c = baz(...) + e = qux(...) + if c: + d = a + b + else: + d = e + s + ``` + """ + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_dependent_branches")) + state = sdfg.add_state(is_start_block=True) + + # Inputs + input_names = ["a", "b", "c", "d", "e", "s"] + for name in input_names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + # Temporaries + temporary_names = ["a1", "a2", "b1", "b2", "c1", "s1", "e1"] + for name in temporary_names: + sdfg.add_scalar( + name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True + ) + + a1, a2, b1, b2, c1, s1, e1 = (state.add_access(name) for name in temporary_names) + me, mx = state.add_map("comp", ndrange={"__i": "0:10"}) + + # The auxiliary computation involving `s`: + tasklet_s1 = state.add_tasklet( + "tasklet_s1", inputs={"__in"}, outputs={"__out"}, code="__out = - __in" + ) + + state.add_edge(state.add_access("s"), None, me, "IN_s", dace.Memlet("s[0:10]")) + state.add_edge(me, "OUT_s", tasklet_s1, "__in", dace.Memlet("s[__i]")) + state.add_edge(tasklet_s1, "__out", s1, None, dace.Memlet("s1[0]")) + + state.add_edge(state.add_access("e"), None, me, "IN_e", dace.Memlet("e[0:10]")) + state.add_edge(me, "OUT_e", e1, None, dace.Memlet("e1[0]")) + + # Computation involving `a`: + tasklet_a1 = state.add_tasklet( + "tasklet_a1", + inputs={"__in", "__in_s"}, + outputs={"__out"}, + code="__out = math.sin(__in) + __in_s", + ) + tasklet_a2 = state.add_tasklet( + "tasklet_a2", inputs={"__in"}, outputs={"__out"}, code="__out = math.exp(__in)" + ) + state.add_edge(state.add_access("a"), None, me, "IN_a", dace.Memlet("a[0:10]")) + state.add_edge(me, "OUT_a", tasklet_a1, "__in", dace.Memlet("a[__i]")) + state.add_edge(s1, None, tasklet_a1, "__in_s", dace.Memlet("s1[0]")) + state.add_edge(tasklet_a1, "__out", a1, None, dace.Memlet("a1[0]")) + state.add_edge(a1, None, tasklet_a2, "__in", dace.Memlet("a1[0]")) + state.add_edge(tasklet_a2, "__out", a2, None, dace.Memlet("a2[0]")) + + # Computation involving `b`: + tasklet_b1 = state.add_tasklet( + "tasklet_b1", inputs={"__in1", "__in2"}, outputs={"__out"}, code="__out = math.sin(__in1) * math.cos(__in2)" + ) + tasklet_b2 = state.add_tasklet( + "tasklet_b2", + inputs={"__in", "__in_s"}, + outputs={"__out"}, + code="__out = math.sin(__in) - __in_s", + ) + + state.add_edge(state.add_access("b"), None, me, "IN_b", dace.Memlet("b[0:10]")) + state.add_edge(me, "OUT_b", tasklet_b1, "__in1", dace.Memlet("b[__i]")) + state.add_edge(a2, None, tasklet_b1, "__in2", dace.Memlet("a2[0]")) + state.add_edge(tasklet_b1, "__out", b1, None, dace.Memlet("b1[0]")) + state.add_edge(b1, None, tasklet_b2, "__in", dace.Memlet("b1[0]")) + state.add_edge(s1, None, tasklet_b2, "__in_s", dace.Memlet("s1[0]")) + state.add_edge(tasklet_b2, "__out", b2, None, dace.Memlet("b2[0]")) + + # Now the condition. + tasklet_cond = state.add_tasklet( + "tasklet_cond", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in <= 0.5", + ) + state.add_edge(state.add_access("c"), None, me, "IN_c", dace.Memlet("c[0:10]")) + state.add_edge(me, "OUT_c", tasklet_cond, "__in", dace.Memlet("c[__i]")) + state.add_edge(tasklet_cond, "__out", c1, None, dace.Memlet("c1[0]")) + + # Make the if selection. + if_block = _make_if_block_with_two_args(state=state, outer_sdfg=sdfg) + state.add_edge(a2, None, if_block, "__arg1", dace.Memlet("a2[0]")) + state.add_edge(b2, None, if_block, "__arg2", dace.Memlet("b2[0]")) + state.add_edge(e1, None, if_block, "__arg3", dace.Memlet("e1[0]")) + state.add_edge(s1, None, if_block, "__arg4", dace.Memlet("s1[0]")) + state.add_edge(c1, None, if_block, "__cond", dace.Memlet("c1[0]")) + + # Now handle the output. + state.add_edge(if_block, "__output", mx, "IN_d", dace.Memlet("d[__i]")) + state.add_edge(mx, "OUT_d", state.add_access("d"), None, dace.Memlet("d[0:10]")) + + # Now add the connectors to the Map* + for iname in input_names: + if iname == "d": + continue + me.add_in_connector(f"IN_{iname}") + me.add_out_connector(f"OUT_{iname}") + mx.add_in_connector("IN_d") + mx.add_out_connector("OUT_d") + sdfg.validate() + + sdfg.view() + breakpoint() + + _perform_test(sdfg, explected_applies=1) + + sdfg.view() + breakpoint() + + # # Examine the structure of the SDFG. + # top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) + # assert {ac.data for ac in top_ac} == set(input_names).union(["c1", "s1"]) + # assert len(sdfg.arrays) == len(top_ac) + + # top_tlet: list[dace_nodes.Tasklet] = util.count_nodes(state, dace_nodes.Tasklet, True) + # assert len(top_tlet) == 2 + # assert {"tasklet_cond", "tasklet_s1"} == {tlet.label for tlet in top_tlet} + + # inner_ac: list[dace_nodes.AccessNode] = util.count_nodes( + # if_block.sdfg, dace_nodes.AccessNode, True + # ) + # expected_data: set[str] = ( + # set(temporary_names).union(input_names).union(["__arg1", "__arg2", "__output"]) + # ) + # expected_data.difference_update(["c1", "c", "d", "s"]) + # assert expected_data == {ac.data for ac in inner_ac} + # assert len([ac for ac in inner_ac if ac.data == "s1"]) == 2 + # assert len([ac for ac in inner_ac if ac.data == "__output"]) == 2 + # assert len(expected_data) + 2 == len(inner_ac) + # assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) + + # inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) + # assert len(inner_tlet) == 4 + # expected_tlet = {tlet.label for tlet in [tasklet_a1, tasklet_a2, tasklet_b1, tasklet_b2]} + # assert {tlet.label for tlet in inner_tlet} == expected_tlet + + def test_if_mover_no_ops(): """ Essentially tests the following situation: From b90393ac9a5717485b074404c959c90283f69952 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 5 Mar 2026 16:11:25 +0100 Subject: [PATCH 02/32] Fix test and pass --- .../move_dataflow_into_if_body.py | 88 +++++++++--- .../test_move_dataflow_into_if_body.py | 133 +++++++----------- 2 files changed, 122 insertions(+), 99 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index f421778051..72ef6d5129 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -219,6 +219,14 @@ def apply( enclosing_map=enclosing_map, ) + conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]] = {} + for conn_name in raw_relocatable_dataflow.keys(): + conn_name_to_access_node_map[conn_name] = self._find_branch_for( + if_block=if_block, + connector=conn_name, + ) + + already_moved_nodes: set[dace_nodes.Node] = set() # Finally relocate the dataflow for conn_name, nodes_to_move in relocatable_dataflow.items(): self._replicate_dataflow_into_branch( @@ -228,8 +236,18 @@ def apply( enclosing_map=enclosing_map, nodes_to_move=nodes_to_move, connector=conn_name, + conn_name_to_access_node_map=conn_name_to_access_node_map, + already_moved_nodes=already_moved_nodes, ) - + already_moved_nodes.update(nodes_to_move) + + # for conn_name, nodes_to_move in relocatable_dataflow.items(): + # if len(nodes_to_move) == 0: + # continue + # print(f"Making array '{conn_name}' transient and removing connector.", flush=True) + # if_block.sdfg.arrays[conn_name].transient = True + # if_block.remove_in_connector(conn_name) + # breakpoint() self._update_symbol_mapping(if_block, sdfg) self._remove_outside_dataflow( @@ -255,6 +273,8 @@ def _replicate_dataflow_into_branch( enclosing_map: dace_nodes.MapEntry, nodes_to_move: set[dace_nodes.Node], connector: str, + conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], + already_moved_nodes: set[dace_nodes.Node], ) -> None: """Replicate the dataflow in `nodes_to_move` from `state` into `if_block`. @@ -312,17 +332,34 @@ def _replicate_dataflow_into_branch( # branch state, these are the outgoing edges. The data dependencies of the # nodes that were not relocated are still missing. for node in nodes_to_move: + if node in already_moved_nodes: + continue for oedge in state.out_edges(node): if oedge.dst is if_block: - assert oedge.dst_conn == connector - # TODO(phimuell): Make subsets complete. - branch_state.add_edge( - new_nodes[oedge.src], - oedge.src_conn, - connector_node, - None, - dace.Memlet.from_memlet(oedge.data), - ) + if oedge.dst_conn == connector: + # TODO(phimuell): Make subsets complete. + branch_state.add_edge( + new_nodes[oedge.src], + oedge.src_conn, + connector_node, + None, + dace.Memlet.from_memlet(oedge.data), + ) + else: + assert oedge.dst_conn in conn_name_to_access_node_map + assert branch_state == conn_name_to_access_node_map[oedge.dst_conn][0] + branch_state.add_edge( + new_nodes[oedge.src], + oedge.src_conn, + conn_name_to_access_node_map[oedge.dst_conn][1], + None, + dace.Memlet.from_memlet(oedge.data), + ) + # If this is an existing node that is a connection as well + existing_connector = conn_name_to_access_node_map[oedge.dst_conn][1].data + if not inner_sdfg.arrays[existing_connector].transient: + inner_sdfg.arrays[existing_connector].transient = True + if_block.remove_in_connector(existing_connector) else: assert oedge.dst in nodes_to_move branch_state.add_edge( @@ -676,21 +713,34 @@ def _filter_relocatable_dataflow( conn_name: rel_df.difference(all_non_relocatable_dataflow) for conn_name, rel_df in raw_relocatable_dataflow.items() } + # if if_block.label == "if_stmt_100": + # breakpoint() + + # Find the known_nodes for each branch + known_nodes: dict[dace.SDFGState, set[dace_nodes.Node]] = dict() + for conn_name, rel_df in relocatable_dataflow.items(): + branch_state, _ = self._find_branch_for(if_block=if_block, connector=conn_name) + if branch_state not in known_nodes: + known_nodes[branch_state] = set() + known_nodes[branch_state].update(rel_df) + assert len(known_nodes) == 2 - # Relocating nodes that are in more than one set is difficult. In the most - # common case of just two branches, this anyway means they have to be - # executed in any case. Thus we remove them now. - known_nodes: set[dace_nodes.Node] = set() multiple_df_nodes: set[dace_nodes.Node] = set() - for rel_df in relocatable_dataflow.values(): - seen_before: set[dace_nodes.Node] = known_nodes.intersection(rel_df) - if len(seen_before) != 0: - multiple_df_nodes.update(seen_before) - known_nodes.update(rel_df) + # find intersect of all known_nodes sets which are the nodes that are in the dataflow + # of multiple branches and thus doesn't make sense to relocate + for branch_state, known_nodes_set in known_nodes.items(): + for other_branch_state, other_known_nodes_set in known_nodes.items(): + if branch_state == other_branch_state: + continue + multiple_df_nodes.update(known_nodes_set.intersection(other_known_nodes_set)) + + # Update the relocatable dataflow by removing the nodes that are in the dataflow of + # multiple branches relocatable_dataflow = { conn_name: rel_df.difference(multiple_df_nodes) for conn_name, rel_df in relocatable_dataflow.items() } + # breakpoint() # TODO(phimuell): If we operate outside of a Map we also have to make sure that # the data is single use data, is not an AccessNode that refers to global diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 0d369f7d40..d2297ec85d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -84,7 +84,8 @@ def _make_if_block_with_two_args( b3_name: str = "__arg3", b4_name: str = "__arg4", cond_name: str = "__cond", - output_name: str = "__output", + output1_name: str = "__output1", + output2_name: str = "__output2", b1_type: dace.typeclass = dace.float64, b2_type: dace.typeclass = dace.float64, b3_type: dace.typeclass = dace.float64, @@ -93,8 +94,8 @@ def _make_if_block_with_two_args( ) -> dace_nodes.NestedSDFG: inner_sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_stmt_")) - types = {b1_name: b1_type, b2_name: b2_type, b3_name: b3_type, b4_name: b4_type, cond_name: dace.bool_, output_name: output_type} - for name in {b1_name, b2_name, b3_name, b4_name, cond_name, output_name}: + types = {b1_name: b1_type, b2_name: b2_type, b3_name: b3_type, b4_name: b4_type, cond_name: dace.bool_, output1_name: output_type, output2_name: output_type} + for name in {b1_name, b2_name, b3_name, b4_name, cond_name, output1_name, output2_name}: inner_sdfg.add_scalar( name, dtype=types[name], @@ -106,62 +107,28 @@ def _make_if_block_with_two_args( then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) tstate = then_body.add_state("true_branch_0_1_2_3_4", is_start_block=True) - tasklet_true = tstate.add_tasklet( - "true_tasklet", - inputs={"__tasklet_in1", "__tasklet_in2"}, - outputs={"__tasklet_out"}, - code="__tasklet_out = __tasklet_in1 + __tasklet_in2", - ) - tstate.add_edge( + tstate.add_nedge( tstate.add_access(b1_name), - None, - tasklet_true, - "__tasklet_in1", - dace.Memlet(f"{b1_name}[0]"), + tstate.add_access(output1_name), + dace.Memlet(f"{b1_name}[0] -> [0]"), ) - tstate.add_edge( + tstate.add_nedge( tstate.add_access(b2_name), - None, - tasklet_true, - "__tasklet_in2", - dace.Memlet(f"{b2_name}[0]"), - ) - tstate.add_edge( - tasklet_true, - "__tasklet_out", - tstate.add_access(output_name), - None, - dace.Memlet(f"{output_name}[0]"), + tstate.add_access(output2_name), + dace.Memlet(f"{b2_name}[0] -> [0]"), ) else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=inner_sdfg) fstate = else_body.add_state("false_branch_0_1_2_3_4", is_start_block=True) - tasklet_false = fstate.add_tasklet( - "false_tasklet", - inputs={"__tasklet_in1", "__tasklet_in2"}, - outputs={"__tasklet_out"}, - code="__tasklet_out = __tasklet_in1 - __tasklet_in2", - ) - fstate.add_edge( + fstate.add_nedge( fstate.add_access(b3_name), - None, - tasklet_false, - "__tasklet_in1", - dace.Memlet(f"{b3_name}[0]"), + fstate.add_access(output1_name), + dace.Memlet(f"{b3_name}[0] -> [0]"), ) - fstate.add_edge( + fstate.add_nedge( fstate.add_access(b4_name), - None, - tasklet_false, - "__tasklet_in2", - dace.Memlet(f"{b4_name}[0]"), - ) - fstate.add_edge( - tasklet_false, - "__tasklet_out", - fstate.add_access(output_name), - None, - dace.Memlet(f"{output_name}[0]"), + fstate.add_access(output2_name), + dace.Memlet(f"{b4_name}[0] -> [0]"), ) if_region.add_branch(dace.sdfg.state.CodeBlock(cond_name), then_body) @@ -170,7 +137,7 @@ def _make_if_block_with_two_args( return state.add_nested_sdfg( sdfg=inner_sdfg, inputs={b1_name, b2_name, b3_name, b4_name, cond_name}, - outputs={output_name}, + outputs={output1_name, output2_name}, ) @@ -783,15 +750,17 @@ def test_if_mover_dependent_branch_4(): e = qux(...) if c: d = a + b + f = a else: d = e + s + f = e ``` """ sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_dependent_branches")) state = sdfg.add_state(is_start_block=True) # Inputs - input_names = ["a", "b", "c", "d", "e", "s"] + input_names = ["a", "b", "c", "d", "e", "f", "s"] for name in input_names: sdfg.add_array( name, @@ -801,13 +770,13 @@ def test_if_mover_dependent_branch_4(): ) # Temporaries - temporary_names = ["a1", "a2", "b1", "b2", "c1", "s1", "e1"] + temporary_names = ["a1", "a2", "b1", "b2", "c1", "s1"] for name in temporary_names: sdfg.add_scalar( name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True ) - a1, a2, b1, b2, c1, s1, e1 = (state.add_access(name) for name in temporary_names) + a1, a2, b1, b2, c1, s1 = (state.add_access(name) for name in temporary_names) me, mx = state.add_map("comp", ndrange={"__i": "0:10"}) # The auxiliary computation involving `s`: @@ -820,7 +789,6 @@ def test_if_mover_dependent_branch_4(): state.add_edge(tasklet_s1, "__out", s1, None, dace.Memlet("s1[0]")) state.add_edge(state.add_access("e"), None, me, "IN_e", dace.Memlet("e[0:10]")) - state.add_edge(me, "OUT_e", e1, None, dace.Memlet("e1[0]")) # Computation involving `a`: tasklet_a1 = state.add_tasklet( @@ -873,52 +841,57 @@ def test_if_mover_dependent_branch_4(): if_block = _make_if_block_with_two_args(state=state, outer_sdfg=sdfg) state.add_edge(a2, None, if_block, "__arg1", dace.Memlet("a2[0]")) state.add_edge(b2, None, if_block, "__arg2", dace.Memlet("b2[0]")) - state.add_edge(e1, None, if_block, "__arg3", dace.Memlet("e1[0]")) + state.add_edge(me, "OUT_e", if_block, "__arg3", dace.Memlet("e[__i]")) state.add_edge(s1, None, if_block, "__arg4", dace.Memlet("s1[0]")) state.add_edge(c1, None, if_block, "__cond", dace.Memlet("c1[0]")) # Now handle the output. - state.add_edge(if_block, "__output", mx, "IN_d", dace.Memlet("d[__i]")) + state.add_edge(if_block, "__output1", mx, "IN_d", dace.Memlet("d[__i]")) + state.add_edge(if_block, "__output2", mx, "IN_f", dace.Memlet("f[__i]")) state.add_edge(mx, "OUT_d", state.add_access("d"), None, dace.Memlet("d[0:10]")) + state.add_edge(mx, "OUT_f", state.add_access("f"), None, dace.Memlet("f[0:10]")) # Now add the connectors to the Map* for iname in input_names: - if iname == "d": + if iname == "d" or iname == "f": continue me.add_in_connector(f"IN_{iname}") me.add_out_connector(f"OUT_{iname}") mx.add_in_connector("IN_d") mx.add_out_connector("OUT_d") + mx.add_in_connector("IN_f") + mx.add_out_connector("OUT_f") sdfg.validate() - sdfg.view() - breakpoint() + # sdfg.view() + # breakpoint() _perform_test(sdfg, explected_applies=1) - sdfg.view() - breakpoint() + # sdfg.view() + # breakpoint() # # Examine the structure of the SDFG. - # top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) - # assert {ac.data for ac in top_ac} == set(input_names).union(["c1", "s1"]) - # assert len(sdfg.arrays) == len(top_ac) - - # top_tlet: list[dace_nodes.Tasklet] = util.count_nodes(state, dace_nodes.Tasklet, True) - # assert len(top_tlet) == 2 - # assert {"tasklet_cond", "tasklet_s1"} == {tlet.label for tlet in top_tlet} - - # inner_ac: list[dace_nodes.AccessNode] = util.count_nodes( - # if_block.sdfg, dace_nodes.AccessNode, True - # ) - # expected_data: set[str] = ( - # set(temporary_names).union(input_names).union(["__arg1", "__arg2", "__output"]) - # ) - # expected_data.difference_update(["c1", "c", "d", "s"]) - # assert expected_data == {ac.data for ac in inner_ac} - # assert len([ac for ac in inner_ac if ac.data == "s1"]) == 2 - # assert len([ac for ac in inner_ac if ac.data == "__output"]) == 2 - # assert len(expected_data) + 2 == len(inner_ac) + top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) + assert {ac.data for ac in top_ac} == set(input_names).union(["c1", "s1"]) + assert len(sdfg.arrays) == len(top_ac) + + top_tlet: list[dace_nodes.Tasklet] = util.count_nodes(state, dace_nodes.Tasklet, True) + assert len(top_tlet) == 2 + assert {"tasklet_cond", "tasklet_s1"} == {tlet.label for tlet in top_tlet} + + inner_ac: list[dace_nodes.AccessNode] = util.count_nodes( + if_block.sdfg, dace_nodes.AccessNode, True + ) + expected_data: set[str] = ( + set(temporary_names).union(input_names).union(["__arg1", "__arg2", "__arg3", "__arg4", "__output1", "__output2"]) + ) + expected_data.difference_update(["c1", "c", "d", "e", "f", "s"]) + assert expected_data == {ac.data for ac in inner_ac} + assert len([ac for ac in inner_ac if ac.data == "s1"]) == 1 + assert len([ac for ac in inner_ac if ac.data == "__output1"]) == 2 + assert len([ac for ac in inner_ac if ac.data == "__output2"]) == 2 + assert len(expected_data) + 2 == len(inner_ac) # assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) # inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) From d7d05bd843e85b6302e1f2d7186214565e12c7f2 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 5 Mar 2026 16:25:44 +0100 Subject: [PATCH 03/32] Remove commented code and added explanations --- .../move_dataflow_into_if_body.py | 30 +++++++++++-------- .../test_move_dataflow_into_if_body.py | 16 ++++------ 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 72ef6d5129..f6e97f497c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -219,6 +219,11 @@ def apply( enclosing_map=enclosing_map, ) + # Create a mapping from connector names to the corresponding AccessNode inside the branch and the branch state. + # This is necessary to properly patch the dataflow inside the branch, i.e. to connect it to the global AccessNode + # corresponding to the connector if necessary. We need to gather this information for all the connectors because + # a node in the dataflow of one connector might be the global AccessNode of another connector and we have to handle + # it properly in `_replicate_dataflow_into_branch`. conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]] = {} for conn_name in raw_relocatable_dataflow.keys(): conn_name_to_access_node_map[conn_name] = self._find_branch_for( @@ -226,6 +231,7 @@ def apply( connector=conn_name, ) + # Gather all the already moved nodes to avoid that we move the same node multiple times already_moved_nodes: set[dace_nodes.Node] = set() # Finally relocate the dataflow for conn_name, nodes_to_move in relocatable_dataflow.items(): @@ -241,13 +247,6 @@ def apply( ) already_moved_nodes.update(nodes_to_move) - # for conn_name, nodes_to_move in relocatable_dataflow.items(): - # if len(nodes_to_move) == 0: - # continue - # print(f"Making array '{conn_name}' transient and removing connector.", flush=True) - # if_block.sdfg.arrays[conn_name].transient = True - # if_block.remove_in_connector(conn_name) - # breakpoint() self._update_symbol_mapping(if_block, sdfg) self._remove_outside_dataflow( @@ -295,6 +294,12 @@ def _replicate_dataflow_into_branch( enclosing_map: The enclosing map. nodes_to_move: The list of nodes that should be removed. connector: The connector that should be inlined. + conn_name_to_access_node_map: A mapping from connector names to the + corresponding AccessNode inside the branch and the branch state. + already_moved_nodes: The set of nodes that have already been moved, this is + needed to avoid that we move the same node multiple times, which can + happen if there are multiple connectors whose dataflow we want to + move and they have some nodes in common. """ # Nothing to relocate nothing to do. if len(nodes_to_move) == 0: @@ -348,6 +353,9 @@ def _replicate_dataflow_into_branch( else: assert oedge.dst_conn in conn_name_to_access_node_map assert branch_state == conn_name_to_access_node_map[oedge.dst_conn][0] + # Some of the nodes_to_move are also in the dataflow of some other connector whose dataflow + # we should also move. Since we already moved the dataflow of the other connector we connect + # the dataflow to the global AccessNode corresponding to the other connector branch_state.add_edge( new_nodes[oedge.src], oedge.src_conn, @@ -355,7 +363,7 @@ def _replicate_dataflow_into_branch( None, dace.Memlet.from_memlet(oedge.data), ) - # If this is an existing node that is a connection as well + # Handle the other connector AccessNode and connector existing_connector = conn_name_to_access_node_map[oedge.dst_conn][1].data if not inner_sdfg.arrays[existing_connector].transient: inner_sdfg.arrays[existing_connector].transient = True @@ -713,8 +721,6 @@ def _filter_relocatable_dataflow( conn_name: rel_df.difference(all_non_relocatable_dataflow) for conn_name, rel_df in raw_relocatable_dataflow.items() } - # if if_block.label == "if_stmt_100": - # breakpoint() # Find the known_nodes for each branch known_nodes: dict[dace.SDFGState, set[dace_nodes.Node]] = dict() @@ -726,7 +732,7 @@ def _filter_relocatable_dataflow( assert len(known_nodes) == 2 multiple_df_nodes: set[dace_nodes.Node] = set() - # find intersect of all known_nodes sets which are the nodes that are in the dataflow + # Find intersect of all known_nodes sets which are the nodes that are in the dataflow # of multiple branches and thus doesn't make sense to relocate for branch_state, known_nodes_set in known_nodes.items(): for other_branch_state, other_known_nodes_set in known_nodes.items(): @@ -740,7 +746,6 @@ def _filter_relocatable_dataflow( conn_name: rel_df.difference(multiple_df_nodes) for conn_name, rel_df in relocatable_dataflow.items() } - # breakpoint() # TODO(phimuell): If we operate outside of a Map we also have to make sure that # the data is single use data, is not an AccessNode that refers to global @@ -886,6 +891,7 @@ def _partition_if_block( for conn_name in reference_count.keys() if conn_name not in relocatable_connectors } + if len(non_relocatable_connectors) == 0: return None if len(relocatable_connectors) == 0: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index d2297ec85d..454bb2dde9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -863,14 +863,8 @@ def test_if_mover_dependent_branch_4(): mx.add_out_connector("OUT_f") sdfg.validate() - # sdfg.view() - # breakpoint() - _perform_test(sdfg, explected_applies=1) - # sdfg.view() - # breakpoint() - # # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) assert {ac.data for ac in top_ac} == set(input_names).union(["c1", "s1"]) @@ -892,12 +886,12 @@ def test_if_mover_dependent_branch_4(): assert len([ac for ac in inner_ac if ac.data == "__output1"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output2"]) == 2 assert len(expected_data) + 2 == len(inner_ac) - # assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) + assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) - # inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) - # assert len(inner_tlet) == 4 - # expected_tlet = {tlet.label for tlet in [tasklet_a1, tasklet_a2, tasklet_b1, tasklet_b2]} - # assert {tlet.label for tlet in inner_tlet} == expected_tlet + inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) + assert len(inner_tlet) == 4 + expected_tlet = {tlet.label for tlet in [tasklet_a1, tasklet_a2, tasklet_b1, tasklet_b2]} + assert {tlet.label for tlet in inner_tlet} == expected_tlet def test_if_mover_no_ops(): From 2b7a501cdf9cb7360e2d4bdc42143f14b1604662 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 5 Mar 2026 16:28:37 +0100 Subject: [PATCH 04/32] Make formatting happy --- .../test_move_dataflow_into_if_body.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 454bb2dde9..d30d6cc207 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -94,7 +94,15 @@ def _make_if_block_with_two_args( ) -> dace_nodes.NestedSDFG: inner_sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_stmt_")) - types = {b1_name: b1_type, b2_name: b2_type, b3_name: b3_type, b4_name: b4_type, cond_name: dace.bool_, output1_name: output_type, output2_name: output_type} + types = { + b1_name: b1_type, + b2_name: b2_type, + b3_name: b3_type, + b4_name: b4_type, + cond_name: dace.bool_, + output1_name: output_type, + output2_name: output_type, + } for name in {b1_name, b2_name, b3_name, b4_name, cond_name, output1_name, output2_name}: inner_sdfg.add_scalar( name, @@ -809,7 +817,10 @@ def test_if_mover_dependent_branch_4(): # Computation involving `b`: tasklet_b1 = state.add_tasklet( - "tasklet_b1", inputs={"__in1", "__in2"}, outputs={"__out"}, code="__out = math.sin(__in1) * math.cos(__in2)" + "tasklet_b1", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code="__out = math.sin(__in1) * math.cos(__in2)", ) tasklet_b2 = state.add_tasklet( "tasklet_b2", @@ -878,7 +889,9 @@ def test_if_mover_dependent_branch_4(): if_block.sdfg, dace_nodes.AccessNode, True ) expected_data: set[str] = ( - set(temporary_names).union(input_names).union(["__arg1", "__arg2", "__arg3", "__arg4", "__output1", "__output2"]) + set(temporary_names) + .union(input_names) + .union(["__arg1", "__arg2", "__arg3", "__arg4", "__output1", "__output2"]) ) expected_data.difference_update(["c1", "c", "d", "e", "f", "s"]) assert expected_data == {ac.data for ac in inner_ac} From cd6f596e54f6ac6f13bbb4b1d2dcb010a3a2ae29 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 5 Mar 2026 17:13:51 +0100 Subject: [PATCH 05/32] We can also have only a false/true branch instead of both --- .../runners/dace/transformations/move_dataflow_into_if_body.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index f6e97f497c..bbe74c8fa6 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -729,7 +729,7 @@ def _filter_relocatable_dataflow( if branch_state not in known_nodes: known_nodes[branch_state] = set() known_nodes[branch_state].update(rel_df) - assert len(known_nodes) == 2 + assert len(known_nodes) <= 2 multiple_df_nodes: set[dace_nodes.Node] = set() # Find intersect of all known_nodes sets which are the nodes that are in the dataflow From 3ab44a0958822fcc94d2ad54f83261bd8d4438e8 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Fri, 6 Mar 2026 13:13:13 +0100 Subject: [PATCH 06/32] Address review comments --- .../dace/transformations/move_dataflow_into_if_body.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index bbe74c8fa6..c3cdb81b5a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -729,7 +729,6 @@ def _filter_relocatable_dataflow( if branch_state not in known_nodes: known_nodes[branch_state] = set() known_nodes[branch_state].update(rel_df) - assert len(known_nodes) <= 2 multiple_df_nodes: set[dace_nodes.Node] = set() # Find intersect of all known_nodes sets which are the nodes that are in the dataflow @@ -740,8 +739,7 @@ def _filter_relocatable_dataflow( continue multiple_df_nodes.update(known_nodes_set.intersection(other_known_nodes_set)) - # Update the relocatable dataflow by removing the nodes that are in the dataflow of - # multiple branches + # Remove from the relocatable dataflow the nodes that appear in multiple branches relocatable_dataflow = { conn_name: rel_df.difference(multiple_df_nodes) for conn_name, rel_df in relocatable_dataflow.items() From 16926ce651c8731ccfc32374f16a29880cd5845f Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 9 Mar 2026 12:21:34 +0100 Subject: [PATCH 07/32] Improve new node creation --- .../move_dataflow_into_if_body.py | 75 ++++++++++--------- .../test_move_dataflow_into_if_body.py | 26 +++++-- 2 files changed, 59 insertions(+), 42 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index c3cdb81b5a..813551d01e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -231,8 +231,12 @@ def apply( connector=conn_name, ) - # Gather all the already moved nodes to avoid that we move the same node multiple times - already_moved_nodes: set[dace_nodes.Node] = set() + # Create a map of the old to the new nodes to keep track of the old nodes copied + # and their corresponding new nodes. The map should be per branch of the ConditionalBlock + # because there could be a node that has to be copied in all branches. Since the + # `relocatable_dataflow` nodes are not disjoint we need this mapping to avoid copying the same + # node multiple times and to properly connect the copied nodes. + old_to_new_nodes_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() # Finally relocate the dataflow for conn_name, nodes_to_move in relocatable_dataflow.items(): self._replicate_dataflow_into_branch( @@ -243,9 +247,8 @@ def apply( nodes_to_move=nodes_to_move, connector=conn_name, conn_name_to_access_node_map=conn_name_to_access_node_map, - already_moved_nodes=already_moved_nodes, + old_to_new_nodes_map=old_to_new_nodes_map, ) - already_moved_nodes.update(nodes_to_move) self._update_symbol_mapping(if_block, sdfg) @@ -273,7 +276,7 @@ def _replicate_dataflow_into_branch( nodes_to_move: set[dace_nodes.Node], connector: str, conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], - already_moved_nodes: set[dace_nodes.Node], + old_to_new_nodes_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node], ) -> None: """Replicate the dataflow in `nodes_to_move` from `state` into `if_block`. @@ -296,28 +299,38 @@ def _replicate_dataflow_into_branch( connector: The connector that should be inlined. conn_name_to_access_node_map: A mapping from connector names to the corresponding AccessNode inside the branch and the branch state. - already_moved_nodes: The set of nodes that have already been moved, this is - needed to avoid that we move the same node multiple times, which can - happen if there are multiple connectors whose dataflow we want to - move and they have some nodes in common. + old_to_new_nodes_map: A mapping from the old nodes to the new nodes. + The keys of the mapping are tuples of the old node and the branch + state for which the new node was created. The values are the new nodes. """ # Nothing to relocate nothing to do. if len(nodes_to_move) == 0: return inner_sdfg: dace.SDFG = if_block.sdfg - branch_state, connector_node = self._find_branch_for( - if_block=if_block, - connector=connector, - ) + branch_state, connector_node = conn_name_to_access_node_map[connector] + + # Replicate the nodes and store them in the `old_to_new_nodes_map` mapping. + # Add the SDFGState to the key of the dictionary because we have to create + # new node for the different branches. + unique_old_nodes: list[dace_nodes.Node] = [] + for old_node in nodes_to_move: + if (old_node, branch_state) in old_to_new_nodes_map: + continue + unique_old_nodes.append(old_node) + copy_of_old_node = copy.deepcopy(old_node) + old_to_new_nodes_map[(old_node, branch_state)] = copy_of_old_node + branch_state.add_node(copy_of_old_node) # There might be AccessNodes inside `nodes_to_move`, we now have to make sure # that they are present inside the nested ones. By our base assumption they - # are transients, because they are only used in one place + # are transients, because they are only used in one place. Make sure that we + # don't add new nodes if the nodes are already in the arrays of the inner SDFG. for node in nodes_to_move: if not isinstance(node, dace_nodes.AccessNode): continue - assert node.data not in inner_sdfg.arrays + if node.data in inner_sdfg.arrays: + continue assert sdfg.arrays[node.data].transient # TODO(phimuell): Handle the case we need to rename something. inner_sdfg.add_datadesc( @@ -326,25 +339,16 @@ def _replicate_dataflow_into_branch( find_new_name=False, ) - # Replicate the nodes. Also make a mapping that allows to map the old ones - # to the new ones. - new_nodes: dict[dace_nodes.Node, dace_nodes.Node] = { - old_node: copy.deepcopy(old_node) for old_node in nodes_to_move - } - branch_state.add_nodes_from(new_nodes.values()) - # Now add the edges between the edges that have been replicated inside the # branch state, these are the outgoing edges. The data dependencies of the # nodes that were not relocated are still missing. - for node in nodes_to_move: - if node in already_moved_nodes: - continue + for node in unique_old_nodes: for oedge in state.out_edges(node): if oedge.dst is if_block: if oedge.dst_conn == connector: # TODO(phimuell): Make subsets complete. branch_state.add_edge( - new_nodes[oedge.src], + old_to_new_nodes_map[(oedge.src, branch_state)], oedge.src_conn, connector_node, None, @@ -357,7 +361,7 @@ def _replicate_dataflow_into_branch( # we should also move. Since we already moved the dataflow of the other connector we connect # the dataflow to the global AccessNode corresponding to the other connector branch_state.add_edge( - new_nodes[oedge.src], + old_to_new_nodes_map[(oedge.src, branch_state)], oedge.src_conn, conn_name_to_access_node_map[oedge.dst_conn][1], None, @@ -371,9 +375,9 @@ def _replicate_dataflow_into_branch( else: assert oedge.dst in nodes_to_move branch_state.add_edge( - new_nodes[oedge.src], + old_to_new_nodes_map[(oedge.src, branch_state)], oedge.src_conn, - new_nodes[oedge.dst], + old_to_new_nodes_map[(oedge.dst, branch_state)], oedge.dst_conn, dace.Memlet.from_memlet(oedge.data), ) @@ -382,7 +386,7 @@ def _replicate_dataflow_into_branch( # could not have been moved inside `if_block` but are still needed to compute # the final result. We find them by scanning the input edges of the nodes # that have been relocated. - for node in nodes_to_move: + for node in unique_old_nodes: for iedge in state.in_edges(node): if iedge.src in nodes_to_move: # Inner data dependency, there is nothing to do and the edge was @@ -442,23 +446,24 @@ def _replicate_dataflow_into_branch( # phase, see `_remove_outside_dataflow()`. pass - if outer_data not in new_nodes: + if (outer_data, branch_state) not in old_to_new_nodes_map: assert all( outer_data.data != mapped_node.data - for mapped_node in new_nodes.values() + for mapped_node, mapped_branch_state in old_to_new_nodes_map.keys() if isinstance(mapped_node, dace_nodes.AccessNode) + and mapped_branch_state == branch_state ) assert outer_data.data in inner_sdfg.arrays assert not inner_sdfg.arrays[outer_data.data].transient - new_nodes[outer_data] = branch_state.add_access( + old_to_new_nodes_map[(outer_data, branch_state)] = branch_state.add_access( outer_data.data, copy.copy(outer_data.debuginfo) ) # Now create the edge in the inner state. branch_state.add_edge( - new_nodes[outer_data], + old_to_new_nodes_map[(outer_data, branch_state)], None, - new_nodes[iedge.dst], + old_to_new_nodes_map[(iedge.dst, branch_state)], iedge.dst_conn, copy.deepcopy(iedge.data), ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index d30d6cc207..d1f4c9dbc9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -778,13 +778,13 @@ def test_if_mover_dependent_branch_4(): ) # Temporaries - temporary_names = ["a1", "a2", "b1", "b2", "c1", "s1"] + temporary_names = ["a1", "a2", "a3", "b1", "b2", "c1", "s1"] for name in temporary_names: sdfg.add_scalar( name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True ) - a1, a2, b1, b2, c1, s1 = (state.add_access(name) for name in temporary_names) + a1, a2, a3, b1, b2, c1, s1 = (state.add_access(name) for name in temporary_names) me, mx = state.add_map("comp", ndrange={"__i": "0:10"}) # The auxiliary computation involving `s`: @@ -848,11 +848,21 @@ def test_if_mover_dependent_branch_4(): state.add_edge(me, "OUT_c", tasklet_cond, "__in", dace.Memlet("c[__i]")) state.add_edge(tasklet_cond, "__out", c1, None, dace.Memlet("c1[0]")) + tasklet_node_reuse = state.add_tasklet( + "tasklet_node_reuse", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code="__out = __in1 + __in2", + ) + state.add_edge(me, "OUT_a", tasklet_node_reuse, "__in1", dace.Memlet("a[__i]")) + state.add_edge(me, "OUT_e", tasklet_node_reuse, "__in2", dace.Memlet("e[__i]")) + state.add_edge(tasklet_node_reuse, "__out", a3, None, dace.Memlet("a3[0]")) + # Make the if selection. if_block = _make_if_block_with_two_args(state=state, outer_sdfg=sdfg) state.add_edge(a2, None, if_block, "__arg1", dace.Memlet("a2[0]")) state.add_edge(b2, None, if_block, "__arg2", dace.Memlet("b2[0]")) - state.add_edge(me, "OUT_e", if_block, "__arg3", dace.Memlet("e[__i]")) + state.add_edge(a3, None, if_block, "__arg3", dace.Memlet("a3[0]")) state.add_edge(s1, None, if_block, "__arg4", dace.Memlet("s1[0]")) state.add_edge(c1, None, if_block, "__cond", dace.Memlet("c1[0]")) @@ -893,17 +903,19 @@ def test_if_mover_dependent_branch_4(): .union(input_names) .union(["__arg1", "__arg2", "__arg3", "__arg4", "__output1", "__output2"]) ) - expected_data.difference_update(["c1", "c", "d", "e", "f", "s"]) + expected_data.difference_update(["c1", "c", "d", "f", "s"]) assert expected_data == {ac.data for ac in inner_ac} assert len([ac for ac in inner_ac if ac.data == "s1"]) == 1 assert len([ac for ac in inner_ac if ac.data == "__output1"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output2"]) == 2 - assert len(expected_data) + 2 == len(inner_ac) + assert len(expected_data) + 3 == len(inner_ac) assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) - assert len(inner_tlet) == 4 - expected_tlet = {tlet.label for tlet in [tasklet_a1, tasklet_a2, tasklet_b1, tasklet_b2]} + assert len(inner_tlet) == 5 + expected_tlet = { + tlet.label for tlet in [tasklet_a1, tasklet_a2, tasklet_b1, tasklet_b2, tasklet_node_reuse] + } assert {tlet.label for tlet in inner_tlet} == expected_tlet From 89bf52275c1f4292ba3bbf836e49b338294e7fd2 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 9 Mar 2026 13:17:13 +0100 Subject: [PATCH 08/32] Addressed Philip's comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philip Müller --- .../move_dataflow_into_if_body.py | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 813551d01e..ddc2f96872 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -225,7 +225,7 @@ def apply( # a node in the dataflow of one connector might be the global AccessNode of another connector and we have to handle # it properly in `_replicate_dataflow_into_branch`. conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]] = {} - for conn_name in raw_relocatable_dataflow.keys(): + for conn_name in relocatable_dataflow.keys(): conn_name_to_access_node_map[conn_name] = self._find_branch_for( if_block=if_block, connector=conn_name, @@ -324,8 +324,9 @@ def _replicate_dataflow_into_branch( # There might be AccessNodes inside `nodes_to_move`, we now have to make sure # that they are present inside the nested ones. By our base assumption they - # are transients, because they are only used in one place. Make sure that we - # don't add new nodes if the nodes are already in the arrays of the inner SDFG. + # are transients and single use, because they are only used in one place. Make + # sure that we don't add new nodes if the nodes are already in the arrays of + # the inner SDFG. for node in nodes_to_move: if not isinstance(node, dace_nodes.AccessNode): continue @@ -340,13 +341,17 @@ def _replicate_dataflow_into_branch( ) # Now add the edges between the edges that have been replicated inside the - # branch state, these are the outgoing edges. The data dependencies of the - # nodes that were not relocated are still missing. + # branch state, these are the outgoing edges. + # Now add the outgoing edges between the replicated nodes inside the branch state. + # The data dependencies (incomming edges) of the nodes, i.e. the not relocated dataflow, + # are still missing. for node in unique_old_nodes: for oedge in state.out_edges(node): if oedge.dst is if_block: if oedge.dst_conn == connector: + # This connection maps the outside data into the nested SDFG, thus its destination is technically the same data. # TODO(phimuell): Make subsets complete. + # TODO(phimuell): Check if this Memlet is always correct, especially in case of slicing. branch_state.add_edge( old_to_new_nodes_map[(oedge.src, branch_state)], oedge.src_conn, @@ -356,7 +361,6 @@ def _replicate_dataflow_into_branch( ) else: assert oedge.dst_conn in conn_name_to_access_node_map - assert branch_state == conn_name_to_access_node_map[oedge.dst_conn][0] # Some of the nodes_to_move are also in the dataflow of some other connector whose dataflow # we should also move. Since we already moved the dataflow of the other connector we connect # the dataflow to the global AccessNode corresponding to the other connector @@ -368,11 +372,14 @@ def _replicate_dataflow_into_branch( dace.Memlet.from_memlet(oedge.data), ) # Handle the other connector AccessNode and connector - existing_connector = conn_name_to_access_node_map[oedge.dst_conn][1].data - if not inner_sdfg.arrays[existing_connector].transient: - inner_sdfg.arrays[existing_connector].transient = True - if_block.remove_in_connector(existing_connector) + inner_access_node_of_connector_name = conn_name_to_access_node_map[ + oedge.dst_conn + ][1].data + if not inner_sdfg.arrays[inner_access_node_of_connector_name].transient: + inner_sdfg.arrays[inner_access_node_of_connector_name].transient = True + if_block.remove_in_connector(inner_access_node_of_connector_name) else: + # Restore a connection between two nodes that were relocated. assert oedge.dst in nodes_to_move branch_state.add_edge( old_to_new_nodes_map[(oedge.src, branch_state)], @@ -740,15 +747,16 @@ def _filter_relocatable_dataflow( # of multiple branches and thus doesn't make sense to relocate for branch_state, known_nodes_set in known_nodes.items(): for other_branch_state, other_known_nodes_set in known_nodes.items(): - if branch_state == other_branch_state: - continue - multiple_df_nodes.update(known_nodes_set.intersection(other_known_nodes_set)) - - # Remove from the relocatable dataflow the nodes that appear in multiple branches - relocatable_dataflow = { - conn_name: rel_df.difference(multiple_df_nodes) - for conn_name, rel_df in relocatable_dataflow.items() - } + if branch_state != other_branch_state: + multiple_df_nodes.update(known_nodes_set.intersection(other_known_nodes_set)) + + if multiple_df_nodes: + # Remove from the relocatable dataflow the nodes that appear in multiple branches + # as it doesn't make sense to relocate them and duplicate them in both branches. + relocatable_dataflow = { + conn_name: rel_df.difference(multiple_df_nodes) + for conn_name, rel_df in relocatable_dataflow.items() + } # TODO(phimuell): If we operate outside of a Map we also have to make sure that # the data is single use data, is not an AccessNode that refers to global @@ -836,6 +844,7 @@ def _partition_if_block( two sets of strings. The first set contains the connectors that can be relocated and the second one of the conditions that can not be relocated. """ + # TODO(phimuell): Change the return type to `tuple[list[str], list[str]]` and sort the connectors, such that the operation is deterministic. # There shall only be one output and three inputs with given names. if len(if_block.out_connectors.keys()) == 0: return None From a609bc5eb0e38650dfd3821022f4fadd7aefc6c7 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 9 Mar 2026 14:53:43 +0100 Subject: [PATCH 09/32] Updated docstring --- .../dace/transformations/move_dataflow_into_if_body.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index ddc2f96872..61e352a5f6 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -712,7 +712,10 @@ def _filter_relocatable_dataflow( the parts that actually can be relocated and returns a `dict` mapping every relocatable input connector to the set of nodes that can be relocated. - Note that the sets that are returned by this function are distinct. + The returned sets can include duplicate nodes, i.e. a node can be in the + dataflow of multiple connectors. The function that performs the actual + relocation (_replicate_dataflow_into_branch) will take care of that and + make sure that such nodes are only copied once. Args: state: The state on which we operate. From af5051d604ad902df02bfca33de54eee2703ec00 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 9 Mar 2026 17:17:26 +0100 Subject: [PATCH 10/32] Handle remaining comments --- .../move_dataflow_into_if_body.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 61e352a5f6..9c9e71cdd7 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -289,6 +289,12 @@ def _replicate_dataflow_into_branch( `if_block`. At the end the function will remove the `connector`, but it will not remove the original dataflow. + In case of nodes existing in multiple `nodes_to_move` sets coming from multiple + connectors, the function will only copy the necessary nodes and edges only once + based on the `old_to_new_nodes_map` keys and values. In case a connector to the + NestedSDFG is exists in the dataflow of another connector, the function + will remove the connection of the original connector and replace the global + AccessNode of the NestedSDFG with a temporary one and remove the original connector. Args: sdfg: The sdfg that we process, the one that contains `state`. @@ -322,21 +328,20 @@ def _replicate_dataflow_into_branch( old_to_new_nodes_map[(old_node, branch_state)] = copy_of_old_node branch_state.add_node(copy_of_old_node) - # There might be AccessNodes inside `nodes_to_move`, we now have to make sure - # that they are present inside the nested ones. By our base assumption they - # are transients and single use, because they are only used in one place. Make - # sure that we don't add new nodes if the nodes are already in the arrays of - # the inner SDFG. - for node in nodes_to_move: - if not isinstance(node, dace_nodes.AccessNode): + # There might be AccessNodes inside `nodes_to_move`, we now have to make sure + # that they are present inside the nested ones. By our base assumption they + # are transients and single use, because they are only used in one place. Make + # sure that we don't add new nodes if the nodes are already in the arrays of + # the inner SDFG. + if not isinstance(old_node, dace_nodes.AccessNode): continue - if node.data in inner_sdfg.arrays: + if old_node.data in inner_sdfg.arrays: continue - assert sdfg.arrays[node.data].transient + assert sdfg.arrays[old_node.data].transient # TODO(phimuell): Handle the case we need to rename something. inner_sdfg.add_datadesc( - node.data, - sdfg.arrays[node.data].clone(), + old_node.data, + sdfg.arrays[old_node.data].clone(), find_new_name=False, ) @@ -714,7 +719,7 @@ def _filter_relocatable_dataflow( The returned sets can include duplicate nodes, i.e. a node can be in the dataflow of multiple connectors. The function that performs the actual - relocation (_replicate_dataflow_into_branch) will take care of that and + relocation (`_replicate_dataflow_into_branch`) will take care of that and make sure that such nodes are only copied once. Args: From a5ab4a290fb875d26837e1c581eae5a1625f0b51 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 10 Mar 2026 13:06:15 +0100 Subject: [PATCH 11/32] Started with some modifications and imporvements not done yet. --- .../move_dataflow_into_if_body.py | 183 ++++++++++-------- .../runners/dace/transformations/utils.py | 54 +++++- 2 files changed, 153 insertions(+), 84 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 9c9e71cdd7..68cbb33bea 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -8,6 +8,7 @@ import copy import functools +import collections from typing import Any, Optional import dace @@ -128,8 +129,10 @@ def can_be_applied( if_block_spec = self._partition_if_block(if_block) if if_block_spec is None: return False + relocatable_connectors, non_relocatable_connectors, connector_usage_location = if_block_spec # Compute the dataflow that is relocated. + # NOTE: That the nodes sets are not sorted in any way, however, the raw_relocatable_dataflow, non_relocatable_dataflow = ( { conn_name: gtx_transformations.utils.find_upstream_nodes( @@ -140,7 +143,7 @@ def can_be_applied( ) for conn_name in conn_names } - for conn_names in if_block_spec + for conn_names in [relocatable_connectors, non_relocatable_connectors] ) relocatable_dataflow = self._filter_relocatable_dataflow( sdfg=sdfg, @@ -193,9 +196,8 @@ def apply( sdfg: dace.SDFG, ) -> None: if_block: dace_nodes.NestedSDFG = self.if_block - if_block_spec = self._partition_if_block(if_block) - assert if_block_spec is not None enclosing_map = graph.scope_dict()[if_block] + relocatable_connectors, non_relocatable_connectors, connector_usage_location = self._partition_if_block(if_block) # Find the dataflow that should be relocated. raw_relocatable_dataflow, non_relocatable_dataflow = ( @@ -208,7 +210,7 @@ def apply( ) for conn_name in conn_names } - for conn_names in if_block_spec + for conn_names in [relocatable_connectors, non_relocatable_connectors] ) relocatable_dataflow = self._filter_relocatable_dataflow( sdfg=sdfg, @@ -216,6 +218,7 @@ def apply( if_block=if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + connector_usage_location=connector_usage_location, enclosing_map=enclosing_map, ) @@ -225,7 +228,7 @@ def apply( # a node in the dataflow of one connector might be the global AccessNode of another connector and we have to handle # it properly in `_replicate_dataflow_into_branch`. conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]] = {} - for conn_name in relocatable_dataflow.keys(): + for conn_name in relocatable_connectors: conn_name_to_access_node_map[conn_name] = self._find_branch_for( if_block=if_block, connector=conn_name, @@ -621,36 +624,6 @@ def _check_for_data_and_symbol_conflicts( return True - def _find_branch_for( - self, - if_block: dace_nodes.NestedSDFG, - connector: str, - ) -> tuple[dace.SDFGState, dace_nodes.AccessNode]: - """ - Locates the branch and the AccessNode to where the dataflow should be relocated. - """ - inner_sdfg: dace.SDFG = if_block.sdfg - conditional_block: dace.sdfg.state.ConditionalBlock = next(iter(inner_sdfg.nodes())) - - # This will locate the state where the first AccessNode that refers to - # `connector` is found. Since `_partition_if_block()` makes sure that - # there is only one match this is okay. But it must be changed, if we - # lift this restriction. - for inner_state in conditional_block.all_states(): - connector_nodes: list[dace_nodes.AccessNode] = [ - dnode for dnode in inner_state.data_nodes() if dnode.data == connector - ] - if len(connector_nodes) == 0: - continue - break - else: - raise ValueError(f"Did not find a branch associated to '{connector}'.") - - assert isinstance(inner_state, dace.SDFGState) - assert inner_state.in_degree(connector_nodes[0]) == 0 - assert inner_state.out_degree(connector_nodes[0]) > 0 - return inner_state, connector_nodes[0] - def _has_if_block_relocatable_dataflow( self, sdfg: dace.SDFG, @@ -675,6 +648,7 @@ def _has_if_block_relocatable_dataflow( if_block_spec = self._partition_if_block(upstream_if_block) if if_block_spec is None: return False + *classified_connectors, connector_usage_location = if_block_spec raw_relocatable_dataflow, non_relocatable_dataflow = ( { @@ -708,8 +682,9 @@ def _filter_relocatable_dataflow( if_block: dace_nodes.NestedSDFG, raw_relocatable_dataflow: dict[str, set[dace_nodes.Node]], non_relocatable_dataflow: dict[str, set[dace_nodes.Node]], + connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], enclosing_map: dace_nodes.MapEntry, - ) -> dict[str, set[dace_nodes.Node]]: + ) -> dict[str, list[dace_nodes.Node]]: """Partition the dependencies. The function expects the dataflow that is upstream of every connector @@ -729,6 +704,8 @@ def _filter_relocatable_dataflow( that can be relocated, not yet filtered. non_relocatable_dataflow: The connectors and their associated dataflow that can not be relocated. + connector_usage_location: Maps a connector to the state and AccessNode + inside the if block. enclosing_map: The limiting node, i.e. the MapEntry of the Map where `if_block` is located in. """ @@ -743,16 +720,14 @@ def _filter_relocatable_dataflow( } # Find the known_nodes for each branch - known_nodes: dict[dace.SDFGState, set[dace_nodes.Node]] = dict() + known_nodes: dict[dace.SDFGState, set[dace_nodes.Node]] = collections.defaultdict(set) for conn_name, rel_df in relocatable_dataflow.items(): - branch_state, _ = self._find_branch_for(if_block=if_block, connector=conn_name) - if branch_state not in known_nodes: - known_nodes[branch_state] = set() + branch_state = connector_usage_location[conn_name] known_nodes[branch_state].update(rel_df) + # Find intersect of all known_nodes sets which are the nodes that are in the + # dataflow of multiple branches and thus doesn't make sense to relocate multiple_df_nodes: set[dace_nodes.Node] = set() - # Find intersect of all known_nodes sets which are the nodes that are in the dataflow - # of multiple branches and thus doesn't make sense to relocate for branch_state, known_nodes_set in known_nodes.items(): for other_branch_state, other_known_nodes_set in known_nodes.items(): if branch_state != other_branch_state: @@ -829,7 +804,8 @@ def filter_nodes( nodes_proposed_for_reloc.remove(reloc_node) has_been_updated = True - return nodes_proposed_for_reloc + # Bring the nodes into a deterministic order. + return gtx_transformations.utils.order_nodes(nodes_proposed_for_reloc) return { conn_name: filter_nodes(rel_df) for conn_name, rel_df in relocatable_dataflow.items() @@ -837,8 +813,9 @@ def filter_nodes( def _partition_if_block( self, + sdfg: dace.SDFG, if_block: dace_nodes.NestedSDFG, - ) -> Optional[tuple[set[str], set[str]]]: + ) -> Optional[tuple[list[str], set[str], dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]]]]: """Check if `if_block` can be processed and partition the input connectors. The function will check if `if_block` has the right structure, i.e. if it is @@ -849,17 +826,28 @@ def _partition_if_block( Returns: If `if_block` is unsuitable the function will return `None`. If `if_block` meets the structural requirements the function will return - two sets of strings. The first set contains the connectors that can be - relocated and the second one of the conditions that can not be relocated. + a tuple of length three. The first element is a `list` containing the + connectors whose dataflow can be relocated. The second element is a `set` + containing the connector names whose dataflow can not be relocated. The + third element is a `dict` that maps connectors to a pair containing the + state (inside the nested SDFG) and the `AccessNode` that refers to to the + connector. + Note that only the first element, the `list` containing the relocatable + dataflow, has a stable order that depends on the connector names. All + other elements have an unspecific order! """ - # TODO(phimuell): Change the return type to `tuple[list[str], list[str]]` and sort the connectors, such that the operation is deterministic. - # There shall only be one output and three inputs with given names. if len(if_block.out_connectors.keys()) == 0: return None - # These are all the output names. output_names: set[str] = set(if_block.out_connectors.keys()) + # If data is used as input and output we ignore it. + # TODO(phimuell): Think if this case can be handled. + input_names: set[str] = set(if_block.in_connectors.keys()) + input_names.difference_update(output_names) + if not input_names: + return None + # We require that the nested SDFG contains a single node, which is a # `ConditionalBlock` containing two branches. inner_sdfg: dace.SDFG = if_block.sdfg @@ -869,51 +857,80 @@ def _partition_if_block( if not isinstance(inner_if_block, dace.sdfg.state.ConditionalBlock): return None - # Defining it outside will ensure that there is only one AccessNode for every - # inconnector, which is something `_find_branch_for()` relies on. - reference_count: dict[str, int] = {conn_name: 0 for conn_name in if_block.in_connectors} + # Mapping between the connector and the inner access node. + connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]] = {} + # This is the dataflow that can not be relocated. + non_relocatable_connectors: set[str] = set() + + # Now inspect all states. for _, branch in inner_if_block.branches: output_count: dict[str, int] = {conn_name: 0 for conn_name in output_names} for inner_state in branch.all_states(): assert isinstance(inner_state, dace.SDFGState) - for node in inner_state.nodes(): - if not isinstance(node, dace_nodes.AccessNode): - return None - if node.data in reference_count: - reference_count[node.data] += 1 - exp_in_deg, exp_out_deg = 0, 1 - elif node.data in output_count: - output_count[node.data] += 1 - exp_in_deg, exp_out_deg = 1, 0 + for dnode in inner_state.data_nodes(): + node_data = dnode.data + + # Check if we can skip the data. + if node_data in non_relocatable_connectors: + continue + elif dnode.desc(sdfg).transient: + continue + + if node_data in connector_usage_location: + # The connectors that can be pulled inside must appear exactly once + # inside a state. In theory they could appear more, but then we + # would have to replicate the dataflow to different locations + # which is not supported. We still allow such situation, but + # consider them as non relocatable. + connector_usage_location.pop(node_data) + non_relocatable_connectors.add(node_data) + else: + if node_data in output_names: + exp_in_deg, exp_out_deg = 0, 1 + else: + assert node_data in output_count + exp_in_deg, exp_out_deg = 1, 0 + + # Check if the node has the right degree. + # TODO(phimuell): Find out if we can remove or relax these checks. + if (inner_state.in_degree(node) == exp_in_deg) and (inner_state.out_degree(node) == exp_out_deg): + connector_usage_location[node_data] = (branch, dnode) + else: + non_relocatable_connectors.add(node_data) + + # All input connectors are now considered non relocatable, thus + # the decomposition does not exist. + if len(non_relocatable_dataflow) == len(input_names): + assert non_relocatable_dataflow == input_names.keys() return None - if inner_state.in_degree(node) != exp_in_deg: - return None - if inner_state.out_degree(node) != exp_out_deg: - return None + # Each branch must write to all outputs. # TODO(phimuell): Think if this should be lifted. if any(count != 1 for count in output_count.values()): return None - # The connectors that can be pulled inside must appear exactly once. - # In theory they could appear more, but then we would have to replicate - # the dataflow to different locations which is not supported. - # So the ones that can be relocated were found exactly once. Zero would - # mean they can not be relocated and more than one means that we do not - # support it yet. - relocatable_connectors = { - conn_name for conn_name, conn_count in reference_count.items() if conn_count == 1 - } - non_relocatable_connectors = { - conn_name - for conn_name in reference_count.keys() - if conn_name not in relocatable_connectors - } - - if len(non_relocatable_connectors) == 0: + # There is nothing to relocate. + if len(connector_usage_location) == 0: return None - if len(relocatable_connectors) == 0: + + # In addition to the non relocatable connectors that were found above, we also + # mark all connectors that were not found as non relocatable. These connectors + # are used for conditions or are not used. + non_relocatable_dataflow.update( + conn for conn in input_names if conn not in connector_usage_location + ) + + # We require that at least one non relocatable dataflow is there, this is for + # the condition. This is not strictly needed, as it could also be passed as + # a symbol, but currently the lowering does not do this and we keep it as + # a sanity check. + if len(non_relocatable_dataflow) == 0: return None - return relocatable_connectors, non_relocatable_connectors + + # We only guarantee that `relocatable_connectors` has an stable order, + # everything else has no guaranteed order, even `connector_usage_location`. + relocatable_connectors = sorted(connector_usage_location.keys()) + + return relocatable_connectors, non_relocatable_connectors, connector_usage_location diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 68a7c33201..f00fcb6b05 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -9,7 +9,7 @@ """Common functionality for the transformations/optimization pipeline.""" import uuid -from typing import Optional, Sequence, TypeVar, Union +from typing import Optional, Sequence, TypeVar, Union, Iterable import dace from dace import data as dace_data, libraries as dace_lib, subsets as dace_sbs, symbolic as dace_sym @@ -847,3 +847,55 @@ def gt_data_descriptor_mapping( name_mapping[data_inside] = data_outside return name_mapping + + +def order_nodes( + nodes: Iterable[dace_nodes.Node], + state: dace.SDFGState, +) -> list[dace_nodes.Node]: + """_Tries_ to order `nodes` in a stable and deterministic way. + + The result should be considered as the best way to order a group of nodes inside + a state. It does, however, not guarantees a stable order in any way. + + The condition this function works best is if the nodes have a unique labels + and AccessNodes referring to the same data have different degrees (one node used + for reading one for writing). + + Known pathological cases: + - Multiple top level AccessNodes referring to the same data that have the same + degree (uncommon in GT4Py). + """ + + # Describes when it works and when not. + def key_fun(node: dace_nodes.Node) -> tuple[str, str, int, int]: + if isinstance(node, dace_nodes.AccessNode): + nid: str = node.data + elif hasattr(node, "label"): + # TODO(phimuell): Maybe add `node.code` in case of Tasklets? + nid = node.label + else: + nid = str(node) + + return (type(node).__name__, nid, state.in_degree(node), state.out_degree(node)) + + return sorted(nodes, key=key_fun) + + +def order_edges( + edges: Iterable[dace_graph.MultiConnectorEdge[dace.Memlet]] +) -> list[dace_graph.MultiConnectorEdge[dace.Memlet]]: + """_Tries_ to order `edges` in a stable and deterministic way. + + There is no guarantee of a stable order, although it should be consider more stable + than the one generated by `order_nodes()`. However, the order might depends on the + selected specialization level. Similar to `order_nodes()` this function works best + if the string (not serialization) representation of the involved nodes is unique, + which means their label in most cases. + """ + + # This is probably the best way to sort edge, because it considers the source and + # destination node, as tie breaker the connectors are used and as second level of + # tie breaker the subsets are used. This means that the specialization level + # is also involved because the subset `[a, b]` is different from `[1, 10]`. + return sorted(edges, key=str) From 189311e1c7556f21ac79cb3bd06d919aab5f4b56 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 13 Mar 2026 09:25:33 +0100 Subject: [PATCH 12/32] First version, now running the tests. --- .../move_dataflow_into_if_body.py | 178 +++++++++--------- 1 file changed, 92 insertions(+), 86 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 68cbb33bea..6d607f131d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -6,9 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import collections import copy import functools -import collections from typing import Any, Optional import dace @@ -126,7 +126,7 @@ def can_be_applied( return False # Test if the `if_block` is valid. This will also give us the names. - if_block_spec = self._partition_if_block(if_block) + if_block_spec = self._partition_if_block(sdfg, if_block) if if_block_spec is None: return False relocatable_connectors, non_relocatable_connectors, connector_usage_location = if_block_spec @@ -151,6 +151,7 @@ def can_be_applied( if_block=if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + connector_usage_location=connector_usage_location, enclosing_map=enclosing_map, ) @@ -197,7 +198,9 @@ def apply( ) -> None: if_block: dace_nodes.NestedSDFG = self.if_block enclosing_map = graph.scope_dict()[if_block] - relocatable_connectors, non_relocatable_connectors, connector_usage_location = self._partition_if_block(if_block) + relocatable_connectors, non_relocatable_connectors, connector_usage_location = ( + self._partition_if_block(sdfg, if_block) # type: ignore[misc] # Guaranteed to be not None. + ) # Find the dataflow that should be relocated. raw_relocatable_dataflow, non_relocatable_dataflow = ( @@ -210,7 +213,7 @@ def apply( ) for conn_name in conn_names } - for conn_names in [relocatable_connectors, non_relocatable_connectors] + for conn_names in [relocatable_connectors, non_relocatable_connectors] ) relocatable_dataflow = self._filter_relocatable_dataflow( sdfg=sdfg, @@ -222,34 +225,24 @@ def apply( enclosing_map=enclosing_map, ) - # Create a mapping from connector names to the corresponding AccessNode inside the branch and the branch state. - # This is necessary to properly patch the dataflow inside the branch, i.e. to connect it to the global AccessNode - # corresponding to the connector if necessary. We need to gather this information for all the connectors because - # a node in the dataflow of one connector might be the global AccessNode of another connector and we have to handle - # it properly in `_replicate_dataflow_into_branch`. - conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]] = {} - for conn_name in relocatable_connectors: - conn_name_to_access_node_map[conn_name] = self._find_branch_for( - if_block=if_block, - connector=conn_name, - ) - - # Create a map of the old to the new nodes to keep track of the old nodes copied - # and their corresponding new nodes. The map should be per branch of the ConditionalBlock - # because there could be a node that has to be copied in all branches. Since the - # `relocatable_dataflow` nodes are not disjoint we need this mapping to avoid copying the same - # node multiple times and to properly connect the copied nodes. + # Since the node sets in `relocatable_dataflow` are not necessarily disjoint, + # we have to keep track of what we have relocated already, such that we do + # not relocate something multiple times. + # TODO(phimuell, iomaganaris): Find out if the state needs to be part of the + # key especially because of the filtering we do. old_to_new_nodes_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() - # Finally relocate the dataflow - for conn_name, nodes_to_move in relocatable_dataflow.items(): + + # Relocate the dataflow. + # NOTE: Important to iterate over `relocatable_connectors` for stability. + for conn_name in relocatable_connectors: self._replicate_dataflow_into_branch( state=graph, sdfg=sdfg, if_block=if_block, enclosing_map=enclosing_map, - nodes_to_move=nodes_to_move, + nodes_to_move=relocatable_dataflow[conn_name], connector=conn_name, - conn_name_to_access_node_map=conn_name_to_access_node_map, + connector_usage_location=connector_usage_location, old_to_new_nodes_map=old_to_new_nodes_map, ) @@ -278,7 +271,7 @@ def _replicate_dataflow_into_branch( enclosing_map: dace_nodes.MapEntry, nodes_to_move: set[dace_nodes.Node], connector: str, - conn_name_to_access_node_map: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], + connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], old_to_new_nodes_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node], ) -> None: """Replicate the dataflow in `nodes_to_move` from `state` into `if_block`. @@ -306,8 +299,8 @@ def _replicate_dataflow_into_branch( enclosing_map: The enclosing map. nodes_to_move: The list of nodes that should be removed. connector: The connector that should be inlined. - conn_name_to_access_node_map: A mapping from connector names to the - corresponding AccessNode inside the branch and the branch state. + connector_usage_location: Maps connector names to the state and AccessNode + where they appear inside the nested SDFG. old_to_new_nodes_map: A mapping from the old nodes to the new nodes. The keys of the mapping are tuples of the old node and the branch state for which the new node was created. The values are the new nodes. @@ -317,7 +310,10 @@ def _replicate_dataflow_into_branch( return inner_sdfg: dace.SDFG = if_block.sdfg - branch_state, connector_node = conn_name_to_access_node_map[connector] + branch_state, _connector_node = connector_usage_location[connector] + + # Sort the nodes such that we insert them in the same order every time. + nodes_to_move = gtx_transformations.utils.order_nodes(nodes_to_move, state) # type: ignore[assignment] # Replicate the nodes and store them in the `old_to_new_nodes_map` mapping. # Add the SDFGState to the key of the dictionary because we have to create @@ -354,38 +350,32 @@ def _replicate_dataflow_into_branch( # The data dependencies (incomming edges) of the nodes, i.e. the not relocated dataflow, # are still missing. for node in unique_old_nodes: - for oedge in state.out_edges(node): + for oedge in gtx_transformations.utils.order_edges(state.out_edges(node)): if oedge.dst is if_block: - if oedge.dst_conn == connector: - # This connection maps the outside data into the nested SDFG, thus its destination is technically the same data. - # TODO(phimuell): Make subsets complete. - # TODO(phimuell): Check if this Memlet is always correct, especially in case of slicing. - branch_state.add_edge( - old_to_new_nodes_map[(oedge.src, branch_state)], - oedge.src_conn, - connector_node, - None, - dace.Memlet.from_memlet(oedge.data), - ) - else: - assert oedge.dst_conn in conn_name_to_access_node_map - # Some of the nodes_to_move are also in the dataflow of some other connector whose dataflow - # we should also move. Since we already moved the dataflow of the other connector we connect - # the dataflow to the global AccessNode corresponding to the other connector - branch_state.add_edge( - old_to_new_nodes_map[(oedge.src, branch_state)], - oedge.src_conn, - conn_name_to_access_node_map[oedge.dst_conn][1], - None, - dace.Memlet.from_memlet(oedge.data), - ) - # Handle the other connector AccessNode and connector - inner_access_node_of_connector_name = conn_name_to_access_node_map[ - oedge.dst_conn - ][1].data - if not inner_sdfg.arrays[inner_access_node_of_connector_name].transient: - inner_sdfg.arrays[inner_access_node_of_connector_name].transient = True - if_block.remove_in_connector(inner_access_node_of_connector_name) + # Here the destination will change, it is no longer the if block, + # but instead the connector node inside branch. It is important + # that `oedge.dst_conn` is not necessarily `connector` as we allow + # for intersecting relocatable node sets. We have to handle these + # to cases slightly different. + assert connector_usage_location[oedge.dst_conn][0] is branch_state + + # TODO(phimuell): Check if this Memlet is always correct, especially in case of slicing. + # TODO(reviewer): Make sure I added a test. + branch_state.add_edge( + old_to_new_nodes_map[(oedge.src, branch_state)], + oedge.src_conn, + connector_usage_location[oedge.dst_conn][1], + None, + # THIS MEMLET IS SO WRONG. + dace.Memlet.from_memlet(oedge.data), + ) + + if connector != oedge.dst_conn: + # This is not the master connector. So we have to handle it. + # TODO(phimuell): Try to refactor such that we can remove it. + inner_sdfg.arrays[oedge.dst_conn].transient = True + if_block.remove_in_connector(oedge.dst_conn) + else: # Restore a connection between two nodes that were relocated. assert oedge.dst in nodes_to_move @@ -402,7 +392,7 @@ def _replicate_dataflow_into_branch( # the final result. We find them by scanning the input edges of the nodes # that have been relocated. for node in unique_old_nodes: - for iedge in state.in_edges(node): + for iedge in gtx_transformations.utils.order_edges(state.in_edges(node)): if iedge.src in nodes_to_move: # Inner data dependency, there is nothing to do and the edge was # created above. @@ -497,6 +487,7 @@ def _remove_outside_dataflow( The function will also remove data containers that are no longer in use. """ + # Creating the union of all first ensures that a node is only removed once. all_relocatable_dataflow: set[dace_nodes.Node] = functools.reduce( lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set() ) @@ -645,7 +636,7 @@ def _has_if_block_relocatable_dataflow( enclosing_map: The limiting node, i.e. the MapEntry of the Map `if_block` is located in. """ - if_block_spec = self._partition_if_block(upstream_if_block) + if_block_spec = self._partition_if_block(sdfg, upstream_if_block) if if_block_spec is None: return False *classified_connectors, connector_usage_location = if_block_spec @@ -660,7 +651,7 @@ def _has_if_block_relocatable_dataflow( ) for conn_name in conn_names } - for conn_names in if_block_spec + for conn_names in classified_connectors ) filtered_relocatable_dataflow = self._filter_relocatable_dataflow( sdfg=sdfg, @@ -668,6 +659,7 @@ def _has_if_block_relocatable_dataflow( if_block=upstream_if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + connector_usage_location=connector_usage_location, enclosing_map=enclosing_map, ) if all(len(rel_df) == 0 for rel_df in filtered_relocatable_dataflow.values()): @@ -684,7 +676,7 @@ def _filter_relocatable_dataflow( non_relocatable_dataflow: dict[str, set[dace_nodes.Node]], connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], enclosing_map: dace_nodes.MapEntry, - ) -> dict[str, list[dace_nodes.Node]]: + ) -> dict[str, set[dace_nodes.Node]]: """Partition the dependencies. The function expects the dataflow that is upstream of every connector @@ -696,6 +688,10 @@ def _filter_relocatable_dataflow( dataflow of multiple connectors. The function that performs the actual relocation (`_replicate_dataflow_into_branch`) will take care of that and make sure that such nodes are only copied once. + However, the function will make sure that if the sets of two connectors + intersects then it will be relocated into the same branch. I.e. a node + is only relocated and does not need to be replicated because it appears + in multiple branches. Args: state: The state on which we operate. @@ -719,23 +715,27 @@ def _filter_relocatable_dataflow( for conn_name, rel_df in raw_relocatable_dataflow.items() } - # Find the known_nodes for each branch - known_nodes: dict[dace.SDFGState, set[dace_nodes.Node]] = collections.defaultdict(set) + # For each branch inside the if block find the nodes that are relocated inside it. + known_nodes_coll: dict[dace.SDFGState, set[dace_nodes.Node]] = collections.defaultdict(set) for conn_name, rel_df in relocatable_dataflow.items(): - branch_state = connector_usage_location[conn_name] - known_nodes[branch_state].update(rel_df) + known_nodes_coll[connector_usage_location[conn_name]].update(rel_df) + known_nodes = list(known_nodes_coll.values()) # Order is unimportant here. - # Find intersect of all known_nodes sets which are the nodes that are in the - # dataflow of multiple branches and thus doesn't make sense to relocate + # We allow that the set of nodes to relocated that are associated to a connector + # can intersect. However, in that case they have to be inside the same branch. + # Thus we now find the nodes that would need to be relocated into different + # branches. multiple_df_nodes: set[dace_nodes.Node] = set() - for branch_state, known_nodes_set in known_nodes.items(): - for other_branch_state, other_known_nodes_set in known_nodes.items(): - if branch_state != other_branch_state: - multiple_df_nodes.update(known_nodes_set.intersection(other_known_nodes_set)) + for i, known_nodes_set in enumerate(known_nodes): + for j in range(i + 1, len(known_nodes)): + multiple_df_nodes.update(known_nodes_set.intersection(known_nodes[j])) if multiple_df_nodes: # Remove from the relocatable dataflow the nodes that appear in multiple branches - # as it doesn't make sense to relocate them and duplicate them in both branches. + # as it doesn't make sense to relocate them and duplicate them in both branches. + # NOTE: The reason for this filtering is not that it is impossible to do, but + # would be rather complex to do, as we now have to copy nodes instead of + # simply moving them around. relocatable_dataflow = { conn_name: rel_df.difference(multiple_df_nodes) for conn_name, rel_df in relocatable_dataflow.items() @@ -751,6 +751,9 @@ def filter_nodes( while has_been_updated: has_been_updated = False + # TODO(phimuell): Look at me. + # TODO(reviewer): Make sure I looked at it. + for reloc_node in list(nodes_proposed_for_reloc): # The node was already handled in a previous iteration. if reloc_node not in nodes_proposed_for_reloc: @@ -804,8 +807,7 @@ def filter_nodes( nodes_proposed_for_reloc.remove(reloc_node) has_been_updated = True - # Bring the nodes into a deterministic order. - return gtx_transformations.utils.order_nodes(nodes_proposed_for_reloc) + return nodes_proposed_for_reloc return { conn_name: filter_nodes(rel_df) for conn_name, rel_df in relocatable_dataflow.items() @@ -815,7 +817,9 @@ def _partition_if_block( self, sdfg: dace.SDFG, if_block: dace_nodes.NestedSDFG, - ) -> Optional[tuple[list[str], set[str], dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]]]]: + ) -> Optional[ + tuple[list[str], set[str], dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]]] + ]: """Check if `if_block` can be processed and partition the input connectors. The function will check if `if_block` has the right structure, i.e. if it is @@ -874,7 +878,7 @@ def _partition_if_block( # Check if we can skip the data. if node_data in non_relocatable_connectors: continue - elif dnode.desc(sdfg).transient: + elif dnode.desc(inner_sdfg).transient: continue if node_data in connector_usage_location: @@ -890,20 +894,22 @@ def _partition_if_block( if node_data in output_names: exp_in_deg, exp_out_deg = 0, 1 else: - assert node_data in output_count + assert node_data in input_names exp_in_deg, exp_out_deg = 1, 0 # Check if the node has the right degree. # TODO(phimuell): Find out if we can remove or relax these checks. - if (inner_state.in_degree(node) == exp_in_deg) and (inner_state.out_degree(node) == exp_out_deg): + if (inner_state.in_degree(dnode) == exp_in_deg) and ( + inner_state.out_degree(dnode) == exp_out_deg + ): connector_usage_location[node_data] = (branch, dnode) else: non_relocatable_connectors.add(node_data) # All input connectors are now considered non relocatable, thus # the decomposition does not exist. - if len(non_relocatable_dataflow) == len(input_names): - assert non_relocatable_dataflow == input_names.keys() + if len(non_relocatable_connectors) == len(input_names): + assert non_relocatable_connectors == input_names return None # Each branch must write to all outputs. @@ -918,15 +924,15 @@ def _partition_if_block( # In addition to the non relocatable connectors that were found above, we also # mark all connectors that were not found as non relocatable. These connectors # are used for conditions or are not used. - non_relocatable_dataflow.update( - conn for conn in input_names if conn not in connector_usage_location + non_relocatable_connectors.update( + conn for conn in input_names if conn not in connector_usage_location ) # We require that at least one non relocatable dataflow is there, this is for # the condition. This is not strictly needed, as it could also be passed as # a symbol, but currently the lowering does not do this and we keep it as # a sanity check. - if len(non_relocatable_dataflow) == 0: + if len(non_relocatable_connectors) == 0: return None # We only guarantee that `relocatable_connectors` has an stable order, From 674f1654657e293603923e39c1f42133dc8cbdf6 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 13 Mar 2026 09:41:54 +0100 Subject: [PATCH 13/32] Fixed some small things and did an improvement. --- .../move_dataflow_into_if_body.py | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 6d607f131d..d6eccd1d93 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -869,7 +869,6 @@ def _partition_if_block( # Now inspect all states. for _, branch in inner_if_block.branches: - output_count: dict[str, int] = {conn_name: 0 for conn_name in output_names} for inner_state in branch.all_states(): assert isinstance(inner_state, dace.SDFGState) for dnode in inner_state.data_nodes(): @@ -878,8 +877,11 @@ def _partition_if_block( # Check if we can skip the data. if node_data in non_relocatable_connectors: continue + elif node_data in output_names: + continue elif dnode.desc(inner_sdfg).transient: continue + assert node_data in input_names if node_data in connector_usage_location: # The connectors that can be pulled inside must appear exactly once @@ -890,21 +892,14 @@ def _partition_if_block( connector_usage_location.pop(node_data) non_relocatable_connectors.add(node_data) + elif inner_state.in_degree(dnode) != 0: + # The node is also written to, a strange situation, that is + # however allowed. So we can not handle it. + non_relocatable_connectors.add(node_data) + else: - if node_data in output_names: - exp_in_deg, exp_out_deg = 0, 1 - else: - assert node_data in input_names - exp_in_deg, exp_out_deg = 1, 0 - - # Check if the node has the right degree. - # TODO(phimuell): Find out if we can remove or relax these checks. - if (inner_state.in_degree(dnode) == exp_in_deg) and ( - inner_state.out_degree(dnode) == exp_out_deg - ): - connector_usage_location[node_data] = (branch, dnode) - else: - non_relocatable_connectors.add(node_data) + # This is a proper input connector node. + connector_usage_location[node_data] = (inner_state, dnode) # All input connectors are now considered non relocatable, thus # the decomposition does not exist. @@ -912,11 +907,6 @@ def _partition_if_block( assert non_relocatable_connectors == input_names return None - # Each branch must write to all outputs. - # TODO(phimuell): Think if this should be lifted. - if any(count != 1 for count in output_count.values()): - return None - # There is nothing to relocate. if len(connector_usage_location) == 0: return None From eeceac28d1b17339d0c27e1eda83f80638d423ad Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 13 Mar 2026 10:07:31 +0100 Subject: [PATCH 14/32] Made a small fix, the unit tests should now pass. --- .../runners/dace/transformations/move_dataflow_into_if_body.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index d6eccd1d93..5842633e15 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -718,7 +718,7 @@ def _filter_relocatable_dataflow( # For each branch inside the if block find the nodes that are relocated inside it. known_nodes_coll: dict[dace.SDFGState, set[dace_nodes.Node]] = collections.defaultdict(set) for conn_name, rel_df in relocatable_dataflow.items(): - known_nodes_coll[connector_usage_location[conn_name]].update(rel_df) + known_nodes_coll[connector_usage_location[conn_name][0]].update(rel_df) known_nodes = list(known_nodes_coll.values()) # Order is unimportant here. # We allow that the set of nodes to relocated that are associated to a connector From a5ab642a98b75199b33177b3f2b0b13b87b89a41 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 13 Mar 2026 10:09:02 +0100 Subject: [PATCH 15/32] Added some more information about what I want to do. --- .../dace/transformations/move_dataflow_into_if_body.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 5842633e15..0900bde69c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -33,6 +33,13 @@ class MoveDataflowIntoIfBody(dace_transformation.SingleStateTransformation): """The transformation moves dataflow into the if branches. + ## TODO ## + - Slicing unit test + - Expending the unit test `_4` that Ioannis made such that it `__arg1` also has + something to relocate. + - Fix the naming issue thing in the if fuser. + - Make the `xfail` test run. + Essentially transforms code from this ```python __arg1 = foo(...) From bbc128fb04711f5d41dbba3c1f875b8850ac0952 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 13 Mar 2026 10:53:01 +0100 Subject: [PATCH 16/32] Added a small note. --- .../runners/dace/transformations/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index f00fcb6b05..2328444414 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -9,7 +9,7 @@ """Common functionality for the transformations/optimization pipeline.""" import uuid -from typing import Optional, Sequence, TypeVar, Union, Iterable +from typing import Iterable, Optional, Sequence, TypeVar, Union import dace from dace import data as dace_data, libraries as dace_lib, subsets as dace_sbs, symbolic as dace_sym @@ -29,6 +29,7 @@ def unique_name(name: str) -> str: not be used if a particular order should be enforced. This function is marked for deprecation. """ + # TODO Stabilize this. maximal_length = 200 unique_sufix = str(uuid.uuid1()).replace("-", "_") if len(name) > (maximal_length - len(unique_sufix)): @@ -850,8 +851,8 @@ def gt_data_descriptor_mapping( def order_nodes( - nodes: Iterable[dace_nodes.Node], - state: dace.SDFGState, + nodes: Iterable[dace_nodes.Node], + state: dace.SDFGState, ) -> list[dace_nodes.Node]: """_Tries_ to order `nodes` in a stable and deterministic way. @@ -883,7 +884,7 @@ def key_fun(node: dace_nodes.Node) -> tuple[str, str, int, int]: def order_edges( - edges: Iterable[dace_graph.MultiConnectorEdge[dace.Memlet]] + edges: Iterable[dace_graph.MultiConnectorEdge[dace.Memlet]], ) -> list[dace_graph.MultiConnectorEdge[dace.Memlet]]: """_Tries_ to order `edges` in a stable and deterministic way. From 9216eb4935676a9bd8685fb3787d5c57f82b8fc3 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 13 Mar 2026 10:59:13 +0100 Subject: [PATCH 17/32] Updated the naming. --- .../runners/dace/transformations/utils.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 2328444414..04dccb685e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -8,7 +8,6 @@ """Common functionality for the transformations/optimization pipeline.""" -import uuid from typing import Iterable, Optional, Sequence, TypeVar, Union import dace @@ -25,16 +24,20 @@ def unique_name(name: str) -> str: """Adds a unique string to `name`. Note: - The names generates by this function are rather unstable and it should - not be used if a particular order should be enforced. This function is - marked for deprecation. + This function assumes that the "namespace" defined by `__gt4py_unique_name_` + can be used freely. """ - # TODO Stabilize this. maximal_length = 200 - unique_sufix = str(uuid.uuid1()).replace("-", "_") - if len(name) > (maximal_length - len(unique_sufix)): - name = name[: (maximal_length - len(unique_sufix) - 1)] - return f"{name}_{unique_sufix}" + if not hasattr(unique_name, "_counter"): + unique_name._counter = 0 # type: ignore[attr-defined] + + proposed_name = f"__gt4py_unique_name_{name}_{unique_name._counter}" # type: ignore[attr-defined] + unique_name._counter += 1 # type: ignore[attr-defined] + + if len(proposed_name) > maximal_length: + raise ValueError("Name became too long.") + + return proposed_name def gt_make_transients_persistent( From a99a8a5e279e4c086282778c81e2a106e3321a35 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Thu, 19 Mar 2026 10:51:01 +0100 Subject: [PATCH 18/32] WIP --- .../move_dataflow_into_if_body.py | 651 +++++++++--------- .../runners/dace/transformations/utils.py | 5 +- .../test_move_dataflow_into_if_body.py | 7 +- 3 files changed, 324 insertions(+), 339 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 0900bde69c..c4336c8048 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -39,6 +39,9 @@ class MoveDataflowIntoIfBody(dace_transformation.SingleStateTransformation): something to relocate. - Fix the naming issue thing in the if fuser. - Make the `xfail` test run. + - Make Test where something is used in multiple branches. + - Check if symbol renaming is okay. + - test if map outside data is sliced into the if block. Essentially transforms code from this ```python @@ -161,9 +164,7 @@ def can_be_applied( connector_usage_location=connector_usage_location, enclosing_map=enclosing_map, ) - - # If no branch has something to inline then we are done. - if all(len(rel_df) == 0 for rel_df in relocatable_dataflow.values()): + if len(relocatable_dataflow) == 0: return False # Check if relatability is possible. @@ -183,18 +184,17 @@ def can_be_applied( # transformation is applied in a loop until it applies nowhere anymore. # NOTE: This is a restriction due to the current implementation. if not self.ignore_upstream_blocks: - for reloc_dataflow in relocatable_dataflow.values(): - if any( - self._has_if_block_relocatable_dataflow( - sdfg=sdfg, - state=graph, - upstream_if_block=upstream_if_block, - enclosing_map=enclosing_map, - ) - for upstream_if_block in reloc_dataflow - if isinstance(upstream_if_block, dace_nodes.NestedSDFG) - ): - return False + if any( + self._has_if_block_relocatable_dataflow( + sdfg=sdfg, + state=graph, + upstream_if_block=upstream_if_block, + enclosing_map=enclosing_map, + ) + for upstream_if_block in relocated_dataflow + if isinstance(upstream_if_block, dace_nodes.NestedSDFG) + ): + return False return True @@ -222,7 +222,7 @@ def apply( } for conn_names in [relocatable_connectors, non_relocatable_connectors] ) - relocatable_dataflow = self._filter_relocatable_dataflow( + relocatable_dataflow: set = self._filter_relocatable_dataflow( sdfg=sdfg, state=graph, if_block=if_block, @@ -232,28 +232,33 @@ def apply( enclosing_map=enclosing_map, ) - # Since the node sets in `relocatable_dataflow` are not necessarily disjoint, - # we have to keep track of what we have relocated already, such that we do - # not relocate something multiple times. - # TODO(phimuell, iomaganaris): Find out if the state needs to be part of the - # key especially because of the filtering we do. - old_to_new_nodes_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() + # We have to bring the nodes in a deterministic order. + nodes_to_move: list[dace_nodes.Node] = gtx_transformations.utils.order_nodes( + nodes_to_move, state + ) + + # For each node we have to find out in which state inside the `if_block` it will + # end up. `relocation_destination` has a fixed order. + relocation_destination: dict[dace_nodes.Node, dace.SDFGState] = {} + for node_to_move in nodes_to_move: + for conn, raw_reloc_dataflow_of_conn in raw_relocatable_dataflow.items(): + if node_to_move in raw_reloc_dataflow_of_conn: + break + else: + raise ValueError("Could not find node '{node_to_move}'") + relocation_destination[node_to_move] = connector_usage_location[conn][0] + + # TODO: LOOK INTO THIS FUNCTION IF IT IS UNSTABLE. + self._update_symbol_mapping(if_block, sdfg, nodes_to_move) # Relocate the dataflow. - # NOTE: Important to iterate over `relocatable_connectors` for stability. - for conn_name in relocatable_connectors: - self._replicate_dataflow_into_branch( - state=graph, - sdfg=sdfg, - if_block=if_block, - enclosing_map=enclosing_map, - nodes_to_move=relocatable_dataflow[conn_name], - connector=conn_name, - connector_usage_location=connector_usage_location, - old_to_new_nodes_map=old_to_new_nodes_map, - ) - - self._update_symbol_mapping(if_block, sdfg) + self._replicate_dataflow_into_branch( + state=graph, + sdfg=sdfg, + if_block=if_block, + enclosing_map=enclosing_map, + relocation_destination=relocation_destination, + ) self._remove_outside_dataflow( sdfg=sdfg, @@ -276,140 +281,129 @@ def _replicate_dataflow_into_branch( state: dace.SDFGState, if_block: dace_nodes.NestedSDFG, enclosing_map: dace_nodes.MapEntry, - nodes_to_move: set[dace_nodes.Node], - connector: str, + relocation_destination: dict[dace_nodes.Node, dace.SDFGState], connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], - old_to_new_nodes_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node], ) -> None: - """Replicate the dataflow in `nodes_to_move` from `state` into `if_block`. - - First the function will determine into which branch, inside `if_block`, - the dataflow has to be replicated. It will then copy the dataflow, nodes - listed in `nodes_to_move` and insert them into that state. - The function will then create the edges to connect them in the same way - as they where outside. If there is an outer data dependency, for example - a read to a global memory, then the function will patch that inside the - `if_block`. - At the end the function will remove the `connector`, but it will not remove - the original dataflow. - In case of nodes existing in multiple `nodes_to_move` sets coming from multiple - connectors, the function will only copy the necessary nodes and edges only once - based on the `old_to_new_nodes_map` keys and values. In case a connector to the - NestedSDFG is exists in the dataflow of another connector, the function - will remove the connection of the original connector and replace the global - AccessNode of the NestedSDFG with a temporary one and remove the original connector. + """Replicate the dataflow in `relocation_destination` into `if_block`. + + The function will replicate the dataflow listed in `relocatable_connectors.keys()`, + that needs + to be connected, in some way, to the `if_block`. It will remove the connectors + that are no longer needed, but it will not remove the original dataflow nor + update the symbol mapping. Args: sdfg: The sdfg that we process, the one that contains `state`. state: The state we operate on, the one that contains `if_block`. if_block: The `if_block` into which we inline. enclosing_map: The enclosing map. - nodes_to_move: The list of nodes that should be removed. - connector: The connector that should be inlined. + nodes_to_move: The list of nodes that should be moved. connector_usage_location: Maps connector names to the state and AccessNode where they appear inside the nested SDFG. - old_to_new_nodes_map: A mapping from the old nodes to the new nodes. - The keys of the mapping are tuples of the old node and the branch - state for which the new node was created. The values are the new nodes. """ - # Nothing to relocate nothing to do. - if len(nodes_to_move) == 0: - return - - inner_sdfg: dace.SDFG = if_block.sdfg - branch_state, _connector_node = connector_usage_location[connector] + inner_sdfg = if_block.sdfg - # Sort the nodes such that we insert them in the same order every time. - nodes_to_move = gtx_transformations.utils.order_nodes(nodes_to_move, state) # type: ignore[assignment] - - # Replicate the nodes and store them in the `old_to_new_nodes_map` mapping. - # Add the SDFGState to the key of the dictionary because we have to create - # new node for the different branches. - unique_old_nodes: list[dace_nodes.Node] = [] - for old_node in nodes_to_move: - if (old_node, branch_state) in old_to_new_nodes_map: - continue - unique_old_nodes.append(old_node) - copy_of_old_node = copy.deepcopy(old_node) - old_to_new_nodes_map[(old_node, branch_state)] = copy_of_old_node - branch_state.add_node(copy_of_old_node) - - # There might be AccessNodes inside `nodes_to_move`, we now have to make sure - # that they are present inside the nested ones. By our base assumption they - # are transients and single use, because they are only used in one place. Make - # sure that we don't add new nodes if the nodes are already in the arrays of - # the inner SDFG. - if not isinstance(old_node, dace_nodes.AccessNode): + # Data that has been fully mapped into the `if_block` and its name. Format is + # because of aliasing. + fully_mapped_in_data: dict[str, set[str]] = collections.defaultdict(set) + for if_iedge in state.in_edges(if_block): + if if_iedge.data.is_empty(): continue - if old_node.data in inner_sdfg.arrays: - continue - assert sdfg.arrays[old_node.data].transient - # TODO(phimuell): Handle the case we need to rename something. - inner_sdfg.add_datadesc( - old_node.data, - sdfg.arrays[old_node.data].clone(), - find_new_name=False, - ) - - # Now add the edges between the edges that have been replicated inside the - # branch state, these are the outgoing edges. - # Now add the outgoing edges between the replicated nodes inside the branch state. - # The data dependencies (incomming edges) of the nodes, i.e. the not relocated dataflow, - # are still missing. - for node in unique_old_nodes: - for oedge in gtx_transformations.utils.order_edges(state.out_edges(node)): + outer_data = if_iedge.data.data + mapped_in_range = if_iedge.data.subset + outer_desc = sdfg.arrays[outer_data] + if mapped_in_range.covers(dace_sbs.Range.from_array(outer_desc)) == True: + fully_mapped_in_data[outer_data].add(if_iedge.dst_conn) + + # Maps old nodes to the new relocated nodes inside the `if_block`. Note that + # the state _inside_ the `if_block` is part of the key. This is needed to + # handle the "outside Map data" which must be mapped into multiple states. + node_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() + rename_map: dict[tuple[str, dace.SDFGState], str] = dict() + + # Replicate the nodes into the `if_block` and create the needed data The + # "outside Map data" will be handled when we handle the incoming edges. + for origin_node, branch_state in relocation_destination.items(): + reloc_node = copy.deepcopy(origin_node) + node_map[(origin_node, branch_state)] = reloc_node + branch_state.add_node(reloc_node) + + # If we relocate an AccessNode, we have to make sure that the data descriptor + # is also added to the nested SDFG. We allow renaming of data containers + # but we do not allow renaming of symbols, this is checked by + # `_check_for_data_and_symbol_conflicts()`. + if isinstance(origin_node, dace_nodes.AccessNode): + assert sdfg.arrays[origin_node.data].transient + # TODO(phimuell): Handle the case we need to rename something. + new_data_name = inner_sdfg.add_datadesc( + origin_node.data, + sdfg.arrays[origin_node.data].clone(), + find_new_name=True, + ) + reloc_node.data = new_data_name + rename_map[(origin_node.data, branch_state)] = new_data_name + + # We now recreate the edges + for origin_node, branch_state in relocation_destination.items(): + for oedge in state.out_edges(node): + assert not oedge.data.is_empty() if oedge.dst is if_block: - # Here the destination will change, it is no longer the if block, - # but instead the connector node inside branch. It is important - # that `oedge.dst_conn` is not necessarily `connector` as we allow - # for intersecting relocatable node sets. We have to handle these - # to cases slightly different. - assert connector_usage_location[oedge.dst_conn][0] is branch_state - - # TODO(phimuell): Check if this Memlet is always correct, especially in case of slicing. + # This defines the "argument" to the nested SDFG. This means that + # the new destination now is the single node inside `if_block` + # that represents the argument. + assert not inner_sdfg.arrays[oedge.dst_conn].transient + assert branch_state is connector_usage_location[oedge.dst_conn][0] + assert isinstance(oedge.src, dace_nodes.AccessNode) + assert oedge.data.wcr is None and oedge.data.other_subset is None + # TODO(reviewer): Make sure I added a test. branch_state.add_edge( - old_to_new_nodes_map[(oedge.src, branch_state)], + node_map[(oedge.src, branch_state)], oedge.src_conn, connector_usage_location[oedge.dst_conn][1], None, - # THIS MEMLET IS SO WRONG. - dace.Memlet.from_memlet(oedge.data), + dace.Memlet( + data=rename_map[(oedge.src, branch_state)], + subset=oedge.data.subset, # Is always subset. + other_subset=dace.Memlet.from_array(inner_sdfg.arrays[oedge.dst_conn]), + volume=oedge.data.volume, + dynamic=oedge.data.dynamic, + ), ) - if connector != oedge.dst_conn: - # This is not the master connector. So we have to handle it. - # TODO(phimuell): Try to refactor such that we can remove it. - inner_sdfg.arrays[oedge.dst_conn].transient = True - if_block.remove_in_connector(oedge.dst_conn) + # The inner data is no longer a global but has become a transient. + inner_sdfg.arrays[oedge.dst_conn].transient = True + if_block.remove_in_connector(oedge.dst_conn) else: - # Restore a connection between two nodes that were relocated. - assert oedge.dst in nodes_to_move - branch_state.add_edge( - old_to_new_nodes_map[(oedge.src, branch_state)], + # If it is not going to the `if_block` it must be a connection + # between to relocated nodes, which we can simply copy. + assert origin_node in relocation_destination + new_oedge = branch_state.add_edge( + node_map[(oedge.src, branch_state)], oedge.src_conn, - old_to_new_nodes_map[(oedge.dst, branch_state)], + node_map[(oedge.dst, branch_state)], oedge.dst_conn, dace.Memlet.from_memlet(oedge.data), ) + if not oedge.data.is_empty(): + new_oedge.data.data = rename_map[(oedge.data.data, branch_state)] # Now we have to satisfy the data dependencies, i.e. forward all nodes that # could not have been moved inside `if_block` but are still needed to compute # the final result. We find them by scanning the input edges of the nodes # that have been relocated. - for node in unique_old_nodes: - for iedge in gtx_transformations.utils.order_edges(state.in_edges(node)): - if iedge.src in nodes_to_move: - # Inner data dependency, there is nothing to do and the edge was - # created above. + # TODO(phimuell): Can we merge the two outer loops? + for origin_node, branch_state in relocation_destination.items(): + for iedge in state.in_edges(node): + if iedge.src in relocation_destination: + # Dependency between two relocated nodes: Handled above. continue - if iedge.data.is_empty(): - # Empty Memlets are there to maintain some order relation, "happens - # before". Depending on the situation we can remove or have to - # recreate them. The case where the connection comes from a - # node within the relocated dataflow is handled above. - assert iedge.src is enclosing_map + elif iedge.data.is_empty(): + # This is an empty Memlet that is between a node that is relocated + # and a node that is not relocated. Because we move the destination + # of the edge into the `if_block` the "happens before" relation + # is automatically handled and we have to do nothing. continue # Now we have to figuring out where the data is coming from, since @@ -418,71 +412,90 @@ def _replicate_dataflow_into_branch( # The data is coming from outside the Map scope, i.e. not defined # inside the Map scope, so we have to trace it back. memlet_path = state.memlet_path(iedge) - outer_data = memlet_path[0].src + outer_node = memlet_path[0].src else: # The data is defined somewhere in the Map scope itself. - outer_data = iedge.src + outer_node = iedge.src + # TODO(phimuell): It is possible that this does not lead to an # AccessNode on the outside, but to something inside the Map scope # such as the MapExit of an inner map. To handle such a case we need # to construct the set of nodes to move differently, i.e. # considering this case already there. - if not isinstance(outer_data, dace_nodes.AccessNode): + if not isinstance(outer_node, dace_nodes.AccessNode): raise NotImplementedError() - assert not gtx_transformations.utils.is_view(outer_data, sdfg) - - # If the data is not yet available in the inner SDFG made - # patch it through. - if outer_data.data not in inner_sdfg.arrays: - inner_desc = sdfg.arrays[outer_data.data].clone() - inner_desc.transient = False - # TODO(phimuell): Handle the case we need to rename something. - inner_sdfg.add_datadesc(outer_data.data, inner_desc, False) - # TODO(phimeull): We pass the whole data inside the SDFG. - # Find out if there are cases where this is wrong. - state.add_edge( - iedge.src, - iedge.src_conn, - if_block, - outer_data.data, - dace.Memlet( - data=outer_data.data, subset=dace_sbs.Range.from_array(inner_desc) - ), - ) - if_block.add_in_connector(outer_data.data) - else: - # This is the case that we found a node, that refers to data that - # was already patched into the `if_block`. We would have to remove - # this, but since this function just replicates the dataflow, - # it will not do that. Instead we postpone this to the cleanup - # phase, see `_remove_outside_dataflow()`. + assert not gtx_transformations.utils.is_view(outer_node, sdfg) + + outer_data = outer_node.data + outer_desc = sdfg.arrays[outer_data] + + # Check if the data is already mapped in and if not map it in. + if (outer_node, branch_state) in node_map: + # The node is already mapped into this state, so nothing to do. + assert (outer_data, branch_state) in rename_map + assert not node_map[(outer_node, branch_state)].desc(inner_sdfg).transient pass - if (outer_data, branch_state) not in old_to_new_nodes_map: - assert all( - outer_data.data != mapped_node.data - for mapped_node, mapped_branch_state in old_to_new_nodes_map.keys() - if isinstance(mapped_node, dace_nodes.AccessNode) - and mapped_branch_state == branch_state + elif outer_data in fully_mapped_in_data: + # The data has already been mapped into the `if_block` thus + # check the state if there is a source node or create one. + outer_aliases = fully_mapped_in_data[outer_data] + candidate_nodes: list[dace_nodes.AccessNode] = sorted( + ( + dnode + for dnode in branch_state.data_nodes() + if dnode.data in outer_aliases + ), + key=lambda dnode: dnode.data, ) - assert outer_data.data in inner_sdfg.arrays - assert not inner_sdfg.arrays[outer_data.data].transient - old_to_new_nodes_map[(outer_data, branch_state)] = branch_state.add_access( - outer_data.data, copy.copy(outer_data.debuginfo) + + if len(candidate_source_nodes) == 0: + # There is no AccessNode in the state so we have to create one. + inner_data = sorted(outer_aliases)[0] + inner_node = branch_state.add_access(inner_data) + + else: + # This is to handle a legal but very unlikely case, that we + # do not handle. If there is any non sink node, then we might + # have a read-write conflict. + candidate_source_nodes = [ + dnode for dnode in candidate_nodes if branch_state.in_degree(dnode) == 0 + ] + if len(candidate_source_nodes) != len(candidate_nodes): + raise NotImplementedError() + + # We take the first node, since they are sorted it is deterministic. + inner_node = candidate_source_nodes[0] + + assert (outer_data, branch_state) not in rename_map + assert not inner_sdfg.arrays[inner_node.data].transient + rename_map[(outer_data, branch_state)] = inner_node.data + node_map[(outer_node, branch_state)] = inner_node + + else: + # The data is not already mapped in and is also unknown. + # Here we rely on that we do not have to perform symbol renaming. + inner_data = inner_sdfg.add_datadesc( + outer_data, + outer_desc.clone(), + find_new_name=True, ) + inner_sdfg.arrays[inner_data].transient = False + + inner_node = branch_state.add_access(inner_data) + rename_map[(outer_data, branch_state)] = inner_node.data + node_map[(outer_node, branch_state)] = inner_node + fully_mapped_in_data[outer_data].add(inner_data) # Now create the edge in the inner state. - branch_state.add_edge( - old_to_new_nodes_map[(outer_data, branch_state)], + new_edge = branch_state.add_edge( + node_map[(outer_node, branch_state)], None, - old_to_new_nodes_map[(iedge.dst, branch_state)], + node_map[(iedge.dst, branch_state)], iedge.dst_conn, copy.deepcopy(iedge.data), ) - - # The old connector name is no longer valid. - inner_sdfg.arrays[connector].transient = True - if_block.remove_in_connector(connector) + new_edge.data.data = rename_map[(outer_data, branch_state)] def _remove_outside_dataflow( self, @@ -558,12 +571,14 @@ def _check_for_data_and_symbol_conflicts( self, sdfg: dace.SDFG, state: dace.SDFGState, - relocatable_dataflow: dict[str, set[dace_nodes.Node]], + relocatable_dataflow: set[dace_nodes.Node], if_block: dace_nodes.NestedSDFG, enclosing_map: dace_nodes.MapEntry, ) -> bool: """Check if the relocation would cause any conflict, such as a symbol clash.""" + # TODO: remove check ion data renaming but still require symbols. + # TODO(phimuell): There is an obscure case where the nested SDFG, on its own, # defines a symbol that is also mapped, for example a dynamic Map range. # It is probably not a problem, because of the scopes DaCe adds when @@ -571,16 +586,12 @@ def _check_for_data_and_symbol_conflicts( # Create a subgraph to compute the free symbols, i.e. the symbols that # need to be supplied from the outside. However, this are not all. - # Note, just adding some "well chosen" nodes to the set will not work. - all_relocated_dataflow: set[dace_nodes.Node] = functools.reduce( - lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set() - ) requiered_symbols: set[str] = dace.sdfg.state.StateSubgraphView( - state, all_relocated_dataflow + state, relocated_dataflow ).free_symbols inner_data_names = if_block.sdfg.arrays.keys() - for node_to_check in all_relocated_dataflow: + for node_to_check in relocated_dataflow: if ( isinstance(node_to_check, dace_nodes.AccessNode) and node_to_check.data in inner_data_names @@ -593,7 +604,7 @@ def _check_for_data_and_symbol_conflicts( for iedge in state.in_edges(node_to_check): src_node = iedge.src - if src_node not in all_relocated_dataflow: + if src_node not in relocated_dataflow: # This means that `src_node` is not relocated but mapped into the # `if` block. This means that `edge` is replicated as well. # NOTE: This code is based on the one found in `DataflowGraphView`. @@ -669,7 +680,7 @@ def _has_if_block_relocatable_dataflow( connector_usage_location=connector_usage_location, enclosing_map=enclosing_map, ) - if all(len(rel_df) == 0 for rel_df in filtered_relocatable_dataflow.values()): + if len(filtered_relocatable_dataflow) == 0: return False return True @@ -683,22 +694,16 @@ def _filter_relocatable_dataflow( non_relocatable_dataflow: dict[str, set[dace_nodes.Node]], connector_usage_location: dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]], enclosing_map: dace_nodes.MapEntry, - ) -> dict[str, set[dace_nodes.Node]]: - """Partition the dependencies. + ) -> set[dace_nodes.Node]: + """Compute the final set of the relocatable nodes. The function expects the dataflow that is upstream of every connector of the `if_block`. The function will then scan the dataflow and compute - the parts that actually can be relocated and returns a `dict` mapping - every relocatable input connector to the set of nodes that can be relocated. - - The returned sets can include duplicate nodes, i.e. a node can be in the - dataflow of multiple connectors. The function that performs the actual - relocation (`_replicate_dataflow_into_branch`) will take care of that and - make sure that such nodes are only copied once. - However, the function will make sure that if the sets of two connectors - intersects then it will be relocated into the same branch. I.e. a node - is only relocated and does not need to be replicated because it appears - in multiple branches. + the parts that actually can be relocated. It will then return a `set` + containing all nodes that can actually be relocated. If this set is empty + then nothing can be relocated. + Note that the returned `set` is in an unspecific order and before processing + should be ordered. Args: state: The state on which we operate. @@ -713,119 +718,100 @@ def _filter_relocatable_dataflow( `if_block` is located in. """ - # Remove the parts of the dataflow that is unrelocatable. + # These are the nodes that can not be relocated anyway. all_non_relocatable_dataflow: set[dace_nodes.Node] = functools.reduce( - lambda s1, s2: s1.union(s2), non_relocatable_dataflow.values(), set() + lambda s1, s2: s1.union(s2), all_non_relocatable_dataflow.values(), set() + ) + + # While we can relocate nodes that are needed by multiple connectors, we can + # not handle the case if they end up in multiple branches. + nodes_in_states: dict[dace.SDFGState, set[dace_nodes.Node]] = collections.defaultdict(set) + for conn_name, rel_df in raw_relocatable_dataflow.items(): + nodes_in_states[connector_usage_location[conn_name][0]].update(rel_df) + state_nodes_sets = list(nodes_per_state.values()) # Order is unimportant here. + for i, state_nodes in enumerate(state_nodes_sets): + for j in range(i + 1, len(state_nodes_sets)): + all_non_relocatable_dataflow.update(state_nodes.intersection(state_nodes_sets[j])) + + # Instead of scanning the nodes associated to each connector separately we will + # process all of them together. We do this because a node can be associated to + # multiple connectors and as such data dependencies can show up. We will, + # after the filtering distribute them back. + nodes_proposed_for_reloc: set[dace_nodes.Node] = functools.reduce( + lambda s1, s2: s1.union(s2), raw_relocatable_dataflow.values(), set() ) - relocatable_dataflow = { - conn_name: rel_df.difference(all_non_relocatable_dataflow) - for conn_name, rel_df in raw_relocatable_dataflow.items() - } - - # For each branch inside the if block find the nodes that are relocated inside it. - known_nodes_coll: dict[dace.SDFGState, set[dace_nodes.Node]] = collections.defaultdict(set) - for conn_name, rel_df in relocatable_dataflow.items(): - known_nodes_coll[connector_usage_location[conn_name][0]].update(rel_df) - known_nodes = list(known_nodes_coll.values()) # Order is unimportant here. - - # We allow that the set of nodes to relocated that are associated to a connector - # can intersect. However, in that case they have to be inside the same branch. - # Thus we now find the nodes that would need to be relocated into different - # branches. - multiple_df_nodes: set[dace_nodes.Node] = set() - for i, known_nodes_set in enumerate(known_nodes): - for j in range(i + 1, len(known_nodes)): - multiple_df_nodes.update(known_nodes_set.intersection(known_nodes[j])) - - if multiple_df_nodes: - # Remove from the relocatable dataflow the nodes that appear in multiple branches - # as it doesn't make sense to relocate them and duplicate them in both branches. - # NOTE: The reason for this filtering is not that it is impossible to do, but - # would be rather complex to do, as we now have to copy nodes instead of - # simply moving them around. - relocatable_dataflow = { - conn_name: rel_df.difference(multiple_df_nodes) - for conn_name, rel_df in relocatable_dataflow.items() - } + + # Filtering out all nodes that can not be relocated anyway. + if all_non_relocatable_dataflow: + nodes_proposed_for_reloc.difference_update(all_non_relocatable_dataflow) # TODO(phimuell): If we operate outside of a Map we also have to make sure that # the data is single use data, is not an AccessNode that refers to global # memory nor is a source AccessNode. - def filter_nodes( - nodes_proposed_for_reloc: set[dace_nodes.Node], - ) -> set[dace_nodes.Node]: - has_been_updated = True - while has_been_updated: - has_been_updated = False - - # TODO(phimuell): Look at me. - # TODO(reviewer): Make sure I looked at it. - - for reloc_node in list(nodes_proposed_for_reloc): - # The node was already handled in a previous iteration. - if reloc_node not in nodes_proposed_for_reloc: - continue + has_been_updated = True + while has_been_updated: + has_been_updated = False - assert ( - state.in_degree(reloc_node) > 0 - ) # Because we are currently always inside a Map - - # If the node is needed by anything that is not also moved - # into the `if` body, then it has to remain outside. For that we - # have to pretend that `if_block` is also relocated. - if any( - oedge.dst not in nodes_proposed_for_reloc - for oedge in state.out_edges(reloc_node) - if oedge.dst is not if_block - ): - nodes_proposed_for_reloc.remove(reloc_node) - has_been_updated = True - continue + # TODO(phimuell): Look at me. + # TODO(reviewer): Make sure I looked at it. - # We do not look at all incoming nodes, but have to ignore some of them. - # We ignore `enclosed_map` because it acts as boundary, and the node - # on the other side of it is mapped into the `if` body anyway. We - # ignore the AccessNodes because they will either be relocated into - # the `if` body or be mapped (remain outside but made accessible - # inside), thus their relocation state is of no concern for - # `reloc_node`. - non_mappable_incoming_nodes: set[dace_nodes.Node] = { - iedge.src - for iedge in state.in_edges(reloc_node) - if not ( - (iedge.src is enclosing_map) - or isinstance(iedge.src, dace_nodes.AccessNode) - ) - } - if non_mappable_incoming_nodes.issubset(nodes_proposed_for_reloc): - # All nodes that can not be mapped into the `if` body are - # currently scheduled to be relocated, thus there is not - # problem. - pass + for reloc_node in list(nodes_proposed_for_reloc): + # The node was already removed in a previous iteration. + if reloc_node not in nodes_proposed_for_reloc: + continue - else: - # Only some of the non mappable nodes are selected to be - # moved inside the `if` body. This means that `reloc_node` - # can also not be moved because of its input dependencies. - # Since we can not relocate `reloc_node` this also implies - # that none of its input can. Thus we remove them from - # `nodes_proposed_for_reloc`. - nodes_proposed_for_reloc.difference_update(non_mappable_incoming_nodes) - nodes_proposed_for_reloc.remove(reloc_node) - has_been_updated = True - - return nodes_proposed_for_reloc - - return { - conn_name: filter_nodes(rel_df) for conn_name, rel_df in relocatable_dataflow.items() - } + # Because we are currently always inside a Map + assert state.in_degree(reloc_node) > 0 + + # If the node is needed by anything that is not also moved + # into the `if` body, then it has to remain outside. For that we + # have to pretend that `if_block` is also relocated. + if any( + oedge.dst not in nodes_proposed_for_reloc + for oedge in state.out_edges(reloc_node) + if oedge.dst is not if_block + ): + nodes_proposed_for_reloc.remove(reloc_node) + has_been_updated = True + continue + + # Empty edges + + # We do not look at incoming edges that comes from nodes that are not + # mappable, i.e. AccessNodes. In addition to AccessNodes we also + # ignore `enclosing_map` because it acts as a boundary anyway and + # on its other side is an AccessNode anyway. + non_mappable_incoming_nodes: set[dace_nodes.Node] = { + iedge.src + for iedge in state.in_edges(reloc_node) + if not ( + (iedge.src is enclosing_map) or isinstance(iedge.src, dace_nodes.AccessNode) + ) + } + if non_mappable_incoming_nodes.issubset(nodes_proposed_for_reloc): + # All nodes that can not be mapped into the `if` body are + # currently scheduled to be relocated, thus there is no + # problem. + pass + + else: + # Only some of the non mappable nodes are selected to be moved + # inside the `if` body. This means that `reloc_node` can also + # not be moved because of its input dependencies. Since we can + # not relocate `reloc_node` this also implies that none of its + # inputs either. + nodes_proposed_for_reloc.difference_update(non_mappable_incoming_nodes) + nodes_proposed_for_reloc.remove(reloc_node) + has_been_updated = True + + return nodes_proposed_for_reloc def _partition_if_block( self, sdfg: dace.SDFG, if_block: dace_nodes.NestedSDFG, ) -> Optional[ - tuple[list[str], set[str], dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]]] + tuple[list[str], [ststr], dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]]] ]: """Check if `if_block` can be processed and partition the input connectors. @@ -835,32 +821,31 @@ def _partition_if_block( be inlined into the `if_block` and which can not. Returns: - If `if_block` is unsuitable the function will return `None`. - If `if_block` meets the structural requirements the function will return - a tuple of length three. The first element is a `list` containing the - connectors whose dataflow can be relocated. The second element is a `set` - containing the connector names whose dataflow can not be relocated. The - third element is a `dict` that maps connectors to a pair containing the - state (inside the nested SDFG) and the `AccessNode` that refers to to the - connector. - Note that only the first element, the `list` containing the relocatable - dataflow, has a stable order that depends on the connector names. All - other elements have an unspecific order! + If `if_block` is unsuitable the function will return `None`. In case the + `if_block` is suitable a `tuple` of length three is returned. + The first element is a `list`, which is never empty, containing all + input connectors that can be relocated. The list is sorted in a stable + order. The second element is a list containing all input connectors that + can not be relocated, it can be empty and is not in a particular order. + The third element is a `dict` that maps connectors to a pair containing + the state (inside the nested SDFG) and the only `AccessNode` that refers + to that connector. + It is important that only the first element of the `tuple` has a guaranteed + order. """ if len(if_block.out_connectors.keys()) == 0: return None + input_names: set[str] = set(if_block.in_connectors.keys()) output_names: set[str] = set(if_block.out_connectors.keys()) # If data is used as input and output we ignore it. # TODO(phimuell): Think if this case can be handled. - input_names: set[str] = set(if_block.in_connectors.keys()) input_names.difference_update(output_names) - if not input_names: + if len(input_names) == 0: return None - # We require that the nested SDFG contains a single node, which is a - # `ConditionalBlock` containing two branches. + # We require that the nested SDFG contains a single node, which is a `ConditionalBlock`. inner_sdfg: dace.SDFG = if_block.sdfg if inner_sdfg.number_of_nodes() != 1: return None @@ -875,9 +860,8 @@ def _partition_if_block( non_relocatable_connectors: set[str] = set() # Now inspect all states. - for _, branch in inner_if_block.branches: - for inner_state in branch.all_states(): - assert isinstance(inner_state, dace.SDFGState) + for _, if_branch in inner_if_block.branches: + for inner_state in if_branch.all_states(): for dnode in inner_state.data_nodes(): node_data = dnode.data @@ -891,25 +875,25 @@ def _partition_if_block( assert node_data in input_names if node_data in connector_usage_location: - # The connectors that can be pulled inside must appear exactly once - # inside a state. In theory they could appear more, but then we - # would have to replicate the dataflow to different locations - # which is not supported. We still allow such situation, but - # consider them as non relocatable. + # There are multiple AccessNodes referring to the same connector + # which is currently not supported. In theory they could appear + # more, but then we would have to replicate the dataflow to + # different locations which is not supported. We allow such + # situations but consider the connector non relocatable. connector_usage_location.pop(node_data) non_relocatable_connectors.add(node_data) elif inner_state.in_degree(dnode) != 0: - # The node is also written to, a strange situation, that is - # however allowed. So we can not handle it. + # The node is also written to, allowed by SDFG grammar, but we + # do not allow it. non_relocatable_connectors.add(node_data) else: # This is a proper input connector node. connector_usage_location[node_data] = (inner_state, dnode) - # All input connectors are now considered non relocatable, thus - # the decomposition does not exist. + # If all input connectors were classified as non relocatable + # then the partition does not exist. if len(non_relocatable_connectors) == len(input_names): assert non_relocatable_connectors == input_names return None @@ -919,8 +903,7 @@ def _partition_if_block( return None # In addition to the non relocatable connectors that were found above, we also - # mark all connectors that were not found as non relocatable. These connectors - # are used for conditions or are not used. + # mark all connectors that were not found as non relocatable. non_relocatable_connectors.update( conn for conn in input_names if conn not in connector_usage_location ) @@ -936,4 +919,4 @@ def _partition_if_block( # everything else has no guaranteed order, even `connector_usage_location`. relocatable_connectors = sorted(connector_usage_location.keys()) - return relocatable_connectors, non_relocatable_connectors, connector_usage_location + return relocatable_connectors, list(non_relocatable_connectors), connector_usage_location diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 04dccb685e..5aeb5188fa 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -27,12 +27,15 @@ def unique_name(name: str) -> str: This function assumes that the "namespace" defined by `__gt4py_unique_name_` can be used freely. """ + + # TODO(phimuell, reviewer): How does this behaves in multiple process scenarios? + # It should be okay as long as the update is atomic, I would say. maximal_length = 200 if not hasattr(unique_name, "_counter"): unique_name._counter = 0 # type: ignore[attr-defined] - proposed_name = f"__gt4py_unique_name_{name}_{unique_name._counter}" # type: ignore[attr-defined] unique_name._counter += 1 # type: ignore[attr-defined] + proposed_name = f"__gt4py_unique_name_{name}_{unique_name._counter}" # type: ignore[attr-defined] if len(proposed_name) > maximal_length: raise ValueError("Name became too long.") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 067ff82cac..12f503da6f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -986,9 +986,6 @@ def test_if_mover_dependent_branch_4(): assert {tlet.label for tlet in inner_tlet} == expected_tlet -@pytest.mark.xfail( - reason="This test is currently expected to fail. For the explanation see: https://github.com/GridTools/gt4py/pull/2514#discussion_r2906948120" -) def test_if_mover_dependent_branch_5(): """ Essentially tests the following situation: @@ -1134,7 +1131,9 @@ def test_if_mover_dependent_branch_5(): mx.add_out_connector("OUT_f") sdfg.validate() - _perform_test(sdfg, explected_applies=2) + # TODO(iomaginaris, phimuell): Why was `explected_applies` set to `2`? This makes + # only sense if we have multiple `if` blocks or do I miss something here? + _perform_test(sdfg, explected_applies=1) # # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) From 4b2922686fa0f993cba5312843393e69b6eb2c16 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 20 Mar 2026 11:02:57 +0100 Subject: [PATCH 19/32] Let's see how this thing works. Took me long enough. --- .../move_dataflow_into_if_body.py | 177 +++++++++--------- 1 file changed, 89 insertions(+), 88 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index c4336c8048..7607a7cdd2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -13,7 +13,6 @@ import dace from dace import ( - data as dace_data, dtypes as dace_dtypes, properties as dace_properties, subsets as dace_sbs, @@ -191,7 +190,7 @@ def can_be_applied( upstream_if_block=upstream_if_block, enclosing_map=enclosing_map, ) - for upstream_if_block in relocated_dataflow + for upstream_if_block in relocatable_dataflow if isinstance(upstream_if_block, dace_nodes.NestedSDFG) ): return False @@ -222,7 +221,7 @@ def apply( } for conn_names in [relocatable_connectors, non_relocatable_connectors] ) - relocatable_dataflow: set = self._filter_relocatable_dataflow( + relocatable_dataflow: set[dace_nodes.Node] = self._filter_relocatable_dataflow( sdfg=sdfg, state=graph, if_block=if_block, @@ -234,22 +233,24 @@ def apply( # We have to bring the nodes in a deterministic order. nodes_to_move: list[dace_nodes.Node] = gtx_transformations.utils.order_nodes( - nodes_to_move, state + relocatable_dataflow, graph ) # For each node we have to find out in which state inside the `if_block` it will # end up. `relocation_destination` has a fixed order. relocation_destination: dict[dace_nodes.Node, dace.SDFGState] = {} for node_to_move in nodes_to_move: + # Although `node_top_move` could be reached through different connectors + # they are all associated to the same branch. + target_state: Optional[dace.SDFGState] = None for conn, raw_reloc_dataflow_of_conn in raw_relocatable_dataflow.items(): if node_to_move in raw_reloc_dataflow_of_conn: + target_state = connector_usage_location[conn][0] break else: raise ValueError("Could not find node '{node_to_move}'") - relocation_destination[node_to_move] = connector_usage_location[conn][0] - - # TODO: LOOK INTO THIS FUNCTION IF IT IS UNSTABLE. - self._update_symbol_mapping(if_block, sdfg, nodes_to_move) + assert target_state is not None + relocation_destination[node_to_move] = target_state # Relocate the dataflow. self._replicate_dataflow_into_branch( @@ -258,12 +259,19 @@ def apply( if_block=if_block, enclosing_map=enclosing_map, relocation_destination=relocation_destination, + connector_usage_location=connector_usage_location, + ) + + # Must be performed after relocation. + self._update_symbol_mapping( + sdfg=sdfg, + if_block=if_block, ) self._remove_outside_dataflow( sdfg=sdfg, state=graph, - relocatable_dataflow=relocatable_dataflow, + relocation_destination=relocation_destination, ) # Because we relocate some node it seems that DaCe gets a bit confused. @@ -303,24 +311,23 @@ def _replicate_dataflow_into_branch( """ inner_sdfg = if_block.sdfg - # Data that has been fully mapped into the `if_block` and its name. Format is - # because of aliasing. + # Maps old nodes to the new relocated nodes inside the `if_block`. Note that + # the state _inside_ the `if_block` is part of the key. This is needed to + # handle the "outside Map data" which must be mapped into multiple states. + node_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() + rename_map: dict[tuple[str, dace.SDFGState], str] = dict() + + # Data that has been fully mapped into the `if_block` and its name inside it. fully_mapped_in_data: dict[str, set[str]] = collections.defaultdict(set) for if_iedge in state.in_edges(if_block): if if_iedge.data.is_empty(): continue outer_data = if_iedge.data.data - mapped_in_range = if_iedge.data.subset + mapped_in_range = if_iedge.data.subset # Is always `.subset`. outer_desc = sdfg.arrays[outer_data] - if mapped_in_range.covers(dace_sbs.Range.from_array(outer_desc)) == True: + if mapped_in_range.covers(dace_sbs.Range.from_array(outer_desc)) == True: # noqa: E712 [true-false-comparison] # SymPy comparison fully_mapped_in_data[outer_data].add(if_iedge.dst_conn) - # Maps old nodes to the new relocated nodes inside the `if_block`. Note that - # the state _inside_ the `if_block` is part of the key. This is needed to - # handle the "outside Map data" which must be mapped into multiple states. - node_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() - rename_map: dict[tuple[str, dace.SDFGState], str] = dict() - # Replicate the nodes into the `if_block` and create the needed data The # "outside Map data" will be handled when we handle the incoming edges. for origin_node, branch_state in relocation_destination.items(): @@ -344,19 +351,19 @@ def _replicate_dataflow_into_branch( rename_map[(origin_node.data, branch_state)] = new_data_name # We now recreate the edges + # We insert them in the same order as we find them outside. for origin_node, branch_state in relocation_destination.items(): - for oedge in state.out_edges(node): - assert not oedge.data.is_empty() + for oedge in state.out_edges(origin_node): if oedge.dst is if_block: # This defines the "argument" to the nested SDFG. This means that # the new destination now is the single node inside `if_block` # that represents the argument. + assert not oedge.data.is_empty() assert not inner_sdfg.arrays[oedge.dst_conn].transient assert branch_state is connector_usage_location[oedge.dst_conn][0] assert isinstance(oedge.src, dace_nodes.AccessNode) assert oedge.data.wcr is None and oedge.data.other_subset is None - # TODO(reviewer): Make sure I added a test. branch_state.add_edge( node_map[(oedge.src, branch_state)], oedge.src_conn, @@ -372,12 +379,13 @@ def _replicate_dataflow_into_branch( ) # The inner data is no longer a global but has become a transient. + assert oedge.dst_conn in if_block.in_connectors inner_sdfg.arrays[oedge.dst_conn].transient = True if_block.remove_in_connector(oedge.dst_conn) else: - # If it is not going to the `if_block` it must be a connection - # between to relocated nodes, which we can simply copy. + # Edges that do not go to the `if_block` must lead to a node + # that is also relocated. assert origin_node in relocation_destination new_oedge = branch_state.add_edge( node_map[(oedge.src, branch_state)], @@ -395,7 +403,7 @@ def _replicate_dataflow_into_branch( # that have been relocated. # TODO(phimuell): Can we merge the two outer loops? for origin_node, branch_state in relocation_destination.items(): - for iedge in state.in_edges(node): + for iedge in state.in_edges(origin_node): if iedge.src in relocation_destination: # Dependency between two relocated nodes: Handled above. continue @@ -403,7 +411,7 @@ def _replicate_dataflow_into_branch( # This is an empty Memlet that is between a node that is relocated # and a node that is not relocated. Because we move the destination # of the edge into the `if_block` the "happens before" relation - # is automatically handled and we have to do nothing. + # is automatically handled and this edge is no longer needed. continue # Now we have to figuring out where the data is coming from, since @@ -429,16 +437,16 @@ def _replicate_dataflow_into_branch( outer_data = outer_node.data outer_desc = sdfg.arrays[outer_data] - # Check if the data is already mapped in and if not map it in. if (outer_node, branch_state) in node_map: - # The node is already mapped into this state, so nothing to do. + # The node is already mapped into this state. assert (outer_data, branch_state) in rename_map assert not node_map[(outer_node, branch_state)].desc(inner_sdfg).transient pass elif outer_data in fully_mapped_in_data: - # The data has already been mapped into the `if_block` thus - # check the state if there is a source node or create one. + # The data has already been mapped into the `if_block`, but not in + # `branch_state`. We first look if the state contains an AccessNode + # referring to that data. outer_aliases = fully_mapped_in_data[outer_data] candidate_nodes: list[dace_nodes.AccessNode] = sorted( ( @@ -449,15 +457,16 @@ def _replicate_dataflow_into_branch( key=lambda dnode: dnode.data, ) - if len(candidate_source_nodes) == 0: + if len(candidate_nodes) == 0: # There is no AccessNode in the state so we have to create one. inner_data = sorted(outer_aliases)[0] inner_node = branch_state.add_access(inner_data) else: - # This is to handle a legal but very unlikely case, that we - # do not handle. If there is any non sink node, then we might - # have a read-write conflict. + # There is an AccessNode in the state. To handle some legal + # but unlikely case we check that nodes we found are all + # source nodes. We have to do this to prevent read-write + # conflicts. candidate_source_nodes = [ dnode for dnode in candidate_nodes if branch_state.in_degree(dnode) == 0 ] @@ -501,37 +510,37 @@ def _remove_outside_dataflow( self, sdfg: dace.SDFG, state: dace.SDFGState, - relocatable_dataflow: dict[str, set[dace_nodes.Node]], + relocation_destination: dict[dace_nodes.Node, dace.SDFGState], ) -> None: """Removes the original dataflow, that has been relocated. The function will also remove data containers that are no longer in use. """ - # Creating the union of all first ensures that a node is only removed once. - all_relocatable_dataflow: set[dace_nodes.Node] = functools.reduce( - lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set() - ) - # Before we can clean the original nodes, we must clean the dataflow. If a - # node, that was relocated, has incoming connections we must remove them - # and the parent dataflow. - for node_to_remove in all_relocatable_dataflow: + # Clean up the dataflow first, before removing the nodes. + for node_to_remove in relocation_destination: + # Create all "interface edges", i.e. connecting a relocated node with one + # that is not. This is needed to properly remove dangling Memlet paths. for iedge in list(state.in_edges(node_to_remove)): - if iedge.src in all_relocatable_dataflow: + if iedge.src in relocation_destination: continue dace_sutils.remove_edge_and_dangling_path(state, iedge) if isinstance(node_to_remove, dace_nodes.AccessNode): + # NOTE: We can remove the data here, because by assumption data that is + # referred to by an AccessNode inside a Map is single use data and + # used nowhere else. + # NOTE: This will temporarily create an invalid SDFG. assert node_to_remove.desc(sdfg).transient sdfg.remove_data(node_to_remove.data, validate=False) # Remove the original nodes (data descriptors were deleted in the loop above). - state.remove_nodes_from(all_relocatable_dataflow) + state.remove_nodes_from(relocation_destination.keys()) def _update_symbol_mapping( self, + sdfg: dace.SDFG, if_block: dace_nodes.NestedSDFG, - parent: dace.SDFG, ) -> None: """Updates the symbol mapping of the nested SDFG. @@ -539,14 +548,17 @@ def _update_symbol_mapping( are available in the parent SDFG. """ symbol_mapping = if_block.symbol_mapping - missing_symbols = [ms for ms in if_block.sdfg.free_symbols if ms not in symbol_mapping] + missing_symbols = sorted( + (ms for ms in if_block.sdfg.free_symbols if ms not in symbol_mapping), + key=lambda sym: str(sym), + ) symbol_mapping.update({s: s for s in missing_symbols}) if_block.symbol_mapping = symbol_mapping # Performs conversion. # Add new global symbols to nested SDFG. # The code is based on `SDFGState.add_nested_sdfg()`. if_block_symbols = if_block.sdfg.symbols - parent_symbols = parent.symbols + parent_symbols = sdfg.symbols for new_sym in missing_symbols: if new_sym in if_block_symbols: # The symbol is already known, so we check that it is the same type as in the @@ -577,49 +589,43 @@ def _check_for_data_and_symbol_conflicts( ) -> bool: """Check if the relocation would cause any conflict, such as a symbol clash.""" - # TODO: remove check ion data renaming but still require symbols. - # TODO(phimuell): There is an obscure case where the nested SDFG, on its own, # defines a symbol that is also mapped, for example a dynamic Map range. # It is probably not a problem, because of the scopes DaCe adds when # generating the C++ code. - # Create a subgraph to compute the free symbols, i.e. the symbols that - # need to be supplied from the outside. However, this are not all. + # This will give us the "internal symbols" that need to be mapped into `if_block`. + # It does not include all symbols, see bellow. requiered_symbols: set[str] = dace.sdfg.state.StateSubgraphView( - state, relocated_dataflow + state, relocatable_dataflow ).free_symbols - inner_data_names = if_block.sdfg.arrays.keys() - for node_to_check in relocated_dataflow: - if ( - isinstance(node_to_check, dace_nodes.AccessNode) - and node_to_check.data in inner_data_names - ): - # There is already a data descriptor that is used on the inside as on - # the outside. Thus we would have to perform some renaming, which we - # currently do not. - # TODO(phimell): Handle this case. - return False - + # The internal symbols missing the symbols that are needed by the nodes that + # are just mapped into the `if_block` as well as the connections that connects + # relocated and mapped nodes. + for node_to_check in relocatable_dataflow: for iedge in state.in_edges(node_to_check): - src_node = iedge.src - if src_node not in relocated_dataflow: - # This means that `src_node` is not relocated but mapped into the - # `if` block. This means that `edge` is replicated as well. - # NOTE: This code is based on the one found in `DataflowGraphView`. - # TODO(phimuell): Do we have to inspect the full Memlet path here? - assert isinstance(src_node, dace_nodes.AccessNode) or src_node is enclosing_map + if iedge.src in relocatable_dataflow: + continue # Ignore internal connections. + + if iedge.src is enclosing_map: + # Outside-Map data must be mapped. Here we only have to consider + # the symbols of the node and can ignore the symbols of the edge. + memlet_path = state.memlet_path(iedge) + node_to_map = memlet_path[0].src + else: + # The mapped node is inside the Map this means we replicate this + # edge thus in addition to the symbols of the data, we need the + # symbols needed by the edge. + node_to_map = iedge.src requiered_symbols |= iedge.data.used_symbols(True, edge=iedge) - # The (beyond the enclosing Map) data is also mapped into the `if` block, so we - # have to consider that as well. - for iedge in state.in_edges(if_block): - if iedge.src is enclosing_map and (not iedge.data.is_empty()): - outside_desc = sdfg.arrays[iedge.data.data] - if isinstance(outside_desc, dace_data.View): - return False # Handle this case. - requiered_symbols |= outside_desc.used_symbols(True) + # Only AccessNodes can be mapped into `if_block`. + if not isinstance(node_to_map, dace_nodes.AccessNode): + return False + + # Add the symbols of the data. + requiered_symbols |= sdfg.arrays[node_to_map.data].used_symbols(True) # A conflicting symbol is a free symbol of the relocatable dataflow, that is not a # direct mapping. For example if there is a symbol `n` on the inside and outside @@ -720,7 +726,7 @@ def _filter_relocatable_dataflow( # These are the nodes that can not be relocated anyway. all_non_relocatable_dataflow: set[dace_nodes.Node] = functools.reduce( - lambda s1, s2: s1.union(s2), all_non_relocatable_dataflow.values(), set() + lambda s1, s2: s1.union(s2), non_relocatable_dataflow.values(), set() ) # While we can relocate nodes that are needed by multiple connectors, we can @@ -728,7 +734,7 @@ def _filter_relocatable_dataflow( nodes_in_states: dict[dace.SDFGState, set[dace_nodes.Node]] = collections.defaultdict(set) for conn_name, rel_df in raw_relocatable_dataflow.items(): nodes_in_states[connector_usage_location[conn_name][0]].update(rel_df) - state_nodes_sets = list(nodes_per_state.values()) # Order is unimportant here. + state_nodes_sets = list(nodes_in_states.values()) # Order is unimportant here. for i, state_nodes in enumerate(state_nodes_sets): for j in range(i + 1, len(state_nodes_sets)): all_non_relocatable_dataflow.update(state_nodes.intersection(state_nodes_sets[j])) @@ -752,9 +758,6 @@ def _filter_relocatable_dataflow( while has_been_updated: has_been_updated = False - # TODO(phimuell): Look at me. - # TODO(reviewer): Make sure I looked at it. - for reloc_node in list(nodes_proposed_for_reloc): # The node was already removed in a previous iteration. if reloc_node not in nodes_proposed_for_reloc: @@ -775,8 +778,6 @@ def _filter_relocatable_dataflow( has_been_updated = True continue - # Empty edges - # We do not look at incoming edges that comes from nodes that are not # mappable, i.e. AccessNodes. In addition to AccessNodes we also # ignore `enclosing_map` because it acts as a boundary anyway and @@ -811,7 +812,7 @@ def _partition_if_block( sdfg: dace.SDFG, if_block: dace_nodes.NestedSDFG, ) -> Optional[ - tuple[list[str], [ststr], dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]]] + tuple[list[str], list[str], dict[str, tuple[dace.SDFGState, dace_nodes.AccessNode]]] ]: """Check if `if_block` can be processed and partition the input connectors. From 39ea68ab41b85252bf52f03a9c6767f9924860c9 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 20 Mar 2026 11:21:55 +0100 Subject: [PATCH 20/32] Some small fixes, but still some issues. --- .../transformations/move_dataflow_into_if_body.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 7607a7cdd2..d8d036f38a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -370,9 +370,9 @@ def _replicate_dataflow_into_branch( connector_usage_location[oedge.dst_conn][1], None, dace.Memlet( - data=rename_map[(oedge.src, branch_state)], + data=rename_map[(oedge.data.data, branch_state)], subset=oedge.data.subset, # Is always subset. - other_subset=dace.Memlet.from_array(inner_sdfg.arrays[oedge.dst_conn]), + other_subset=dace_sbs.Range.from_array(inner_sdfg.arrays[oedge.dst_conn]), volume=oedge.data.volume, dynamic=oedge.data.dynamic, ), @@ -491,6 +491,15 @@ def _replicate_dataflow_into_branch( ) inner_sdfg.arrays[inner_data].transient = False + state.add_edge( + iedge.src, + iedge.src_conn, + if_block, + inner_data, + dace.Memlet.from_array(outer_data, outer_desc), + ) + if_block.add_in_connector(inner_data) + inner_node = branch_state.add_access(inner_data) rename_map[(outer_data, branch_state)] = inner_node.data node_map[(outer_node, branch_state)] = inner_node From 0825a449afb7eadde2168a862ca2333a3a4de25d Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 20 Mar 2026 11:23:50 +0100 Subject: [PATCH 21/32] More fixes but still not passing all the tests. --- .../move_dataflow_into_if_body.py | 119 +++++++++--------- 1 file changed, 61 insertions(+), 58 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index d8d036f38a..d652f3f5da 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -350,62 +350,14 @@ def _replicate_dataflow_into_branch( reloc_node.data = new_data_name rename_map[(origin_node.data, branch_state)] = new_data_name - # We now recreate the edges - # We insert them in the same order as we find them outside. - for origin_node, branch_state in relocation_destination.items(): - for oedge in state.out_edges(origin_node): - if oedge.dst is if_block: - # This defines the "argument" to the nested SDFG. This means that - # the new destination now is the single node inside `if_block` - # that represents the argument. - assert not oedge.data.is_empty() - assert not inner_sdfg.arrays[oedge.dst_conn].transient - assert branch_state is connector_usage_location[oedge.dst_conn][0] - assert isinstance(oedge.src, dace_nodes.AccessNode) - assert oedge.data.wcr is None and oedge.data.other_subset is None - - branch_state.add_edge( - node_map[(oedge.src, branch_state)], - oedge.src_conn, - connector_usage_location[oedge.dst_conn][1], - None, - dace.Memlet( - data=rename_map[(oedge.data.data, branch_state)], - subset=oedge.data.subset, # Is always subset. - other_subset=dace_sbs.Range.from_array(inner_sdfg.arrays[oedge.dst_conn]), - volume=oedge.data.volume, - dynamic=oedge.data.dynamic, - ), - ) - - # The inner data is no longer a global but has become a transient. - assert oedge.dst_conn in if_block.in_connectors - inner_sdfg.arrays[oedge.dst_conn].transient = True - if_block.remove_in_connector(oedge.dst_conn) - - else: - # Edges that do not go to the `if_block` must lead to a node - # that is also relocated. - assert origin_node in relocation_destination - new_oedge = branch_state.add_edge( - node_map[(oedge.src, branch_state)], - oedge.src_conn, - node_map[(oedge.dst, branch_state)], - oedge.dst_conn, - dace.Memlet.from_memlet(oedge.data), - ) - if not oedge.data.is_empty(): - new_oedge.data.data = rename_map[(oedge.data.data, branch_state)] - - # Now we have to satisfy the data dependencies, i.e. forward all nodes that - # could not have been moved inside `if_block` but are still needed to compute - # the final result. We find them by scanning the input edges of the nodes - # that have been relocated. - # TODO(phimuell): Can we merge the two outer loops? + # We now create the mapped nodes, i.e. the nodes that are not relocated but + # have to be put inside the `if_block`. We find them by looking at the input + # edges, that do not lead to a node that is relocated. Connections between + # relocated nodes are handled later. for origin_node, branch_state in relocation_destination.items(): for iedge in state.in_edges(origin_node): if iedge.src in relocation_destination: - # Dependency between two relocated nodes: Handled above. + # Dependency between two relocated nodes: Handled below. continue elif iedge.data.is_empty(): # This is an empty Memlet that is between a node that is relocated @@ -492,11 +444,11 @@ def _replicate_dataflow_into_branch( inner_sdfg.arrays[inner_data].transient = False state.add_edge( - iedge.src, - iedge.src_conn, - if_block, - inner_data, - dace.Memlet.from_array(outer_data, outer_desc), + iedge.src, + iedge.src_conn, + if_block, + inner_data, + dace.Memlet.from_array(outer_data, outer_desc), ) if_block.add_in_connector(inner_data) @@ -515,6 +467,57 @@ def _replicate_dataflow_into_branch( ) new_edge.data.data = rename_map[(outer_data, branch_state)] + # Now create the edges between the relocated nodes, which are all the outgoing + # edges, the `if_block` is handled as a special relocated node and its + # connectors (but not the edges) are removed to. + # NOTE: This loop can not be fused with the one above and must run after it. + for origin_node, branch_state in relocation_destination.items(): + for oedge in state.out_edges(origin_node): + if oedge.dst is if_block: + # This defines the "argument" to the nested SDFG. This means that + # the new destination now is the single node inside `if_block` + # that represents the argument. + assert not oedge.data.is_empty() + assert not inner_sdfg.arrays[oedge.dst_conn].transient + assert branch_state is connector_usage_location[oedge.dst_conn][0] + assert isinstance(oedge.src, dace_nodes.AccessNode) + assert oedge.data.wcr is None and oedge.data.other_subset is None + + branch_state.add_edge( + node_map[(oedge.src, branch_state)], + oedge.src_conn, + connector_usage_location[oedge.dst_conn][1], + None, + dace.Memlet( + data=rename_map[(oedge.data.data, branch_state)], + subset=oedge.data.subset, # Is always subset. + other_subset=dace_sbs.Range.from_array( + inner_sdfg.arrays[oedge.dst_conn] + ), + volume=oedge.data.volume, + dynamic=oedge.data.dynamic, + ), + ) + + # The inner data is no longer a global but has become a transient. + assert oedge.dst_conn in if_block.in_connectors + inner_sdfg.arrays[oedge.dst_conn].transient = True + if_block.remove_in_connector(oedge.dst_conn) + + else: + # Edges that do not go to the `if_block` must lead to a node + # that is also relocated. + assert origin_node in relocation_destination + new_oedge = branch_state.add_edge( + node_map[(oedge.src, branch_state)], + oedge.src_conn, + node_map[(oedge.dst, branch_state)], + oedge.dst_conn, + dace.Memlet.from_memlet(oedge.data), + ) + if not oedge.data.is_empty(): + new_oedge.data.data = rename_map[(oedge.data.data, branch_state)] + def _remove_outside_dataflow( self, sdfg: dace.SDFG, From cb717f9041d7f081b0f23ab64d13d10bfd7a4e50 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 20 Mar 2026 11:40:55 +0100 Subject: [PATCH 22/32] Had to modify a test. --- .../test_move_dataflow_into_if_body.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 12f503da6f..fa6ba8101f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -953,10 +953,12 @@ def test_if_mover_dependent_branch_4(): _perform_test(sdfg, explected_applies=1) - # # Examine the structure of the SDFG. + # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) assert {ac.data for ac in top_ac} == set(input_names).union(["c1", "s1"]) assert len(sdfg.arrays) == len(top_ac) + assert all(state.out_degree(ac) == 1 for ac in [s1, c1]) + assert all(oedge.dst_conn == "__arg4" for oedge in state.out_edges(s1)) top_tlet: list[dace_nodes.Tasklet] = util.count_nodes(state, dace_nodes.Tasklet, True) assert len(top_tlet) == 2 @@ -970,12 +972,12 @@ def test_if_mover_dependent_branch_4(): .union(input_names) .union(["__arg1", "__arg2", "__arg3", "__arg4", "__output1", "__output2"]) ) - expected_data.difference_update(["c1", "c", "d", "f", "s"]) + expected_data.difference_update(["c1", "c", "d", "f", "s", "s1"]) assert expected_data == {ac.data for ac in inner_ac} - assert len([ac for ac in inner_ac if ac.data == "s1"]) == 1 + assert len([ac for ac in inner_ac if ac.data == "__arg4"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output1"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output2"]) == 2 - assert len(expected_data) + 3 == len(inner_ac) + assert len(expected_data) + 4 == len(inner_ac) assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) From 048e2e726567d00643c2f16789850ab42aad2243 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 20 Mar 2026 11:51:44 +0100 Subject: [PATCH 23/32] More changes to the tests. --- .../test_move_dataflow_into_if_body.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index fa6ba8101f..a5e1b165c8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -964,6 +964,9 @@ def test_if_mover_dependent_branch_4(): assert len(top_tlet) == 2 assert {"tasklet_cond", "tasklet_s1"} == {tlet.label for tlet in top_tlet} + all_mapped_in_data = [iedge.data.data for iedge in state.in_edges(if_block)] + assert len(all_mapped_in_data) == len(set(all_mapped_in_data)) + inner_ac: list[dace_nodes.AccessNode] = util.count_nodes( if_block.sdfg, dace_nodes.AccessNode, True ) @@ -1133,19 +1136,22 @@ def test_if_mover_dependent_branch_5(): mx.add_out_connector("OUT_f") sdfg.validate() - # TODO(iomaginaris, phimuell): Why was `explected_applies` set to `2`? This makes - # only sense if we have multiple `if` blocks or do I miss something here? _perform_test(sdfg, explected_applies=1) - # # Examine the structure of the SDFG. + # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) assert {ac.data for ac in top_ac} == set(input_names).union(["c1", "s1"]) assert len(sdfg.arrays) == len(top_ac) + assert all(state.out_degree(ac) == 1 for ac in [s1, c1]) + assert all(oedge.dst_conn == "__arg4" for oedge in state.out_edges(s1)) top_tlet: list[dace_nodes.Tasklet] = util.count_nodes(state, dace_nodes.Tasklet, True) assert len(top_tlet) == 2 assert {"tasklet_cond", "tasklet_s1"} == {tlet.label for tlet in top_tlet} + all_mapped_in_data = [iedge.data.data for iedge in state.in_edges(if_block)] + assert len(all_mapped_in_data) == len(set(all_mapped_in_data)) + inner_ac: list[dace_nodes.AccessNode] = util.count_nodes( if_block.sdfg, dace_nodes.AccessNode, True ) @@ -1154,18 +1160,26 @@ def test_if_mover_dependent_branch_5(): .union(input_names) .union(["__arg1", "__arg2", "__arg3", "__arg4", "__output1", "__output2"]) ) - expected_data.difference_update(["c1", "c", "d", "f", "s"]) + expected_data.difference_update(["c1", "c", "d", "f", "s", "s1"]) assert expected_data == {ac.data for ac in inner_ac} - assert len([ac for ac in inner_ac if ac.data == "s1"]) == 1 + assert len([ac for ac in inner_ac if ac.data == "__arg4"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output1"]) == 2 assert len([ac for ac in inner_ac if ac.data == "__output2"]) == 2 - assert len(expected_data) + 3 == len(inner_ac) + assert len(expected_data) + 4 == len(inner_ac) assert if_block.sdfg.arrays.keys() == expected_data.union(["__cond"]) inner_tlet: list[dace_nodes.Tasklet] = util.count_nodes(if_block.sdfg, dace_nodes.Tasklet, True) - assert len(inner_tlet) == 5 + assert len(inner_tlet) == 6 expected_tlet = { - tlet.label for tlet in [tasklet_a1, tasklet_a2, tasklet_b1, tasklet_b2, tasklet_node_reuse] + tlet.label + for tlet in [ + tasklet_a1, + tasklet_a2, + tasklet_a2a, + tasklet_b1, + tasklet_b2, + tasklet_node_reuse, + ] } assert {tlet.label for tlet in inner_tlet} == expected_tlet From b361a5d273c4b9a3f4d1cbead11f208702d834f7 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Fri, 20 Mar 2026 13:11:54 +0100 Subject: [PATCH 24/32] Now all currently existing tests pass, but we already activated another one, which is good. --- .../transformations/move_dataflow_into_if_body.py | 13 ++++++++++--- .../test_move_dataflow_into_if_body.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index d652f3f5da..e3211d242a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -611,6 +611,7 @@ def _check_for_data_and_symbol_conflicts( requiered_symbols: set[str] = dace.sdfg.state.StateSubgraphView( state, relocatable_dataflow ).free_symbols + assert all(isinstance(sym, str) for sym in requiered_symbols) # The internal symbols missing the symbols that are needed by the nodes that # are just mapped into the `if_block` as well as the connections that connects @@ -618,7 +619,9 @@ def _check_for_data_and_symbol_conflicts( for node_to_check in relocatable_dataflow: for iedge in state.in_edges(node_to_check): if iedge.src in relocatable_dataflow: - continue # Ignore internal connections. + continue # Ignore internal connections, handled in subgraph. + elif iedge.data.is_empty(): + continue # Empty Memlets do not have symbols. if iedge.src is enclosing_map: # Outside-Map data must be mapped. Here we only have to consider @@ -630,14 +633,18 @@ def _check_for_data_and_symbol_conflicts( # edge thus in addition to the symbols of the data, we need the # symbols needed by the edge. node_to_map = iedge.src - requiered_symbols |= iedge.data.used_symbols(True, edge=iedge) + requiered_symbols |= { + str(sym) for sym in iedge.data.used_symbols(True, edge=iedge) + } # Only AccessNodes can be mapped into `if_block`. if not isinstance(node_to_map, dace_nodes.AccessNode): return False # Add the symbols of the data. - requiered_symbols |= sdfg.arrays[node_to_map.data].used_symbols(True) + requiered_symbols |= { + str(sym) for sym in sdfg.arrays[node_to_map.data].used_symbols(True) + } # A conflicting symbol is a free symbol of the relocatable dataflow, that is not a # direct mapping. For example if there is a symbol `n` on the inside and outside diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index a5e1b165c8..6d8d6c262b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -1710,8 +1710,8 @@ def test_if_mover_symbolic_tasklet(): sdfg, explected_applies=1, ) - expected_symb = {"symbol_1", "symbol_2"} + expected_symb = {"symbol_1", "symbol_2"} assert if_block.sdfg.symbols.keys() == expected_symb.union(["__i"]) assert all(if_block.sdfg.symbols[sym] == dace.float64 for sym in expected_symb) assert if_block.sdfg.symbols["__i"] in {dace.int32, dace.int64} From 3819e017290c06e71ddc95bbcda8ebf0c80740af Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Mon, 23 Mar 2026 10:37:31 +0100 Subject: [PATCH 25/32] Undid some optimizations. --- .../runners/dace/transformations/utils.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 5aeb5188fa..5bf4c90db2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -8,6 +8,7 @@ """Common functionality for the transformations/optimization pipeline.""" +import uuid from typing import Iterable, Optional, Sequence, TypeVar, Union import dace @@ -24,23 +25,15 @@ def unique_name(name: str) -> str: """Adds a unique string to `name`. Note: - This function assumes that the "namespace" defined by `__gt4py_unique_name_` - can be used freely. + The names generates by this function are rather unstable and it should + not be used if a particular order should be enforced. This function is + marked for deprecation. """ - - # TODO(phimuell, reviewer): How does this behaves in multiple process scenarios? - # It should be okay as long as the update is atomic, I would say. maximal_length = 200 - if not hasattr(unique_name, "_counter"): - unique_name._counter = 0 # type: ignore[attr-defined] - - unique_name._counter += 1 # type: ignore[attr-defined] - proposed_name = f"__gt4py_unique_name_{name}_{unique_name._counter}" # type: ignore[attr-defined] - - if len(proposed_name) > maximal_length: - raise ValueError("Name became too long.") - - return proposed_name + unique_sufix = str(uuid.uuid1()).replace("-", "_") + if len(name) > (maximal_length - len(unique_sufix)): + name = name[: (maximal_length - len(unique_sufix) - 1)] + return f"{name}_{unique_sufix}" def gt_make_transients_persistent( From ebbf8ac8876a9458e43c9ca4e06966fa2c949063 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Mon, 23 Mar 2026 13:30:45 +0100 Subject: [PATCH 26/32] Small notes. --- .../program_processors/runners/dace/transformations/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 5bf4c90db2..bf4e86face 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -869,6 +869,7 @@ def order_nodes( # Describes when it works and when not. def key_fun(node: dace_nodes.Node) -> tuple[str, str, int, int]: + # TODO(phimuell): Deprecate once [DaCe PR#2320](https://github.com/spcl/dace/pull/2320) is in. if isinstance(node, dace_nodes.AccessNode): nid: str = node.data elif hasattr(node, "label"): @@ -894,6 +895,8 @@ def order_edges( which means their label in most cases. """ + # TODO(phimuell): Deprecate once [DaCe PR#2320](https://github.com/spcl/dace/pull/2320) is in. + # This is probably the best way to sort edge, because it considers the source and # destination node, as tie breaker the connectors are used and as second level of # tie breaker the subsets are used. This means that the specialization level From f3f3cc4e54c061df84880713fe20ae4728058429 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Wed, 25 Mar 2026 10:08:53 +0100 Subject: [PATCH 27/32] Small fixes. Implemented a better sorting mechanism and also handled one case of empty Memlets, however, not all are handled yet. --- .../move_dataflow_into_if_body.py | 29 +++++++--- .../runners/dace/transformations/utils.py | 57 +------------------ 2 files changed, 22 insertions(+), 64 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index e3211d242a..c56c7eaa26 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -231,10 +231,11 @@ def apply( enclosing_map=enclosing_map, ) - # We have to bring the nodes in a deterministic order. - nodes_to_move: list[dace_nodes.Node] = gtx_transformations.utils.order_nodes( - relocatable_dataflow, graph - ) + # Bring the nodes in a deterministic order, which is induced by the underlying state. + # NOTE: The following key function is equivalent to use `lambda n: graph.node_id(n)` + # but instead of O[N^2] it is O[N]. + node_keys = {node: i for i, node in enumerate(graph.nodes())} + nodes_to_move = sorted(relocatable_dataflow, key=lambda n: node_keys[n]) # For each node we have to find out in which state inside the `if_block` it will # end up. `relocation_destination` has a fixed order. @@ -477,7 +478,6 @@ def _replicate_dataflow_into_branch( # This defines the "argument" to the nested SDFG. This means that # the new destination now is the single node inside `if_block` # that represents the argument. - assert not oedge.data.is_empty() assert not inner_sdfg.arrays[oedge.dst_conn].transient assert branch_state is connector_usage_location[oedge.dst_conn][0] assert isinstance(oedge.src, dace_nodes.AccessNode) @@ -758,6 +758,21 @@ def _filter_relocatable_dataflow( for j in range(i + 1, len(state_nodes_sets)): all_non_relocatable_dataflow.update(state_nodes.intersection(state_nodes_sets[j])) + # The dataflow that must happen before the `if_block`, i.e that is connected + # with it by an empty Memlet can not be reconnected. + for if_block_iedge in state.in_edges(if_block): + if if_block_iedge.src is enclosing_map: + continue + elif not if_block_iedge.data.is_empty(): + continue + all_non_relocatable_dataflow.update( + gtx_transformations.utils.find_upstream_nodes( + start=if_block_iedge.src, + state=state, + ) + ) + all_non_relocatable_dataflow.add(if_block_iedge.src) + # Instead of scanning the nodes associated to each connector separately we will # process all of them together. We do this because a node can be associated to # multiple connectors and as such data dependencies can show up. We will, @@ -770,9 +785,7 @@ def _filter_relocatable_dataflow( if all_non_relocatable_dataflow: nodes_proposed_for_reloc.difference_update(all_non_relocatable_dataflow) - # TODO(phimuell): If we operate outside of a Map we also have to make sure that - # the data is single use data, is not an AccessNode that refers to global - # memory nor is a source AccessNode. + # TODO(phimuell): Better screening of empty Memlets. has_been_updated = True while has_been_updated: has_been_updated = False diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index bf4e86face..68a7c33201 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -9,7 +9,7 @@ """Common functionality for the transformations/optimization pipeline.""" import uuid -from typing import Iterable, Optional, Sequence, TypeVar, Union +from typing import Optional, Sequence, TypeVar, Union import dace from dace import data as dace_data, libraries as dace_lib, subsets as dace_sbs, symbolic as dace_sym @@ -847,58 +847,3 @@ def gt_data_descriptor_mapping( name_mapping[data_inside] = data_outside return name_mapping - - -def order_nodes( - nodes: Iterable[dace_nodes.Node], - state: dace.SDFGState, -) -> list[dace_nodes.Node]: - """_Tries_ to order `nodes` in a stable and deterministic way. - - The result should be considered as the best way to order a group of nodes inside - a state. It does, however, not guarantees a stable order in any way. - - The condition this function works best is if the nodes have a unique labels - and AccessNodes referring to the same data have different degrees (one node used - for reading one for writing). - - Known pathological cases: - - Multiple top level AccessNodes referring to the same data that have the same - degree (uncommon in GT4Py). - """ - - # Describes when it works and when not. - def key_fun(node: dace_nodes.Node) -> tuple[str, str, int, int]: - # TODO(phimuell): Deprecate once [DaCe PR#2320](https://github.com/spcl/dace/pull/2320) is in. - if isinstance(node, dace_nodes.AccessNode): - nid: str = node.data - elif hasattr(node, "label"): - # TODO(phimuell): Maybe add `node.code` in case of Tasklets? - nid = node.label - else: - nid = str(node) - - return (type(node).__name__, nid, state.in_degree(node), state.out_degree(node)) - - return sorted(nodes, key=key_fun) - - -def order_edges( - edges: Iterable[dace_graph.MultiConnectorEdge[dace.Memlet]], -) -> list[dace_graph.MultiConnectorEdge[dace.Memlet]]: - """_Tries_ to order `edges` in a stable and deterministic way. - - There is no guarantee of a stable order, although it should be consider more stable - than the one generated by `order_nodes()`. However, the order might depends on the - selected specialization level. Similar to `order_nodes()` this function works best - if the string (not serialization) representation of the involved nodes is unique, - which means their label in most cases. - """ - - # TODO(phimuell): Deprecate once [DaCe PR#2320](https://github.com/spcl/dace/pull/2320) is in. - - # This is probably the best way to sort edge, because it considers the source and - # destination node, as tie breaker the connectors are used and as second level of - # tie breaker the subsets are used. This means that the specialization level - # is also involved because the subset `[a, b]` is different from `[1, 10]`. - return sorted(edges, key=str) From 13f44b509c7eaca455085eeca1cf3d1c8be79d68 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Wed, 25 Mar 2026 10:46:36 +0100 Subject: [PATCH 28/32] Small update --- .../move_dataflow_into_if_body.py | 10 -- .../test_move_dataflow_into_if_body.py | 136 ++++++++++++++++++ 2 files changed, 136 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index c56c7eaa26..66e3bf5fd4 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -32,16 +32,6 @@ class MoveDataflowIntoIfBody(dace_transformation.SingleStateTransformation): """The transformation moves dataflow into the if branches. - ## TODO ## - - Slicing unit test - - Expending the unit test `_4` that Ioannis made such that it `__arg1` also has - something to relocate. - - Fix the naming issue thing in the if fuser. - - Make the `xfail` test run. - - Make Test where something is used in multiple branches. - - Check if symbol renaming is okay. - - test if map outside data is sliced into the if block. - Essentially transforms code from this ```python __arg1 = foo(...) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 6d8d6c262b..7ef95afee9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -1968,3 +1968,139 @@ def test_if_mover_symbol_aliasing(): sdfg=sdfg, explected_applies=0, ) + + +def test_if_mover_slice_input(): + def _make_nested_sdfg(cond_name: str, iter_name: str) -> dace.SDFG: + sdfg = dace.SDFG("If_block") + + sdfg.add_scalar("arg1", dtype=dace.float64, transient=False) + sdfg.add_scalar("out", dtype=dace.float64, transient=False) + sdfg.add_scalar(cond_name, dtype=dace.bool_, transient=False) + sdfg.add_array("arg2", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_symbol(iter_name, stype=dace.int32) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + tstate.add_edge( + tstate.add_access("arg1"), + None, + tstate.add_access("out"), + None, + dace.Memlet("arg1[0] -> [0]"), + ) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + f_tasklet = fstate.add_tasklet( + "f_tasklet", inputs={"__in"}, outputs={"__out"}, code="__out = __in + 1.0" + ) + fstate.add_edge( + fstate.add_access("arg2"), None, f_tasklet, "__in", dace.Memlet(f"arg2[{iter_name}]") + ) + fstate.add_edge(f_tasklet, "__out", fstate.add_access("out"), None, dace.Memlet("out[0]")) + + if_region = dace.sdfg.state.ConditionalBlock(gtx_transformations.utils.unique_name("if")) + sdfg.add_node(if_region, is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock(cond_name), then_body) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"not {cond_name}"), else_body) + + sdfg.validate() + return sdfg + + def _make_outer_sdfg( + cond_name: str, iter_name: str + ) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG]: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_slicing")) + state = sdfg.add_state(is_start_block=True) + + # Inputs + input_names = list("abcd") + for name in input_names: + sdfg.add_array( + name, + shape=((10, 10) if name.startswith("b") else (10,)), + dtype=dace.float64, + transient=False, + ) + + # Temporaries + temporary_names = ["a1", "c1"] + for name in temporary_names: + sdfg.add_scalar( + name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True + ) + a1, c1 = (state.add_access(name) for name in temporary_names) + + me, mx = state.add_map("map", ndrange={iter_name: "0:10"}) + for name in input_names[:-1]: + state.add_edge( + state.add_access(name), + None, + me, + f"IN_{name}", + dace.Memlet(data=name, subset=("0:10, 0:10" if name == "b" else "0:10")), + ) + me.add_scope_connectors(name) + + state.add_edge( + mx, "OUT_d", state.add_access("d"), None, dace.Memlet(data="d", subset="0:10") + ) + mx.add_scope_connectors("d") + + # First branch. + tasklet_a1 = state.add_tasklet( + "tasklet_a1", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code="__out = __in1 + __in2", + ) + + state.add_edge(me, "OUT_a", tasklet_a1, "__in1", dace.Memlet(f"a[{iter_name}]")) + state.add_edge( + me, "OUT_b", tasklet_a1, "__in2", dace.Memlet(f"b[{iter_name}, {iter_name}]") + ) + state.add_edge(tasklet_a1, "__out", a1, None, dace.Memlet("a1[0]")) + + # Second branch + # There is nothing. + + # Condition + tasklet_c1 = state.add_tasklet( + "tasklet_c1", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in > 0.5", + ) + state.add_edge(me, "OUT_c", tasklet_c1, "__in", dace.Memlet(f"c[{iter_name}]")) + state.add_edge(tasklet_c1, "__out", c1, None, dace.Memlet("c1[0]")) + + # Nested SDFG + nsdfg = state.add_nested_sdfg( + sdfg=_make_nested_sdfg(cond_name=cond_name, iter_name=iter_name), + inputs={"arg1", "arg2", cond_name}, + outputs={"out"}, + symbol_mapping={iter_name: iter_name}, + ) + state.add_edge(a1, None, nsdfg, "arg1", dace.Memlet("a1[0]")) + state.add_edge(me, "OUT_b", nsdfg, "arg2", dace.Memlet(f"b[{iter_name}, 0:10]")) + state.add_edge(c1, None, nsdfg, cond_name, dace.Memlet("c1[0]")) + state.add_edge(nsdfg, "out", mx, "IN_d", dace.Memlet(f"d[{iter_name}]")) + + sdfg.validate() + return sdfg, state, nsdfg + + iter_name = "__i" + cond_name = "cond" + + sdfg, state, nsdfg = _make_outer_sdfg(cond_name=cond_name, iter_name=iter_name) + + _perform_test( + sdfg=sdfg, + explected_applies=1, + ) + + assert False, "Add structural checks." + assert False, ( + "Add checks where `b` is copied fully inside the Map scope and then sliced into the `if_block`." + ) From caf39f36798b510ffb7ef2306208a7288f289eee Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Thu, 26 Mar 2026 11:39:53 +0100 Subject: [PATCH 29/32] Added tests. --- .../move_dataflow_into_if_body.py | 5 +++- .../test_move_dataflow_into_if_body.py | 25 ++++++++++++++++--- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 66e3bf5fd4..7f69bcd547 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -308,7 +308,10 @@ def _replicate_dataflow_into_branch( node_map: dict[tuple[dace_nodes.Node, dace.SDFGState], dace_nodes.Node] = dict() rename_map: dict[tuple[str, dace.SDFGState], str] = dict() - # Data that has been fully mapped into the `if_block` and its name inside it. + # Check what data is already fully mapped into the `if_block`. There might be + # aliasing, i.e. multiple inner names refer to the same outer name. + # TODO(phimuell): Investigate if it would be better if we handle partially + # mapped in data such by fully map it in and perform the slicing outside. fully_mapped_in_data: dict[str, set[str]] = collections.defaultdict(set) for if_iedge in state.in_edges(if_block): if if_iedge.data.is_empty(): diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 7ef95afee9..4e044e5e71 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -2095,12 +2095,29 @@ def _make_outer_sdfg( sdfg, state, nsdfg = _make_outer_sdfg(cond_name=cond_name, iter_name=iter_name) + tlet_before = util.count_nodes(sdfg, dace_nodes.Tasklet, True) + + assert iter_name in nsdfg.symbol_mapping + assert len(tlet_before) == 2 + assert {"tasklet_a1", "tasklet_c1"} == {tlet.label for tlet in tlet_before} + assert state.in_degree(nsdfg) == 3 + assert {"a1", "c1", "b"} == {ie.data.data for ie in state.in_edges(nsdfg)} + assert "b" not in nsdfg.in_connectors + _perform_test( sdfg=sdfg, explected_applies=1, ) - assert False, "Add structural checks." - assert False, ( - "Add checks where `b` is copied fully inside the Map scope and then sliced into the `if_block`." - ) + tlet_after = util.count_nodes(sdfg, dace_nodes.Tasklet, True) + + assert len(tlet_after) == 1 + assert tlet_after[0].label == "tasklet_c1" + assert "b" in nsdfg.in_connectors + assert state.in_degree(nsdfg) == 4 + assert {"c1", "b", "a"} == {ie.data.data for ie in state.in_edges(nsdfg)} + assert sum(1 for ie in state.in_edges(nsdfg) if ie.data.data == "a") == 1 + + # It would be possible to have only one edge that goes into the nested SDFG. + # However, we would need to perform some more modifications. + assert sum(1 for ie in state.in_edges(nsdfg) if ie.data.data == "b") == 2 From 12bcb26773bf69c92fe79c630b87e7298de48529 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Thu, 26 Mar 2026 11:54:43 +0100 Subject: [PATCH 30/32] Updated some tests. --- .../test_move_dataflow_into_if_body.py | 64 +++++++++++++++---- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 4e044e5e71..1f9d1d3a25 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -1970,7 +1970,8 @@ def test_if_mover_symbol_aliasing(): ) -def test_if_mover_slice_input(): +@pytest.mark.parametrize("outer_slice_variable", [True, False]) +def test_if_mover_slice_input(outer_slice_variable: bool): def _make_nested_sdfg(cond_name: str, iter_name: str) -> dace.SDFG: sdfg = dace.SDFG("If_block") @@ -2009,7 +2010,9 @@ def _make_nested_sdfg(cond_name: str, iter_name: str) -> dace.SDFG: return sdfg def _make_outer_sdfg( - cond_name: str, iter_name: str + cond_name: str, + iter_name: str, + outer_slice_variable: bool, ) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG]: sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_slicing")) state = sdfg.add_state(is_start_block=True) @@ -2063,7 +2066,12 @@ def _make_outer_sdfg( state.add_edge(tasklet_a1, "__out", a1, None, dace.Memlet("a1[0]")) # Second branch - # There is nothing. + if outer_slice_variable: + sdfg.add_array("slice_b", shape=(10,), dtype=dace.float64, transient=True) + slice_b = state.add_access("slice_b") + state.add_edge( + me, "OUT_b", slice_b, None, dace.Memlet(f"b[{iter_name}, 0:10] -> [0:10]") + ) # Condition tasklet_c1 = state.add_tasklet( @@ -2083,8 +2091,13 @@ def _make_outer_sdfg( symbol_mapping={iter_name: iter_name}, ) state.add_edge(a1, None, nsdfg, "arg1", dace.Memlet("a1[0]")) - state.add_edge(me, "OUT_b", nsdfg, "arg2", dace.Memlet(f"b[{iter_name}, 0:10]")) state.add_edge(c1, None, nsdfg, cond_name, dace.Memlet("c1[0]")) + + if outer_slice_variable: + state.add_edge(slice_b, None, nsdfg, "arg2", dace.Memlet("slice_b[0:10]")) + else: + state.add_edge(me, "OUT_b", nsdfg, "arg2", dace.Memlet(f"b[{iter_name}, 0:10]")) + state.add_edge(nsdfg, "out", mx, "IN_d", dace.Memlet(f"d[{iter_name}]")) sdfg.validate() @@ -2093,7 +2106,9 @@ def _make_outer_sdfg( iter_name = "__i" cond_name = "cond" - sdfg, state, nsdfg = _make_outer_sdfg(cond_name=cond_name, iter_name=iter_name) + sdfg, state, nsdfg = _make_outer_sdfg( + cond_name=cond_name, iter_name=iter_name, outer_slice_variable=outer_slice_variable + ) tlet_before = util.count_nodes(sdfg, dace_nodes.Tasklet, True) @@ -2101,7 +2116,9 @@ def _make_outer_sdfg( assert len(tlet_before) == 2 assert {"tasklet_a1", "tasklet_c1"} == {tlet.label for tlet in tlet_before} assert state.in_degree(nsdfg) == 3 - assert {"a1", "c1", "b"} == {ie.data.data for ie in state.in_edges(nsdfg)} + assert {"a1", "c1", ("slice_b" if outer_slice_variable else "b")} == { + ie.data.data for ie in state.in_edges(nsdfg) + } assert "b" not in nsdfg.in_connectors _perform_test( @@ -2110,14 +2127,37 @@ def _make_outer_sdfg( ) tlet_after = util.count_nodes(sdfg, dace_nodes.Tasklet, True) + inner_ac = util.count_nodes(nsdfg.sdfg, dace_nodes.AccessNode, True) assert len(tlet_after) == 1 assert tlet_after[0].label == "tasklet_c1" assert "b" in nsdfg.in_connectors - assert state.in_degree(nsdfg) == 4 - assert {"c1", "b", "a"} == {ie.data.data for ie in state.in_edges(nsdfg)} - assert sum(1 for ie in state.in_edges(nsdfg) if ie.data.data == "a") == 1 - # It would be possible to have only one edge that goes into the nested SDFG. - # However, we would need to perform some more modifications. - assert sum(1 for ie in state.in_edges(nsdfg) if ie.data.data == "b") == 2 + if outer_slice_variable: + assert nsdfg.sdfg.arrays["arg1"].transient + assert nsdfg.sdfg.arrays["arg2"].transient + assert nsdfg.sdfg.arrays["slice_b"].transient + assert not nsdfg.sdfg.arrays["b"].transient + assert not nsdfg.sdfg.arrays["a"].transient + + # In this case the slicing was done inside. + assert state.in_degree(nsdfg) == 3 + assert {"c1", "b", "a"} == {ie.data.data for ie in state.in_edges(nsdfg)} + + assert len([ac for ac in inner_ac if ac.data == "b"]) == 2 + assert len([ac for ac in inner_ac if ac.data == "slice_b"]) == 1 + + else: + assert nsdfg.sdfg.arrays["arg1"].transient + assert not nsdfg.sdfg.arrays["arg2"].transient + assert not nsdfg.sdfg.arrays["b"].transient + assert not nsdfg.sdfg.arrays["a"].transient + + assert state.in_degree(nsdfg) == 4 + assert {"c1", "b", "a"} == {ie.data.data for ie in state.in_edges(nsdfg)} + assert sum(1 for ie in state.in_edges(nsdfg) if ie.data.data == "a") == 1 + + # It would be possible to have only one edge that goes into the nested SDFG. + # However, we would need to perform some more modifications. + assert len([ie for ie in state.in_edges(nsdfg) if ie.data.data == "b"]) == 2 + assert len([ac for ac in inner_ac if ac.data == "b"]) == 1 From ef0df6498dbed0620ab1455301b3fe0fcfbe11b3 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Sun, 19 Apr 2026 13:27:41 +0200 Subject: [PATCH 31/32] Applied comments. --- .../move_dataflow_into_if_body.py | 29 ++++++------ .../test_move_dataflow_into_if_body.py | 46 +++++++++---------- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 7f69bcd547..f339ba9f8a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -131,7 +131,8 @@ def can_be_applied( relocatable_connectors, non_relocatable_connectors, connector_usage_location = if_block_spec # Compute the dataflow that is relocated. - # NOTE: That the nodes sets are not sorted in any way, however, the + # NOTE: That the nodes sets are not sorted in any way, instead we will sort + # them before we iterate over them. raw_relocatable_dataflow, non_relocatable_dataflow = ( { conn_name: gtx_transformations.utils.find_upstream_nodes( @@ -211,7 +212,7 @@ def apply( } for conn_names in [relocatable_connectors, non_relocatable_connectors] ) - relocatable_dataflow: set[dace_nodes.Node] = self._filter_relocatable_dataflow( + relocatable_dataflow = self._filter_relocatable_dataflow( sdfg=sdfg, state=graph, if_block=if_block, @@ -239,7 +240,7 @@ def apply( target_state = connector_usage_location[conn][0] break else: - raise ValueError("Could not find node '{node_to_move}'") + raise ValueError(f"Could not find node '{node_to_move}'") assert target_state is not None relocation_destination[node_to_move] = target_state @@ -254,10 +255,7 @@ def apply( ) # Must be performed after relocation. - self._update_symbol_mapping( - sdfg=sdfg, - if_block=if_block, - ) + self._update_symbol_mapping(sdfg, if_block) self._remove_outside_dataflow( sdfg=sdfg, @@ -296,7 +294,7 @@ def _replicate_dataflow_into_branch( state: The state we operate on, the one that contains `if_block`. if_block: The `if_block` into which we inline. enclosing_map: The enclosing map. - nodes_to_move: The list of nodes that should be moved. + relocation_destination: Maps nodes to the states where they should be relocated. connector_usage_location: Maps connector names to the state and AccessNode where they appear inside the nested SDFG. """ @@ -335,7 +333,6 @@ def _replicate_dataflow_into_branch( # `_check_for_data_and_symbol_conflicts()`. if isinstance(origin_node, dace_nodes.AccessNode): assert sdfg.arrays[origin_node.data].transient - # TODO(phimuell): Handle the case we need to rename something. new_data_name = inner_sdfg.add_datadesc( origin_node.data, sdfg.arrays[origin_node.data].clone(), @@ -601,10 +598,10 @@ def _check_for_data_and_symbol_conflicts( # This will give us the "internal symbols" that need to be mapped into `if_block`. # It does not include all symbols, see bellow. - requiered_symbols: set[str] = dace.sdfg.state.StateSubgraphView( + required_symbols: set[str] = dace.sdfg.state.StateSubgraphView( state, relocatable_dataflow ).free_symbols - assert all(isinstance(sym, str) for sym in requiered_symbols) + assert all(isinstance(sym, str) for sym in required_symbols) # The internal symbols missing the symbols that are needed by the nodes that # are just mapped into the `if_block` as well as the connections that connects @@ -626,8 +623,8 @@ def _check_for_data_and_symbol_conflicts( # edge thus in addition to the symbols of the data, we need the # symbols needed by the edge. node_to_map = iedge.src - requiered_symbols |= { - str(sym) for sym in iedge.data.used_symbols(True, edge=iedge) + required_symbols |= { + str(sym) for sym in iedge.data.used_symbols(all_symbols=True, edge=iedge) } # Only AccessNodes can be mapped into `if_block`. @@ -635,8 +632,8 @@ def _check_for_data_and_symbol_conflicts( return False # Add the symbols of the data. - requiered_symbols |= { - str(sym) for sym in sdfg.arrays[node_to_map.data].used_symbols(True) + required_symbols |= { + str(sym) for sym in sdfg.arrays[node_to_map.data].used_symbols(all_symbols=True) } # A conflicting symbol is a free symbol of the relocatable dataflow, that is not a @@ -644,7 +641,7 @@ def _check_for_data_and_symbol_conflicts( # then everything is okay if the symbol mapping is `{n: n}` i.e. the symbol has the # same meaning inside and outside. Everything else is not okay. symbol_mapping = if_block.symbol_mapping - conflicting_symbols = requiered_symbols.intersection((str(k) for k in symbol_mapping)) + conflicting_symbols = required_symbols.intersection((str(k) for k in symbol_mapping)) for conflicting_symbol in conflicting_symbols: if conflicting_symbol != str(symbol_mapping[conflicting_symbol]): return False diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 1f9d1d3a25..f5b93da9ee 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -217,13 +217,13 @@ def _make_if_block_with_two_args( def _perform_test( sdfg: dace.SDFG, - explected_applies: int, + expected_applies: int, if_block: Optional[dace_nodes.NestedSDFG] = None, ) -> dace.SDFG: if if_block is not None: # The test should be applied in a specific location. - assert 0 <= explected_applies <= 1 - can_be_applied_ref = explected_applies != 0 + assert 0 <= expected_applies <= 1 + can_be_applied_ref = expected_applies != 0 can_be_applied_res = gtx_transformations.MoveDataflowIntoIfBody.can_be_applied_to( sdfg=sdfg, if_block=if_block, @@ -232,7 +232,7 @@ def _perform_test( return sdfg # General case, run the SDFG first and then compare the result. - if explected_applies != 0: + if expected_applies != 0: ref, res = util.make_sdfg_args(sdfg) util.compile_and_run_sdfg(sdfg, **ref) @@ -241,9 +241,9 @@ def _perform_test( validate=True, validate_all=True, ) - assert nb_apply == explected_applies + assert nb_apply == expected_applies - if explected_applies == 0: + if expected_applies == 0: return sdfg util.compile_and_run_sdfg(sdfg, **res) @@ -345,7 +345,7 @@ def test_if_mover_independent_branches(): mx.add_out_connector("OUT_d") sdfg.validate() - _perform_test(sdfg, explected_applies=1) + _perform_test(sdfg, expected_applies=1) # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) @@ -467,7 +467,7 @@ def test_if_mover_invalid_if_block(): mx.add_out_connector("OUT_d") sdfg.validate() - _perform_test(sdfg, explected_applies=0) + _perform_test(sdfg, expected_applies=0) def test_if_mover_dependent_branch_1(): @@ -582,7 +582,7 @@ def test_if_mover_dependent_branch_1(): mx.add_out_connector("OUT_d") sdfg.validate() - _perform_test(sdfg, explected_applies=1) + _perform_test(sdfg, expected_applies=1) # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) @@ -695,7 +695,7 @@ def test_if_mover_dependent_branch_2(): mx.add_in_connector("IN_d") mx.add_out_connector("OUT_d") - _perform_test(sdfg, explected_applies=1) + _perform_test(sdfg, expected_applies=1) # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) @@ -801,7 +801,7 @@ def test_if_mover_dependent_branch_3(): assert util.count_nodes(state, dace_nodes.MapEntry) == 2 assert util.count_nodes(state, dace_nodes.AccessNode) == 7 - _perform_test(sdfg, explected_applies=1) + _perform_test(sdfg, expected_applies=1) # It is unspecific if `IN_b1` or `IN_b2` remains, but `b` should have only one connector. assert any(iconn in me.in_connectors for iconn in ["IN_b1", "IN_b2"]) @@ -951,7 +951,7 @@ def test_if_mover_dependent_branch_4(): mx.add_out_connector("OUT_f") sdfg.validate() - _perform_test(sdfg, explected_applies=1) + _perform_test(sdfg, expected_applies=1) # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) @@ -1136,7 +1136,7 @@ def test_if_mover_dependent_branch_5(): mx.add_out_connector("OUT_f") sdfg.validate() - _perform_test(sdfg, explected_applies=1) + _perform_test(sdfg, expected_applies=1) # Examine the structure of the SDFG. top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) @@ -1294,7 +1294,7 @@ def test_if_mover_dependent_branch_6(): assert util.count_nodes(state, dace_nodes.MapEntry) == 2 assert util.count_nodes(state, dace_nodes.AccessNode) == 9 - _perform_test(sdfg, explected_applies=1) + _perform_test(sdfg, expected_applies=1) # Simplify the SDFG to remove double `b` AccessNode in the false branch. sdfg.simplify() @@ -1400,7 +1400,7 @@ def test_if_mover_no_ops(): sdfg.validate() # This might change if we will move the read fully inside the branches. - _perform_test(sdfg, explected_applies=0) + _perform_test(sdfg, expected_applies=0) def test_if_mover_one_branch_is_nothing(): @@ -1483,7 +1483,7 @@ def test_if_mover_one_branch_is_nothing(): mx.add_out_connector("OUT_d") sdfg.validate() - _perform_test(sdfg, explected_applies=1) + _perform_test(sdfg, expected_applies=1) top_ac: list[dace_nodes.AccessNode] = util.count_nodes(state, dace_nodes.AccessNode, True) assert {ac.data for ac in top_ac} == set(input_names).union(["c1"]) @@ -1605,14 +1605,14 @@ def test_if_mover_chain(): # because it is limited by the top one. _perform_test( sdfg, - explected_applies=0, + expected_applies=0, if_block=bot_if_block, ) # But we are able to inline both. _perform_test( sdfg, - explected_applies=2, + expected_applies=2, ) @@ -1708,7 +1708,7 @@ def test_if_mover_symbolic_tasklet(): sdfg = _perform_test( sdfg, - explected_applies=1, + expected_applies=1, ) expected_symb = {"symbol_1", "symbol_2"} @@ -1823,14 +1823,14 @@ def test_if_mover_access_node_between(): # block that in turn has dataflow that could be relocated. _perform_test( sdfg, - explected_applies=0, + expected_applies=0, if_block=bot_if_block, ) # But we are able to process them that way, starting from the bottom. _perform_test( sdfg, - explected_applies=2, + expected_applies=2, ) expected_top_level_data: set[str] = {"a", "b", "c", "d", "e", "f", "c2"} @@ -1966,7 +1966,7 @@ def test_if_mover_symbol_aliasing(): # to account for the access on the Memlets of the `{true, false}_tlet`. _perform_test( sdfg=sdfg, - explected_applies=0, + expected_applies=0, ) @@ -2123,7 +2123,7 @@ def _make_outer_sdfg( _perform_test( sdfg=sdfg, - explected_applies=1, + expected_applies=1, ) tlet_after = util.count_nodes(sdfg, dace_nodes.Tasklet, True) From 2615c65340f3b9c6d1aa878a7422897323580840 Mon Sep 17 00:00:00 2001 From: "Philip Mueller, CSCS" Date: Mon, 4 May 2026 07:53:25 +0200 Subject: [PATCH 32/32] Applied suggestion. --- .../dace/transformations/move_dataflow_into_if_body.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index f339ba9f8a..d7f7a36b96 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -195,9 +195,9 @@ def apply( ) -> None: if_block: dace_nodes.NestedSDFG = self.if_block enclosing_map = graph.scope_dict()[if_block] - relocatable_connectors, non_relocatable_connectors, connector_usage_location = ( - self._partition_if_block(sdfg, if_block) # type: ignore[misc] # Guaranteed to be not None. - ) + partition_res = self._partition_if_block(sdfg, if_block) + assert partition_res is not None + relocatable_connectors, non_relocatable_connectors, connector_usage_location = partition_res # Find the dataflow that should be relocated. raw_relocatable_dataflow, non_relocatable_dataflow = (