diff --git a/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py b/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py new file mode 100644 index 0000000000..d0fe74152b --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py @@ -0,0 +1,30 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Final + +from dace import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace.library_nodes.broadcast import ( + Broadcast, + inplace_broadcast_expander, +) +from gt4py.next.program_processors.runners.dace.library_nodes.reduce_with_skip_values import ( + ReduceWithSkipValues, +) + + +GTIR_LIBRARY_NODES: Final[tuple[dace_nodes.LibraryNode, ...]] = (Broadcast, ReduceWithSkipValues) +"""List of available GTIR library nodes.""" + + +__all__ = [ + "Broadcast", + "ReduceWithSkipValues", + "inplace_broadcast_expander", +] diff --git a/src/gt4py/next/program_processors/runners/dace/library_nodes/broadcast.py b/src/gt4py/next/program_processors/runners/dace/library_nodes/broadcast.py new file mode 100644 index 0000000000..aff018ed66 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/library_nodes/broadcast.py @@ -0,0 +1,302 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import copy +from typing import Any, Final, Iterable, Optional, Sequence + +import dace +from dace import ( + data as dace_data, + library as dace_library, + nodes as dace_nodes, + properties as dace_properties, + subsets as dace_sbs, +) +from dace.sdfg import graph as dace_graph +from dace.transformation import transformation as dace_transform + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace import lowering as gtx_dace_lowering + + +_INPUT_NAME: Final[str] = "_inp" +_OUTPUT_NAME: Final[str] = "_outp" + + +@dace_library.node +class Broadcast(dace_nodes.LibraryNode): + """Implements write of a scalar value over an array subset. + + Same as XLA. + broadcast_in_dim[i] describes where dimension `i` of the `value_to_broadcast` + goes. In case of a scalar it is empty. + Furthermore the following has to hold: + ```python + for i in range(len(broadcast_in_dim): + assert output.shape[broadcast_in_dim[i]] == value_to_broadcast.shape[i] + ``` + + Args: + broadcast_in_dim: How to broadcast. + params: The parameters that should be used for the expansion. If given one + entry for each dimension of the destination is needed. + + Todo: + - While for the output it is probably okay to always require an adjacent + AccessNode for the input it might be possible to be on the other side + of a Map. + """ + + implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {} + default_implementation: Final[str | None] = "pure" + + brodcast_in_dims = dace_properties.ListProperty(element_type=int) + params = dace_properties.ListProperty(element_type=str, allow_none=True) + + def __init__( + self, + name: str, + broadcast_in_dims: Sequence[int], + params: Optional[Iterable[gtx_common.Dimension | str]], + debuginfo: dace.dtypes.DebugInfo | None = None, + ): + # TODO(philip, edopao): I would propose to drop `value` then. This makes it + # simpler to handle in the transformations. + super().__init__(name, inputs={_INPUT_NAME}, outputs={_OUTPUT_NAME}) + + self.brodcast_in_dims = list(broadcast_in_dims) + self.debuginfo = debuginfo + + if params is None: + self.params = None + else: + self.params = [ + gtx_dace_lowering.get_map_variable(param) + if isinstance(param, gtx_common.Dimension) + else param + for param in params + ] + + def validate(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: + if len(self.brodcast_in_dims) != len(set(self.brodcast_in_dims)): + raise ValueError("`Can not broadcast to multiple dimensions at the same time.") + + # TODO(phimuell): Handle empty Memlets in the input. + if state.in_degree(self) != 1 and next(iter(state.in_edges(self))).dst_conn == _INPUT_NAME: + raise ValueError("GT4Py Broadcast node needs exactly one input.") + if ( + state.out_degree(self) != 1 + and next(iter(state.out_edges(self))).src_conn == _OUTPUT_NAME + ): + raise ValueError("GT4Py Broadcast node needs exactly one output.") + + bcast_value_node: dace_nodes.AccessNode = next(iter(state.in_edges(self))).src + if not isinstance(bcast_value_node, dace_nodes.AccessNode): + raise ValueError("Source of broadcasting must be an AccessNode.") + bcast_value_desc = bcast_value_node.desc(sdfg) + if isinstance(bcast_value_desc, dace_data.View): + raise ValueError("Can not broadcast from a view.") + + bcast_result_node: dace_nodes.AccessNode = next(iter(state.out_edges(self))).dst + if not isinstance(bcast_result_node, dace_nodes.AccessNode): + raise ValueError("Broadcast result must be an AccessNode.") + bcast_result_desc = bcast_result_node.desc(sdfg) + if isinstance(bcast_result_desc, dace_data.View): + raise ValueError("Broadcast result can not be a view.") + + if not isinstance(bcast_result_desc, dace_data.Array): + # In fact it would also be possible to broadcast into a Scalar, but this + # does not make much sense. + raise ValueError( + f"Can only broadcast into an array, but target was `{type(bcast_result_desc).__name__}`." + ) + + if (self.params is not None) and (len(self.params) != len(bcast_result_desc.shape)): + raise ValueError( + f"Expected that {len(bcast_result_desc.shape)} parameters are" + f" needed but {len(self.params)} were specified." + ) + + if isinstance(bcast_value_desc, dace_data.Scalar): + if len(self.brodcast_in_dims) != 0: + raise ValueError("For a scalar `broadcast_in_dims` must be empty.") + else: + if len(self.brodcast_in_dims) != len(bcast_value_desc.shape): + raise ValueError( + f"`broadcast_in_dims` has {len(self.brodcast_in_dims)} entries," + f" but the value to broadcast had {len(bcast_value_desc.shape)} dimensions." + ) + if len(bcast_result_desc.shape) < len(bcast_value_desc.shape): + raise ValueError( + f"The value to broadcast has more dimensions ({len(bcast_value_desc.shape)})" + f" than the result ({len(bcast_result_desc.shape)})." + ) + + for src_dim, bcast_dst_dim in enumerate(self.brodcast_in_dims): + if bcast_dst_dim < 0: + raise ValueError("Negative broadcast") + if bcast_dst_dim >= len(bcast_result_desc.shape): + raise ValueError("Out of range broadcast dim found.") + + # Only do the size matching test if the sizes are known, as different + # symbols can have the same value. + src_size = bcast_value_desc.shape[src_dim] + dst_size = bcast_result_desc.shape[bcast_dst_dim] + if str(src_size).isdigit() and str(dst_size).isdigit() and (src_size != dst_size): + raise ValueError("Size mismatch found.") + + +def inplace_broadcast_expander( + bcast_node: Broadcast, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None: + """Perform expansion of `bcast_node` inside `state`. + + The main difference between this and the normal expansion transformation is + that this function does not generate nested SDFG and instead performs the + expansion in place. + """ + + input_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = next(iter(state.in_edges(bcast_node))) + output_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = next( + iter(state.out_edges(bcast_node)) + ) + + # TODO(phimuell): Add warning. + map_params: list[str] = ( + [f"__bcast{dst_dim}" for dst_dim in range(len(output_edge.data.subset))] + if bcast_node.params is None + else list(bcast_node.params) + ) + + output_subset: list[str] = map_params.copy() + map_ranges: dict[str, dace_sbs.Range] = { + map_param: dace_sbs.Range([sbs]) + for map_param, sbs in zip(map_params, output_edge.data.subset) + } + + input_subset: list[str] + if len(bcast_node.brodcast_in_dims) == 0: + input_subset = ["0"] + else: + bcast_value_offset = input_edge.data.subset.min_element() + input_subset = [ + f"{map_params[dst_dim]} + ({offset})" + for dst_dim, offset in zip(bcast_node.brodcast_in_dims, bcast_value_offset) + ] + + me, mx = state.add_map(f"__gt4py_broadcast_map_{bcast_node.name}", ndrange=map_ranges) + bcast_tlet = state.add_tasklet( + f"__gt4py_broadcast_tasklet_{bcast_node.name}", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in", + ) + + state.add_edge( + input_edge.src, + input_edge.src_conn, + me, + f"IN_{input_edge.data.data}", + dace.Memlet.from_array(input_edge.data.data, sdfg.arrays[input_edge.data.data]), + ) + state.add_edge( + me, + f"OUT_{input_edge.data.data}", + bcast_tlet, + "__in", + dace.Memlet(data=input_edge.data.data, subset=", ".join(input_subset)), + ) + me.add_scope_connectors(input_edge.data.data) + + state.add_edge( + bcast_tlet, + "__out", + mx, + f"IN_{output_edge.data.data}", + dace.Memlet(data=output_edge.data.data, subset=", ".join(output_subset)), + ) + state.add_edge( + mx, + f"OUT_{output_edge.data.data}", + output_edge.dst, + output_edge.dst_conn, + dace.Memlet(data=output_edge.data.data, subset=copy.deepcopy(output_edge.data.subset)), + ) + mx.add_scope_connectors(output_edge.data.data) + + # Now delete the node. + state.remove_node(bcast_node) + + +@dace_library.register_expansion(Broadcast, "pure") +class BroadcastExpandInlined(dace_transform.ExpandTransformation): + """Implements pure expansion of the Broadcast library node.""" + + environments: Final[list[Any]] = [] + + @staticmethod + def expansion(node: Broadcast, state: dace.SDFGState, sdfg: dace.SDFG) -> dace.SDFG: + # TODO: + # - Modify the edges on the outside. + # - Handle the missing symbols. + + # NOTE: We have to cheat a here a bit. Actually only parts of the output + # would be mapped into the nested SDFG. + assert isinstance(node, Broadcast) + assert state.out_degree(node) == 1 and state.in_degree(node) == 1 + + nsdfg = dace.SDFG(f"__gt4py_broadcast_expansion_{node.label}") + bcast_st = nsdfg.add_state(f"__gt4py_broadcast_expansion_{node.label}_state") + + input_edge = next(state.in_edges_by_connector(node, _INPUT_NAME)) + output_edge = next(state.out_edges_by_connector(node, _OUTPUT_NAME)) + bcast_value = input_edge.src + bcast_result = output_edge.dst + + # Creating the input and output data inside the nested SDFG, such that we can + # map them _fully_ (see later) into the nested SDFG. + bcast_value_inner_data = nsdfg.add_datadesc(_INPUT_NAME, bcast_value.desc(sdfg).clone()) + bcast_result_inner_data = nsdfg.add_datadesc(_OUTPUT_NAME, bcast_result.desc(sdfg).clone()) + nsdfg.arrays[bcast_value_inner_data].transient = False + nsdfg.arrays[bcast_result_inner_data].transient = False + + inner_bcast_node = copy.deepcopy(node) + inner_bcast_value_edge = bcast_st.add_edge( + bcast_st.add_access(bcast_value_inner_data), + input_edge.src_conn, + inner_bcast_node, + "_inp", + copy.deepcopy(input_edge.data), + ) + inner_bcast_value_edge.data.data = bcast_value_inner_data + + inner_bcast_result_edge = bcast_st.add_edge( + inner_bcast_node, + "_outp", + bcast_st.add_access(bcast_result_inner_data), + output_edge.dst_conn, + copy.deepcopy(output_edge.data), + ) + inner_bcast_result_edge.data.data = bcast_result_inner_data + + # Now we run the inplace expansion on the node inside the nested SDFG. + inplace_broadcast_expander(inner_bcast_node, bcast_st, nsdfg) + + # To ensure that the full data is passed into the nested SDFG, which we + # assumed because of how we want that everything is mapped into. + input_edge.data.subset = dace_sbs.Range.from_array(bcast_value.desc(sdfg)) + output_edge.data.subset = dace_sbs.Range.from_array(bcast_result.desc(sdfg)) + + # NOTE: We will not update `nsdfg.symbols`, instead we will rely on + # `add_nested_sdfg()` that is called implicitly by the expansion driver. + + return nsdfg diff --git a/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py b/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py new file mode 100644 index 0000000000..135f814890 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py @@ -0,0 +1,176 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import Any, Final + +import dace +from dace import library as dace_library, properties as dace_properties +from dace.sdfg import graph as dace_graph +from dace.transformation import transformation as dace_transform + +from gt4py.next import common as gtx_common + + +@dace.library.node +class ReduceWithSkipValues(dace.sdfg.nodes.LibraryNode): + """Implements reduction with skip values.""" + + implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {} + default_implementation: Final[str | None] = "pure" + + # Properties + wcr = dace_properties.LambdaProperty(default="lambda a, b: a") + identity = dace_properties.SymbolicProperty(default=0, to_json=lambda x: str(x)) + init = dace_properties.SymbolicProperty(default=0, to_json=lambda x: str(x)) + input_conn = dace_properties.Property(default="_in") + output_conn = dace_properties.Property(default="_out") + mask_conn = dace_properties.Property(default="_mask") + + def __init__( + self, + name: str, + wcr: str, + identity: dace.symbolic.SymbolicType, + init: dace.symbolic.SymbolicType, + debuginfo: dace.dtypes.DebugInfo | None = None, + input_conn: str | None = None, + output_conn: str | None = None, + mask_conn: str | None = None, + ) -> None: + if input_conn is None: + input_conn = self.input_conn + else: + self.input_conn = input_conn + + if output_conn is None: + output_conn = self.output_conn + else: + self.output_conn = output_conn + + if mask_conn is None: + mask_conn = self.mask_conn + else: + self.mask_conn = mask_conn + + super().__init__(name, inputs={input_conn, mask_conn}, outputs={output_conn}) + self.wcr = wcr + self.identity = identity + self.init = init + self.debuginfo = debuginfo + + def validate(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: + assert len(list(state.in_edges_by_connector(self, self.input_conn))) == 1 + inedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(self, self.input_conn) + ) + assert len(list(state.out_edges_by_connector(self, self.output_conn))) == 1 + outedge: dace_graph.MultiConnectorEdge = next( + state.out_edges_by_connector(self, self.output_conn) + ) + assert len(list(state.in_edges_by_connector(self, self.mask_conn))) == 1 + maskedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(self, self.mask_conn) + ) + + mask_desc = sdfg.arrays[maskedge.data.data] + if len(mask_desc.shape) != 2: + raise ValueError(f"Invalid shape {mask_desc.shape} of mask array, expected 2d array.") + max_neighbors = mask_desc.shape[1] + if not (isinstance(max_neighbors, int) or str(max_neighbors).isdigit()): + raise ValueError( + f"Invalid shape {mask_desc.shape} of mask array, expected constant neighbors size." + ) + if inedge.data.num_elements() != max_neighbors: + raise ValueError(f"Invalid memlet on input connector {self.input_conn}.") + if maskedge.data.num_elements() != max_neighbors: + raise ValueError(f"Invalid memlet on input connector {self.mask_conn}.") + if outedge.data.num_elements() != 1: + raise ValueError(f"Invalid memlet on output connector {self.output_conn}.") + + +@dace_library.register_expansion(ReduceWithSkipValues, "pure") +class ReduceWithSkipValuesExpandInlined(dace_transform.ExpandTransformation): + """Implements pure expansion of the ReduceWithSkipValues library node.""" + + environments: Final[list[Any]] = [] + + @staticmethod + def expansion(node: ReduceWithSkipValues, state: dace.SDFGState, sdfg: dace.SDFG) -> dace.SDFG: + assert len(list(state.in_edges_by_connector(node, node.input_conn))) == 1 + inedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(node, node.input_conn) + ) + assert len(list(state.out_edges_by_connector(node, node.output_conn))) == 1 + outedge: dace_graph.MultiConnectorEdge = next( + state.out_edges_by_connector(node, node.output_conn) + ) + assert len(list(state.in_edges_by_connector(node, node.mask_conn))) == 1 + maskedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(node, node.mask_conn) + ) + input_desc = sdfg.arrays[inedge.data.data] + output_desc = sdfg.arrays[outedge.data.data] + mask_desc = sdfg.arrays[maskedge.data.data] + assert len(mask_desc.shape) == 2 + max_neighbors = mask_desc.shape[1] + assert isinstance(max_neighbors, int) or str(max_neighbors).isdigit() + + local_dim_index = inedge.data.src_subset.size().index(max_neighbors) + + nsdfg = dace.SDFG(node.label) + inp, _ = nsdfg.add_array( + node.input_conn, + (max_neighbors,), + input_desc.dtype, + strides=(input_desc.strides[local_dim_index],), + ) + mask, _ = nsdfg.add_array( + node.mask_conn, + (max_neighbors,), + mask_desc.dtype, + strides=(mask_desc.strides[1],), + ) + outp, _ = nsdfg.add_scalar(node.output_conn, output_desc.dtype) + st_init = nsdfg.add_state("init") + init_tasklet = st_init.add_tasklet( + name="write", + inputs={}, + outputs={"__tlet_out"}, + code=f"__tlet_out = {input_desc.dtype}({node.init})", + ) + st_init.add_edge( + init_tasklet, + "__tlet_out", + st_init.add_access(outp), + None, + dace.Memlet(data=outp, subset="0"), + ) + st_reduce = nsdfg.add_state_after(st_init, "compute") + # Fill skip values in local dimension with the reduce identity value + skip_value = f"{input_desc.dtype}({node.identity})" + # Since this map operates on a pure local dimension, we explicitly set sequential + # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. + # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. + st_reduce.add_mapped_tasklet( + name="reduce_with_skip_values", + map_ranges={"i": f"0:{max_neighbors}"}, + inputs={ + "__tlet_inp": dace.Memlet(data=inp, subset="i"), + "__tlet_mask": dace.Memlet(data=mask, subset="i"), + }, + code=f"__tlet_out = __tlet_inp if __tlet_mask != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", + outputs={ + "__tlet_out": dace.Memlet(data=outp, subset="0", wcr=node.wcr, wcr_nonatomic=True), + }, + external_edges=True, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + + return nsdfg diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index da590d84e0..72bdf4dd56 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -27,6 +27,7 @@ import dace from dace import nodes as dace_nodes, subsets as dace_subsets +from dace.libraries import standard as dace_stdlib from gt4py import eve from gt4py.eve.extended_typing import MaybeNestedInTuple, NestedTuple @@ -34,7 +35,10 @@ from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import symbol_ref_utils -from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtir_library_nodes, + sdfg_args as gtx_dace_args, +) from gt4py.next.program_processors.runners.dace.lowering import ( gtir_python_codegen, gtir_to_sdfg, @@ -1320,139 +1324,6 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=ts.ListType(node.type.element_type, offset_type), ) - def _make_reduce_with_skip_values( - self, - input_expr: ValueExpr | MemletExpr, - offset_provider_type: gtx_common.NeighborConnectivityType, - reduce_init: SymbolExpr, - reduce_identity: SymbolExpr, - reduce_wcr: str, - result_node: dace_nodes.AccessNode, - ) -> None: - """ - Helper method to lower reduction on a local field containing skip values. - - The reduction is implemented as a nested SDFG containing 2 states. In first - state, the result (a scalar data node passed as argumet) is initialized. - In second state, a mapped tasklet uses a write-conflict resolution (wcr) - memlet to update the result. - We use the offset provider as a mask to identify skip values: the value - that is written to the result node is either the input value, when the - corresponding neighbor index in the connectivity table is valid, or the - identity value if the neighbor index is missing. - """ - origin_map_index = gtir_to_sdfg_utils.get_map_variable(offset_provider_type.source_dim) - - assert ( - isinstance(input_expr.gt_dtype, ts.ListType) - and input_expr.gt_dtype.offset_type is not None - ) - offset_type = input_expr.gt_dtype.offset_type - connectivity = gtx_dace_args.connectivity_identifier(offset_type.value) - connectivity_node = self.state.add_access(connectivity) - connectivity_desc = connectivity_node.desc(self.sdfg) - connectivity_desc.transient = False - - desc = input_expr.dc_node.desc(self.sdfg) - if isinstance(input_expr, MemletExpr): - local_dim_indices = [i for i, size in enumerate(input_expr.subset.size()) if size != 1] - else: - local_dim_indices = list(range(len(desc.shape))) - - if len(local_dim_indices) != 1: - raise NotImplementedError( - f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." - ) - local_dim_index = local_dim_indices[0] - assert desc.shape[local_dim_index] == offset_provider_type.max_neighbors - - # we lower the reduction map with WCR out memlet in a nested SDFG - nsdfg = dace.SDFG(self.subgraph_builder.unique_nsdfg_name("reduce_with_skip_values")) - nsdfg.add_array( - "values", - (desc.shape[local_dim_index],), - desc.dtype, - strides=(desc.strides[local_dim_index],), - ) - nsdfg.add_array( - "neighbor_indices", - (connectivity_desc.shape[1],), - connectivity_desc.dtype, - strides=(connectivity_desc.strides[1],), - ) - nsdfg.add_scalar("acc", desc.dtype) - st_init = nsdfg.add_state(f"{nsdfg.label}_init") - init_tasklet, connector_mapping = self.subgraph_builder.add_tasklet( - name="init_acc", - sdfg=self.sdfg, - state=st_init, - inputs={}, - outputs={"val"}, - code=f"val = {reduce_init.dc_dtype}({reduce_init.value})", - ) - st_init.add_edge( - init_tasklet, - connector_mapping["val"], - st_init.add_access("acc"), - None, - dace.Memlet(data="acc", subset="0"), - ) - st_reduce = nsdfg.add_state_after(st_init, f"{nsdfg.label}_reduce") - # Fill skip values in local dimension with the reduce identity value - skip_value = f"{reduce_identity.dc_dtype}({reduce_identity.value})" - # Since this map operates on a pure local dimension, we explicitly set sequential - # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. - # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. - self.subgraph_builder.add_mapped_tasklet( - name="reduce_with_skip_values", - sdfg=self.sdfg, - state=st_reduce, - map_ranges={"i": f"0:{offset_provider_type.max_neighbors}"}, - inputs={ - "val": dace.Memlet(data="values", subset="i"), - "neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), - }, - code=f"out = val if neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", - outputs={ - "out": dace.Memlet(data="acc", subset="0", wcr=reduce_wcr, wcr_nonatomic=True), - }, - external_edges=True, - schedule=dace.dtypes.ScheduleType.Sequential, - ) - - nsdfg_node = self.state.add_nested_sdfg( - nsdfg, inputs={"values", "neighbor_indices"}, outputs={"acc"} - ) - - if isinstance(input_expr, MemletExpr): - self._add_input_data_edge(input_expr.dc_node, input_expr.subset, nsdfg_node, "values") - else: - self.state.add_edge( - input_expr.dc_node, - None, - nsdfg_node, - "values", - self.sdfg.make_array_memlet(input_expr.dc_node.data), - ) - # The layout of connectivity tables is known. - assert len(offset_provider_type.domain) == 2 - assert offset_provider_type.domain[1].kind == gtx_common.DimensionKind.LOCAL - self._add_input_data_edge( - connectivity_node, - dace_subsets.Range.from_string( - f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" - ), - nsdfg_node, - "neighbor_indices", - ) - self.state.add_edge( - nsdfg_node, - "acc", - result_node, - None, - dace.Memlet(data=result_node.data, subset="0"), - ) - def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, ts.ScalarType) op_name, reduce_init, reduce_identity = get_reduce_params(node) @@ -1471,27 +1342,73 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + inp_conn = "_in" + outp_conn = "_out" + mask_conn = "_mask" + if offset_provider_type.has_skip_values: - self._make_reduce_with_skip_values( - input_expr, - offset_provider_type, - reduce_init, - reduce_identity, + name = self.subgraph_builder.unique_nsdfg_name("reduce_with_skip_values") + reduce_node = gtir_library_nodes.ReduceWithSkipValues( + name, + reduce_wcr, + identity=reduce_identity.value, + init=reduce_init.value, + debuginfo=gtir_to_sdfg_utils.debug_info(node), + ) + reduce_node.input_conn = inp_conn + reduce_node.output_conn = outp_conn + reduce_node.mask_conn = mask_conn + else: + reduce_node = dace_stdlib.Reduce( + "reduce", reduce_wcr, - result_node, + axes=None, + identity=reduce_init.value, + schedule=dace.dtypes.ScheduleType.Default, + debuginfo=gtir_to_sdfg_utils.debug_info(node), + inputs={inp_conn}, + outputs={outp_conn}, ) + self.state.add_node(reduce_node) + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, reduce_node, inp_conn) else: - reduce_node = self.state.add_reduce(reduce_wcr, axes=None, identity=reduce_init.value) - if isinstance(input_expr, MemletExpr): - self._add_input_data_edge(input_expr.dc_node, input_expr.subset, reduce_node) - else: - self.state.add_nedge( - input_expr.dc_node, - reduce_node, - self.sdfg.make_array_memlet(input_expr.dc_node.data), - ) - self.state.add_nedge(reduce_node, result_node, dace.Memlet(data=result, subset="0")) + self.state.add_edge( + input_expr.dc_node, + None, + reduce_node, + inp_conn, + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + self.state.add_edge( + reduce_node, outp_conn, result_node, None, dace.Memlet(data=result, subset="0") + ) + + if offset_provider_type.has_skip_values: + assert ( + isinstance(input_expr.gt_dtype, ts.ListType) + and input_expr.gt_dtype.offset_type is not None + ) + + offset_type = input_expr.gt_dtype.offset_type + connectivity = gtx_dace_args.connectivity_identifier(offset_type.value) + connectivity_node = self.state.add_access(connectivity) + connectivity_desc = connectivity_node.desc(self.sdfg) + connectivity_desc.transient = False + + # The layout of connectivity tables is known. + assert len(offset_provider_type.domain) == 2 + assert offset_provider_type.domain[1].kind == gtx_common.DimensionKind.LOCAL + origin_map_index = gtir_to_sdfg_utils.get_map_variable(offset_provider_type.source_dim) + self._add_input_data_edge( + connectivity_node, + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), + reduce_node, + mask_conn, + ) return ValueExpr(result_node, node.type) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index cb9b9c6d65..7272562616 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -73,6 +73,9 @@ def unique_tasklet_name(self, name: str) -> str: ... @abc.abstractmethod def unique_temp_name(self) -> str: ... + @abc.abstractmethod + def unique_lib_node_name(self, lib_node_type: str) -> str: ... + def add_temp_array( self, sdfg: dace.SDFG, shape: Sequence[Any], dtype: dace.dtypes.typeclass ) -> tuple[str, dace.data.Scalar]: @@ -118,7 +121,7 @@ def add_tasklet( code: str, language: dace.dtypes.Language = dace.dtypes.Language.Python, **kwargs: Any, - ) -> dace_nodes.Tasklet: + ) -> tuple[dace_nodes.Tasklet, dict[str, str]]: """Wrapper of `dace.SDFGState.add_tasklet` that assigns a unique name. It also modifies the tasklet connectors by adding a prefix string (see @@ -759,6 +762,9 @@ def unique_tasklet_name(self, name: str) -> str: def unique_temp_name(self) -> str: return f"{next(self.uids['gtir_tmp'])}" + def unique_lib_node_name(self, lib_node_type: str) -> str: + return f"{next(self.uids[lib_node_type])}" + def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] ) -> tuple[list[dace.symbolic.SymbolicType], list[dace.symbolic.SymbolicType]]: diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py index 007202a87d..864cc6fe04 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_primitives.py @@ -12,17 +12,16 @@ from typing import TYPE_CHECKING, Iterable, Optional, Protocol import dace -from dace import nodes as dace_nodes, subsets as dace_subsets +from dace import data as dace_data, nodes as dace_nodes, subsets as dace_subsets from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtir_library_nodes, + sdfg_args as gtx_dace_args, +) from gt4py.next.program_processors.runners.dace.lowering import ( gtir_dataflow, gtir_domain, @@ -261,8 +260,7 @@ def translate_as_fieldop( if isinstance(arg_type, ts.ScalarType) or arg_type.dims != node.type.dims: # Special usage of 'deref' as argument to fieldop expression, to broadcast # the input value (a scalar or a field slice) on the output domain. - stencil_expr = im.lambda_("a")(im.deref("a")) - stencil_expr.expr.type = node.type.dtype + return translate_broadcast(node, ctx, sdfg_builder) else: # Special usage of 'deref' with field argument, to access the field # on the given domain. It copies a subset of the source field. @@ -292,6 +290,162 @@ def translate_as_fieldop( ) +def translate_broadcast( + node: gtir.Node, + ctx: gtir_to_sdfg.SubgraphContext, + sdfg_builder: gtir_to_sdfg.SDFGBuilder, +) -> gtir_to_sdfg_types.FieldopData: + """Translates a broadcast expression which writes a scalar value on the field domain.""" + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + + if not isinstance(node.type, ts.FieldType): + raise NotImplementedError("Unexpected 'as_filedop' with tuple output in SDFG lowering.") + + assert isinstance(node.type.dtype, ts.ScalarType) + field_dtype = gtx_dace_args.as_dace_type(node.type.dtype) + + assert len(node.args) == 1 + bcast_arg = node.args[0] + + fun_node = node.fun + assert len(fun_node.args) == 2 + fieldop_expr, fieldop_domain_expr = fun_node.args + assert cpm.is_ref_to(fieldop_expr, "deref") + + # TODO: + # - Include the dimensions of the output, such that we can generate the Map correctly + # this is needed for expansion of the broadcast node such that we can order them correctly. + # This information is stored in `field_dims`. + + # Parse the domain of the field operator. + assert isinstance(fieldop_domain_expr.type, ts.DomainType) + field_domain = gtir_domain.get_field_domain( + domain_utils.SymbolicDomain.from_expr(fieldop_domain_expr) + ) + + # The memory layout of the output field follows the field operator compute domain. + field_dims, field_origin, field_shape = gtir_domain.get_field_layout(field_domain) + assert field_dims == node.type.dims + field_name, field_desc = sdfg_builder.add_temp_array(ctx.sdfg, field_shape, field_dtype) + field_node = ctx.state.add_access(field_name) + + # Retrieve the scalar argument, which could be either a literal value or the + # result of a scalar expression. + # TODO: The name should not be derived from a tasklet + bcast_node_name = sdfg_builder.unique_lib_node_name("broadcast") + + # The destination array was allocated to fit everything. + bcast_result_subset = dace_subsets.Range.from_array(field_desc) + bcast_value_subset: dace_subsets.Subset | None = None + + # Which dimensions in `bcast_value` corresponds to the ones in `cast_result`. + broadcast_in_dims: list[int] + + if isinstance(bcast_arg, gtir.Literal): + assert isinstance(bcast_arg.type, ts.ScalarType) + bcast_value_tlet, connector_mapping = sdfg_builder.add_tasklet( + sdfg_builder.unique_tasklet_name(bcast_node_name), + sdfg=ctx.sdfg, + state=ctx.state, + inputs=set(), + outputs={"__out"}, + code=f"__out = {bcast_arg.value}", + ) + bcast_value_name, bcast_value_desc = sdfg_builder.add_temp_scalar( + ctx.sdfg, gtx_dace_args.as_dace_type(bcast_arg.type) + ) + bcast_value = ctx.state.add_access(bcast_value_name) + ctx.state.add_edge( + bcast_value_tlet, + connector_mapping["__out"], + bcast_value, + None, + dace.Memlet.from_array(bcast_value_name, bcast_value_desc), + ) + broadcast_in_dims = [] + + elif isinstance( + arg := _parse_fieldop_arg(bcast_arg, ctx, sdfg_builder, field_domain), + gtir_dataflow.MemletExpr, + ): + if isinstance(arg.gt_dtype, ts.ScalarType): + # Broadcasting a scalar that is not a literal. + assert isinstance(arg.dc_node.desc(ctx.sdfg), dace_data.Scalar) + bcast_value = arg.dc_node + broadcast_in_dims = [] + else: + assert isinstance(arg.gt_dtype, ts.ListType) + raise NotImplementedError("Broadcast of lists is not supported.") + + else: + # TODO: What is the exact difference between this case and the one above? + + bcast_value = arg.field + bcast_value_desc = bcast_value.desc(ctx.sdfg) + if isinstance(bcast_value_desc, dace_data.Scalar): + broadcast_in_dims = [] + else: + # Broadcasting a "vector", i.e. adding missing non local dimensions. + assert isinstance(arg.field.desc(ctx.sdfg), dace_data.Array) + bcast_value_gtdims = arg.get_field_type().dims + bcast_result_gtdim_map = {dim: i for i, dim in enumerate(field_dims)} + bcast_value_origins = [o for _, o in arg.field_domain] + bcast_result_origins = field_origin + bcast_result_subset_size = bcast_result_subset.size() + + # Use the dimensions to find out how we have to broadcast. + broadcast_in_dims = [] + bcast_value_subset_components: list[str] = [] + for bcast_value_dim, bcast_value_gtdim in enumerate(bcast_value_gtdims): + # Which dimension of the input should go to which dimension in the output. + bcast_result_dim = bcast_result_gtdim_map[bcast_value_gtdim] + broadcast_in_dims.append(bcast_result_dim) + + # Find the correction that should be applied to the input. + bcast_value_origin = bcast_value_origins[bcast_value_dim] + bcast_result_origin = bcast_result_origins[bcast_result_dim] + bcast_dim_size = bcast_result_subset_size[bcast_result_dim] + start_pos = f"({bcast_result_origin}) - ({bcast_value_origin})" + + bcast_value_subset_components.append( + f"({start_pos}):(({start_pos}) + ({bcast_dim_size}))" + ) + + bcast_value_subset = dace_subsets.Range.from_string( + ", ".join(bcast_value_subset_components) + ) + + bcast_node = gtir_library_nodes.Broadcast( + name=bcast_node_name, + broadcast_in_dims=broadcast_in_dims, + params=field_dims, + debuginfo=gtir_to_sdfg_utils.debug_info(node), + ) + ctx.state.add_node(bcast_node) + + # If not specified differently use the whole `bcast_value` content. + if bcast_value_subset is None: + bcast_value_subset = dace_subsets.Range.from_array(bcast_value.desc(ctx.sdfg)) + + ctx.state.add_edge( + bcast_value, + None, + bcast_node, + "_inp", + dace.Memlet(data=bcast_value.data, subset=bcast_value_subset), + ) + ctx.state.add_edge( + bcast_node, + "_outp", + field_node, + None, + dace.Memlet(data=field_name, subset=bcast_result_subset), + ) + + return gtir_to_sdfg_types.FieldopData(field_node, node.type, tuple(field_origin)) + + def _construct_if_branch_output( ctx: gtir_to_sdfg.SubgraphContext, sdfg_builder: gtir_to_sdfg.SDFGBuilder, @@ -740,6 +894,7 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, + translate_broadcast, translate_concat_where, translate_if, translate_index, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 66a24c1ebf..60fe9d4e92 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -19,6 +19,7 @@ GT4PyAutoOptHookStage, gt_auto_optimize, ) +from .broadcast import BrodcastChainRemover, ScalarBrodcastInliner from .concat_where_mapper import ( gt_apply_concat_where_replacement_on_sdfg, gt_check_if_concat_where_node_is_replaceable, @@ -88,6 +89,7 @@ __all__ = [ + "BrodcastChainRemover", "CopyChainRemover", "DoubleWriteRemover", "FuseHorizontalConditionBlocks", @@ -114,6 +116,7 @@ "RemoveAccessNodeCopies", "RemovePointwiseViews", "RemoveScalarCopies", + "ScalarBrodcastInliner", "ScanLoopUnrolling", "SingleStateGlobalDirectSelfCopyElimination", "SingleStateGlobalSelfCopyElimination", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 799e8ad228..adb3c1cd24 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -20,7 +20,10 @@ from dace.transformation.passes import analysis as dace_analysis from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtir_library_nodes, + transformations as gtx_transformations, +) class GT4PyAutoOptHook(enum.Enum): @@ -234,6 +237,8 @@ def gt_auto_optimize( device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU optimization_hooks = optimization_hooks or {} + validate_all = True + with dace.config.temporary_config(): # Do not store which transformations were applied inside the SDFG. dace.Config.set("store_history", value=False) @@ -369,6 +374,14 @@ def gt_auto_optimize( stacklevel=0, ) + # We now expand all GT4Py specific library nodes. + # We do this such that we have control over all the Maps that are there. + # NOTE: `Broadcast` nodes were already expanded in the top level dataflow phase. + # TODO(phimuell): It is probably the right place, but maybe there is a better one. + for node, state in list(sdfg.all_nodes_recursive()): + if isinstance(node, gtir_library_nodes.GTIR_LIBRARY_NODES): + node.expand(state) + sdfg = _gt_auto_configure_maps_and_strides( sdfg=sdfg, gpu=gpu, @@ -647,13 +660,28 @@ def _gt_auto_process_top_level_maps( # The SDFG was modified by the transformations above. The SDFG was # modified. Call Simplify and try again to further optimize. - gtx_transformations.gt_simplify( + simplify_res = gtx_transformations.gt_simplify( sdfg, validate=False, validate_all=validate_all, skip=gtx_transformations.constants._GT_AUTO_OPT_TOP_LEVEL_STAGE_SIMPLIFY_SKIP_LIST, ) + # We will now perform the expansion of the Broadcast. The main idea is that if + # the broadcast specific transformations have stopped doing work we will now + # expand them and give the splitting transformations a chance of handling them. + if simplify_res: + broadcast_related_results = sum( + simplify_res.get(xtrans_name, 0) + for xtrans_name in ["ScalarBrodcastInliner", "BrodcastChainRemover"] + ) + else: + broadcast_related_results = 0 + if (not disable_splitting) and broadcast_related_results == 0: + for node, state in list(sdfg.all_nodes_recursive()): + if isinstance(node, gtir_library_nodes.Broadcast): + gtir_library_nodes.inplace_broadcast_expander(node, state, state.sdfg) + # Replace `concat_where` nodes # TODO(phimuell): Are there better locations for this transformation? gtx_transformations.gt_apply_concat_where_replacement_on_sdfg( @@ -751,8 +779,8 @@ def _gt_auto_process_dataflow_inside_maps( # NestedSDFGs inside the ConditionalBlocks it fuses. sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), - validate=True, - validate_all=True, + validate=False, + validate_all=validate_all, ) # Move dataflow into the branches of the `if` such that they are only evaluated diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/broadcast.py b/src/gt4py/next/program_processors/runners/dace/transformations/broadcast.py new file mode 100644 index 0000000000..ecd1ec9b9e --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/broadcast.py @@ -0,0 +1,643 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +""" +TODO: + - Figuring out where to put this file. + - Finalize the broadcast library node implementation. Currently the implementation + kinds of assume that all information is stored inside the destination subset. + Furthermore, at least `InlineBroadcastAccess` assumes that there is a source. +""" + +import copy +from typing import Any, Optional + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_sbs, + transformation as dace_transformation, +) +from dace.sdfg import graph as dace_graph, nodes as dace_nodes +from dace.transformation.passes import analysis as dace_analysis + +from gt4py.next import config as gtx_config +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtx_lib_nodes, + transformations as gtx_xtrans, +) + + +@dace_properties.make_properties +class ScalarBrodcastInliner(dace_transformation.SingleStateTransformation): + """ + + The transformation performs two kind of operations that are related to + broadcast library nodes. However, this transformations only operate on + scalar values that are broadcasted. + + The first one is "static access replacement" in that operation the pattern: + ``` + (value_to_broadcast) -> [BroadCast] -> (bcast_result) -> MapEnty + ``` + is matched and transformed into: + ``` + (value_to_broadcase) -> MapEntry + ``` + Thus instead of filling an array with a value and then reading from that + big filled array, the Map directly reads from the value that should be + broadcasted. The only restriction is, that the accesses are known at + compile time, i.e. no neighbourhood accesses are supported. + + The second kind of operation is "AccessNode bypassing" in which the + following pattern is matched: + ``` + (value_to_broadcast) -> (bcast_result) -> (access_node_consumer) + ``` + to: + ``` + (value_to_broadcast) -> (access_node_consumer) + ``` + Thus, the broadcast is performed directly into the target destination. + + It is important that these two modes can be performed at the same time. + Meaning `bcast_result` can have multiple consumer. + + Args: + clean_dead_dataflow: Perform dead dataflow elimination. + single_use_data: Use this as single use data, if not given + `FindSingleUseData` will be run at every `apply()` call. + + Note: + In the "AccessNode bypassing" mode each (AccessNode) consumer will + have its own broadcast, but all will read from the same value to + broadcast. + + Todo: + - Rename this transformation, as it now supports scalars. + """ + + bcast_value = dace_transformation.PatternNode(dace_nodes.AccessNode) + bcast_node = dace_transformation.PatternNode(gtx_lib_nodes.Broadcast) + bcast_result = dace_transformation.PatternNode(dace_nodes.AccessNode) + + clean_dead_dataflow = dace_properties.Property( + dtype=bool, + allow_none=False, + default=True, + desc="Clean dead dataflow.", + ) + + # Name of all data that is used at only one place. Is computed by the + # `FindSingleUseData` pass and be passed at construction time. Needed until + # [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved. + _single_use_data: Optional[dict[dace.SDFG, set[str]]] + + @classmethod + def expressions(cls) -> Any: + return [ + dace.sdfg.utils.node_path_graph( + cls.bcast_value, + cls.bcast_node, + cls.bcast_result, + ) + ] + + def __init__( + self, + *args: Any, + clean_dead_dataflow: Optional[bool] = None, + single_use_data: Optional[dict[dace.SDFG, set[str]]] = None, + **kwargs: Any, + ) -> None: + self._single_use_data = single_use_data + if clean_dead_dataflow is not None: + self.clean_dead_dataflow = clean_dead_dataflow + + super().__init__(*args, **kwargs) + + def can_be_applied( + self, + graph: dace.SDFGState, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + bcast_value = self.bcast_value + bcast_node = self.bcast_node + bcast_result = self.bcast_result + bcast_result_desc = bcast_result.desc(sdfg) + bcast_value_desc = bcast_value.desc(sdfg) + assert graph.in_degree(self.bcast_node) == 1 + assert graph.out_degree(self.bcast_node) == 1 + + # A fundamental requirement is that `bcast_result` is only generated by us. + # ADR-18 guarantees us this if it is transient and has a single producer, + # `bcast_node`. However, since we will remove `bcast_result`, we have to + # make sure that it is not used every where else. + if graph.in_degree(bcast_result) != 1: + return False + if not bcast_result_desc.transient: + return False + if bcast_value_desc.dtype != bcast_result_desc.dtype: + return False + + # Check single single use if a cached result is known, otherwise do it at the end. + if self._single_use_data is not None: + if bcast_result.data not in self._single_use_data[sdfg]: + return False + + # We do not allow view! + if gtx_xtrans.utils.is_view(bcast_result_desc): + return False + if gtx_xtrans.utils.is_view(bcast_value_desc): + return False + + # NOTE: The big question is if it favourable to perform the transformation + # every time. If there are only AccessNodes consumers then the answer is + # probably yes, as we can remove the read and write of the initial data + # only the write to final destination is left. If the consumers are Maps + # the thing is a bit different. As we have to create the intermediate + # allocation. If the read of the memory is okay the `InlineBroadcastAccess` + # transformation can get rid of it. However, if this is not possible then + # you would need to allocate more memory than before. Thus, we require + # that the Maps consumers can be handled by `InlineBroadcastAccess`. + + # Now we have to inspect all consumers of the result node. + for consumer_edge in graph.out_edges(bcast_result): + if consumer_edge.data.is_empty(): + return False + + match consumer := consumer_edge.dst: + case dace_nodes.AccessNode(): + # TODO(phimuell): Are there more checks needed. + if gtx_xtrans.utils.is_view(consumer, sdfg): + return False + + # Otherwise we would end up with not enough named parameters + # during expansion. + # TODO(phimuell): Since a Memlet is essentially a copy this case + # can be handled by adding dummy dimensions, but we would need + # to run some more analysis. + if bcast_node.params is not None and ( + len(consumer.desc(sdfg).shape) != len(bcast_result_desc.shape) + ): + return False + + case dace_nodes.MapEntry(): + if not self._check_map_consumer( + state=graph, + sdfg=sdfg, + bcast_value=bcast_value, + supplier_edge=consumer_edge, + ): + return False + + case dace_nodes.NestedSDFG(): + # TODO(phimuell): Consider implementing this case. + return False + + case _: + # This kind of node can not be handled. + return False + + # Check single use data if it was not known at the beginning. + if self._single_use_data is None: + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + if bcast_result.data not in single_use_data[sdfg]: + return False + + return True + + def apply( + self, + graph: dace.SDFGState, + sdfg: dace.SDFG, + ) -> None: + # Since the information what should be written is fully encoded in the + # destination subset of the `(bcast_result) -> (access_node)` edge + # we can now copy the edge. + bcast_value = self.bcast_value + bcast_node = self.bcast_node + bcast_result = self.bcast_result + + for consumer_edge in graph.out_edges(bcast_result): + match consumer := consumer_edge.dst: + case dace_nodes.AccessNode(): + self._handle_access_node_consumer( + state=graph, + bcast_value=bcast_value, + bcast_node=bcast_node, + bcast_result=bcast_result, + consumer_edge=consumer_edge, + ) + + case dace_nodes.MapEntry(): + self._handle_map_consumer( + state=graph, + sdfg=sdfg, + bcast_value=bcast_value, + bcast_node=bcast_node, + bcast_result=bcast_result, + consumer_edge=consumer_edge, + ) + + case _: + raise NotImplementedError( + f"Node type `{type(consumer).__name__}` is not supported." + ) + + # Check if we can remove the `bcast_result` node. + if self.clean_dead_dataflow and graph.out_degree(bcast_result) == 0: + # We have to figuring out + if self._single_use_data is None: + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + else: + single_use_data = self._single_use_data + + if bcast_result.data in single_use_data[sdfg]: + graph.remove_node(bcast_result) + graph.remove_node(bcast_node) + sdfg.remove_data(bcast_result.data, validate=__debug__) + + @staticmethod + def _handle_access_node_consumer( + state: dace.SDFGState, + bcast_value: dace_nodes.AccessNode, + bcast_node: gtx_lib_nodes.Broadcast, + bcast_result: dace_nodes.AccessNode, + consumer_edge: dace_graph.MultiConnectorEdge[dace.Memlet], + ) -> dace_graph.MultiConnectorEdge[dace.Memlet]: + """Bypass `bcast_result` and perform the broadcast directly into `consumer_edge.dst`. + + The function will replicate the broadcast node but it will be connected to + `bcast_value`. Furthermore, the function will remove `consumer_edge`. + + In case there are multiple edges between `bcast_result` and the destination + of `consumer_edge` the function is called multiple times, once for each + connection. + """ + + new_bcast_node = copy.deepcopy(bcast_node) + state.add_node(new_bcast_node) + + state.add_edge( + bcast_value, + None, + new_bcast_node, + "_inp", + dace.Memlet(data=bcast_value.data, subset="0"), + ) + + new_dst_subset: dace_sbs.Subset = copy.deepcopy( + consumer_edge.data.get_dst_subset(consumer_edge, state) + ) + assert new_dst_subset is not None + new_consumer_edge = state.add_edge( + new_bcast_node, + "_outp", + consumer_edge.dst, + consumer_edge.dst_conn, + dace.Memlet(data=consumer_edge.dst.data, subset=new_dst_subset), + ) + state.remove_edge(consumer_edge) + + return new_consumer_edge + + @staticmethod + def _check_map_consumer( + state: dace.SDFGState, + sdfg: dace.SDFG, + bcast_value: dace_nodes.AccessNode, + supplier_edge: dace_graph.MultiConnectorEdge[dace.Memlet], + ) -> bool: + """Check if the transformation can handle the Map. + + The function expects that `supplier_edge` connects the `bcast_result` to + a Map that consumes it. + The function essentially checks if all accesses are known at compile time + and that they only perform a single access. + """ + assert isinstance(supplier_edge.src, dace_nodes.AccessNode) + assert isinstance(supplier_edge.dst, dace_nodes.MapEntry) + + if supplier_edge.data.is_empty(): + return False + if not supplier_edge.dst_conn.startswith("IN_"): + return False + + bcast_result: dace_nodes.AccessNode = supplier_edge.src + bcast_value_desc = sdfg.arrays[bcast_value.data] + map_entry: dace_nodes.MapEntry = supplier_edge.dst + inner_connector = "OUT_" + supplier_edge.dst_conn[3:] + scope_dict = state.scope_dict() + + for consumer_edge in state.out_edges_by_connector(map_entry, inner_connector): + for final_consumer in state.memlet_tree(consumer_edge).leaves(): + sbs_to_inspect = ( + final_consumer.data.subset + if final_consumer.data.data == bcast_result.data + else final_consumer.data.other_subset + ) + assert sbs_to_inspect is not None + + # Currently the only test we do is, that only one element from + # `bcast_result` is loaded. However, the location of this + # element needs to be known. This rejects neighbourhood accesses. + if (sbs_to_inspect.num_elements() == 1) == False: # noqa: E712 [true-false-comparison] # SymPy comparison + return False + + if isinstance(bcast_value_desc, dace_data.Scalar): + # If `bcast_value` is a scalar then we do not impose any further + # restrictions. + pass + + else: + # In case `bcast_value` is not a scalar, then we require that the + # consumer is not nested. This is a simplification to avoids + # updating the Memlet trees. + # TODO(phimuell): Lift this requirement. + if scope_dict[final_consumer.dst] is not map_entry: + return False + + return True + + @staticmethod + def _handle_map_consumer( + state: dace.SDFGState, + sdfg: dace.SDFG, + bcast_node: gtx_lib_nodes.Broadcast, + bcast_value: dace_nodes.AccessNode, + bcast_result: dace_nodes.AccessNode, + consumer_edge: dace_graph.MultiConnectorEdge[dace.Memlet], + ) -> None: + """Replace accesses to `bcast_result`, performed through `consumer_edge` by accesses to `bcast_value`. + + The function expects that `consumer_edge` ends at a MapEntry node. + + Note: + - The function will ignore all outgoing edges of `bcast_result` that are + not going to `map_entry`. + - If there are multiple connections between `bcast_result` and the + destination of `consumer_edge` then this function is called multiple + times, once for each connection. + """ + + map_entry = consumer_edge.dst + scope_dict = state.scope_dict() + assert isinstance(map_entry, dace_nodes.MapEntry) + assert consumer_edge.dst_conn.startswith("IN_") + assert consumer_edge.src is bcast_result + + # Make `bcast_value` available inside the Map body. + bcast_value_conn: str | None = None + bcast_value_desc = sdfg.arrays[bcast_value.data] + for edge in state.out_edges(bcast_value): + if edge.dst is not map_entry: + continue + if edge.data.is_empty(): + continue + if not edge.dst_conn.startswith("IN_"): + continue + bcast_value_conn = "OUT_" + edge.dst_conn[3:] + edge.data.subset = dace_sbs.Range.from_array(bcast_value_desc) + break + + else: + # There was no connection between them so we have to create one. + bcast_value_conn_raw = map_entry.next_connector(bcast_value.data) + state.add_edge( + bcast_value, + None, + map_entry, + "IN_" + bcast_value_conn_raw, + sdfg.make_array_memlet(bcast_value.data), + ) + bcast_value_conn = "OUT_" + bcast_value_conn_raw + map_entry.add_scope_connectors(bcast_value_conn_raw, force=True) + assert bcast_value_conn in map_entry.out_connectors + + is_vector_broadcast = not isinstance(bcast_value_desc, dace_data.Scalar) + inner_connector = "OUT_" + consumer_edge.dst_conn[3:] + bcast_value_offset = next(iter(state.in_edges(bcast_node))).data.subset.min_element() + bcast_result_offset = next(iter(state.out_edges(bcast_node))).data.subset.min_element() + + # Note regardless if we have a scalar or a vector broadcast, since we have + # ensured that everything is a scalar read, we only need to modify the + # respective subset, there is not need for fancy updates. + for inner_map_edge in list(state.out_edges_by_connector(map_entry, inner_connector)): + new_consumer_sbs: list[str] = [] + for mtree in state.memlet_tree(inner_map_edge).traverse_children(True): + tree_edge = mtree.edge + assert tree_edge.data.wcr is None + + if is_vector_broadcast: + # In vector broadcast we limit ourselves to "direct read", i.e. no + # nesting, we do this because of the new map parameters that are + # potentially introduced. Furthermore, updating the Memlet tress + # would also be a bit difficult. + # Thus, this body is only called once per `inner_map_edge`. + assert scope_dict[tree_edge.dst] is map_entry + assert tree_edge.src is map_entry + assert inner_map_edge is tree_edge + broadcast_in_dims: list[int] = bcast_node.brodcast_in_dims + consumer_sbs: dace_sbs.Subset = ( + tree_edge.data.subset + if tree_edge.data.data == bcast_result.data + else tree_edge.data.other_subset + ) + + for bcast_res_dim, (sbs_acc, _, _) in enumerate(consumer_sbs): + if bcast_res_dim not in broadcast_in_dims: + # This is a replicated dimension, we can ignore it. + continue + else: + bcast_val_dim = broadcast_in_dims.index(bcast_res_dim) + new_consumer_sbs.append( + f"({sbs_acc}) - ({bcast_result_offset[bcast_res_dim]}) + ({bcast_value_offset[bcast_val_dim]})" + ) + assert len(new_consumer_sbs) == len(bcast_value_offset) + + else: + # In the scalar case we allow for very deep nested Memlet tress. + # However, since the subset is always `0` the updates are also + # very simple. + new_consumer_sbs.append("0") + + if tree_edge.data.data == bcast_result.data: + tree_edge.data.data = bcast_value.data + tree_edge.data.subset = dace_sbs.Range.from_string(", ".join(new_consumer_sbs)) + else: + tree_edge.data.other_subset = dace_sbs.Range.from_string( + ", ".join(new_consumer_sbs) + ) + + # Now reroute `inner_map_edge` such that it reads from `bcast_value` + # directly, which is available from `bcast_value_conn` at `map_entry`. + state.add_edge( + map_entry, + bcast_value_conn, + inner_map_edge.dst, + inner_map_edge.dst_conn, + inner_map_edge.data, # Was modified above. + ) + state.remove_edge(inner_map_edge) + map_entry.remove_out_connector(inner_map_edge.src_conn) + + # Now remove the connection `(bcast_result) -> map_entry`. + map_entry.remove_in_connector(consumer_edge.dst_conn) + state.remove_edge(consumer_edge) + + +@dace_properties.make_properties +class BrodcastChainRemover(dace_transformation.SingleStateTransformation): + """ + Removes a chain of broadcasts operation. + + It matches the pattern: + ``` + (bcast_value) -----> [Broadcast1] ----> (bcast_tmp) ----> [Broadcast2] + ``` + and turns it into: + ``` + (bcast_value) ----> [Broadcast2] + ``` + + `bcast_tmp` can have multiple consumers. If after the rerouting `bcast_tmp` is + no longer used it is removed. + + There is currently an implementation limitation, in that `brodcast_in_dims` of + the first broadcast node must be an empty list. Which is roughly equivalent to + say that `bcast_value` is a scalar. + + Args: + single_use_data: Use this as single use data, if not given + `FindSingleUseData` will be run at every `apply()` call. + + Todo: + Lift the limitation on `BroadcastNode1::brodcast_in_dims`. + + Note: + This transformation processes more than was matched. + """ + + bcast_value = dace_transformation.PatternNode(dace_nodes.AccessNode) + bcast_node1 = dace_transformation.PatternNode(gtx_lib_nodes.Broadcast) + bcast_tmp = dace_transformation.PatternNode(dace_nodes.AccessNode) + bcast_node2 = dace_transformation.PatternNode( + gtx_lib_nodes.Broadcast + ) # Only for speed up matching. + + # Name of all data that is used at only one place. Is computed by the + # `FindSingleUseData` pass and be passed at construction time. Needed until + # [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved. + _single_use_data: Optional[dict[dace.SDFG, set[str]]] + + @classmethod + def expressions(cls) -> Any: + return [ + dace.sdfg.utils.node_path_graph( + cls.bcast_value, + cls.bcast_node1, + cls.bcast_tmp, + cls.bcast_node2, + ) + ] + + def __init__( + self, + *args: Any, + single_use_data: Optional[dict[dace.SDFG, set[str]]] = None, + **kwargs: Any, + ) -> None: + self._single_use_data = single_use_data + super().__init__(*args, **kwargs) + + def can_be_applied( + self, + graph: dace.SDFGState, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + bcast_node1 = self.bcast_node1 + bcast_tmp = self.bcast_tmp + + # `bcast_tmp` must be fully generated by the first broadcast node. We do not + # need to check here if it is single use data. This will be done in `apply()`. + if graph.in_degree(bcast_tmp) != 1: + return False + + # This is a pure simplification. + # TODO(phimuell): Implement this case. + if len(bcast_node1.brodcast_in_dims) != 0: + return False + + # All consumers of `bcast_tmp` must be broadcast nodes as well. + found_other_bcast_node = False + for oedge in graph.out_edges(bcast_tmp): + if isinstance(oedge.dst, gtx_lib_nodes.Broadcast): + found_other_bcast_node = True + + # We must found some other broadcast nodes. + if not found_other_bcast_node: + return False + + return True + + def apply( + self, + graph: dace.SDFGState, + sdfg: dace.SDFG, + ) -> None: + bcast_node1: gtx_lib_nodes.Broadcast = self.bcast_node1 + bcast_tmp: dace_nodes.AccessNode = self.bcast_tmp + + bcast_value = self.bcast_value + bcast_value_edge = next(iter(graph.in_edges(bcast_node1))) + + # This is possible because we have checked the `brodcast_in_dims`. + assert len(bcast_node1.brodcast_in_dims) == 0 + + # Go through all consumer and bypass the first broadcast node. + for oedge in list(graph.out_edges(bcast_tmp)): + consumer = oedge.dst + if not isinstance(consumer, gtx_lib_nodes.Broadcast): + continue + + consumer.brodcast_in_dims = [] # Possible because of simplification. + graph.add_edge( + bcast_value, + None, + consumer, + "_inp", + copy.deepcopy(bcast_value_edge.data), + ) + graph.remove_edge(oedge) + + # If `bcast_tmp` still has consumer then we are done, because we can not + # delete it. + if graph.out_degree(bcast_value) > 0: + return + + # If there are no consumer, we must check if it is single use data, i.e. used + # in another state. + if self._single_use_data is None: + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + else: + single_use_data = self._single_use_data + + if bcast_tmp.data in single_use_data[sdfg]: + # It is single use data and so we can delete it. + graph.remove_node(bcast_tmp) + graph.remove_node(bcast_node1) + sdfg.remove_data(bcast_tmp.data, validate=gtx_config.DEBUG) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/constants.py b/src/gt4py/next/program_processors/runners/dace/transformations/constants.py index c0e6a057c2..a6b83cd1d4 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/constants.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/constants.py @@ -65,6 +65,7 @@ "SingleStateGlobalSelfCopyElimination", "MultiStateGlobalSelfCopyElimination", "MapToCopy", + "ScalarBrodcastInliner", } ) """Simplify stages disabled during the optimization of dataflow inside the Maps.""" diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_to_copy.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_to_copy.py index c8d0417777..92ecaf4595 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_to_copy.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_to_copy.py @@ -134,6 +134,8 @@ def can_be_applied( # `concat_where`, but I am not sure how important that case is. if isinstance(sdfg.arrays[src_access_node.data], dace_data.Scalar): return False + if len(src_subset) != len(dst_subset): + return False if (src_subset.num_elements() == src_edge.data.volume) != True: # noqa: E712 [true-false-comparison] # SymPy comparison return False if (dst_subset.num_elements() == dst_edge.data.volume) != True: # noqa: E712 [true-false-comparison] # SymPy comparison diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py index 8c08f3459a..be6b942610 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -58,6 +58,7 @@ def gt_simplify( `SingleStateGlobalSelfCopyElimination`, with the exception that the write to `T`, i.e. `(G) -> (T)` and the write back to `G`, i.e. `(T) -> (G)` might be in different states. + - `ScalarBrodcastInliner`: Handles the GT4Py specific broadcast nodes. - `CopyChainRemover`: Which removes some chains that are introduced by the `concat_where` built-in function. - `GT4PyDeadDataflowElimination`: Run `gt_eliminate_dead_dataflow()` on the SDFG, @@ -149,6 +150,46 @@ def gt_simplify( result["FuseStates"] = 0 result["FuseStates"] += fuse_state_res + # Handling broadcasts. We first run `BrodcastChainRemover` before + # `ScalarBrodcastInliner` to prevent some specific situations that might arise. + # Note because of the similarities the transformations share the result of a + # single use data scan. + broadcast_single_use_data_cache: None | dict[dace.SDFG, set[str]] = None + if "BrodcastChainRemover" not in skip: + find_single_use_data = dace_transformation.passes.analysis.FindSingleUseData() + broadcast_single_use_data_cache = find_single_use_data.apply_pass(sdfg, None) + removed_chains = sdfg.apply_transformations_repeated( + gtx_transformations.BrodcastChainRemover( + single_use_data=broadcast_single_use_data_cache, + ), + validate=False, + validate_all=validate_all, + ) + if removed_chains: + at_least_one_xtrans_run = True + result = result or {} + if "BrodcastChainRemover" not in result: + result["BrodcastChainRemover"] = 0 + result["BrodcastChainRemover"] += removed_chains + if "ScalarBrodcastInliner" not in skip: + if broadcast_single_use_data_cache is None: + find_single_use_data = dace_transformation.passes.analysis.FindSingleUseData() + broadcast_single_use_data_cache = find_single_use_data.apply_pass(sdfg, None) + inlined_broadcasts = sdfg.apply_transformations_repeated( + gtx_transformations.ScalarBrodcastInliner( + clean_dead_dataflow=True, + single_use_data=broadcast_single_use_data_cache, + ), + validate=False, + validate_all=validate_all, + ) + if inlined_broadcasts: + at_least_one_xtrans_run = True + result = result or {} + if "ScalarBrodcastInliner" not in result: + result["ScalarBrodcastInliner"] = 0 + result["ScalarBrodcastInliner"] += inlined_broadcasts + if "MapToCopy" not in skip: find_single_use_data = dace_transformation.passes.analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 68a7c33201..c3f5e332e6 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -17,6 +17,8 @@ from dace.transformation import pass_pipeline as dace_ppl from dace.transformation.passes import analysis as dace_analysis +from gt4py.next.program_processors.runners.dace import library_nodes as gtx_lib_nodes + _PassT = TypeVar("_PassT", bound=dace_ppl.Pass) @@ -570,6 +572,33 @@ def reconfigure_dataflow_after_rerouting( other_subset = new_edge.data.src_subset if is_producer_edge else new_edge.data.dst_subset assert other_subset is None + elif isinstance(other_node, gtx_lib_nodes.Broadcast): + # For now we only allow the case where the destination `bacst_result` is replaced + # by another node (`is_producer_edge` is `True`). Furthermore, we only handle the + # case where the dimensionality of data represented by `old_node` and `new_node` + # is the same. This avoids the problems that we need to modify `broadcast_in_dims` + # and `params`. + # But beside these constraints there is nothing to do. + assert isinstance(new_node, dace_nodes.AccessNode) + assert isinstance(old_node, dace_nodes.AccessNode) + + if not is_producer_edge: + raise ValueError("Broadcast nodes are only supported as output.") + + # NOTE: It is possible because to handle this case because we do not need "more" + # Map parameter, we just need to create some dummy ones. This is possible + # because the ones that are "new" will be definition (a Memlet can not + # broadcast) not do any work, i.e. are guaranteed to have size 1. Thus we + # just have to figuring out how to redistribute `params` and `brodcast_in_dims`. + if (other_node.params is None) and (len(other_node.brodcast_in_dims) == 0): + pass # Special case where there is nothing to do. + elif len(new_node.desc(sdfg).shape) != len(old_node.desc(sdfg).shape): + raise NotImplementedError( + "Broadcast reconfiguration only works if `old_node` and `new_node` relresents the same dimensionality." + ) + + # There is nothing else to do in this case. + else: # As we encounter them we should handle them case by case. raise NotImplementedError( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 0ec8608a37..d2db01122b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -119,13 +119,13 @@ def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField: out = cases.allocate(cartesian_case, testee, cases.RETURN)() a = cases.allocate( - cartesian_case, testee, "a", domain=gtx.domain({KDim: out.domain.shape[2]}) + cartesian_case, testee, "a", domain=gtx.domain({KDim: (-2, out.domain.shape[2])}) )() b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])() ref = np.concatenate( ( - np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)), + np.tile(a.asnumpy()[2], (*b.domain.shape[0:2], 1)), b.asnumpy(), ), axis=2, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/library_node_test/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/library_node_test/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/library_node_test/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/library_node_test/test_broadcast.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/library_node_test/test_broadcast.py new file mode 100644 index 0000000000..6c03a36aec --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/library_node_test/test_broadcast.py @@ -0,0 +1,181 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pytest +import copy +import numpy as np + +dace = pytest.importorskip("dace") + +from dace import data as dace_data, subsets as dace_sbs +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, + library_nodes as gtx_lib_nodes, +) + +from ..transformation_tests import util + + +import dace +import numpy as np + + +def _make_2d_broadcast() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("broadcast_2d")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_scalar( + "bcast_value", + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "bcast_result", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + bcast_value = state.add_access("bcast_value") + bcast_result = state.add_access("bcast_result") + + bcast_node1 = gtx_lib_nodes.Broadcast(name="bcast_node1", broadcast_in_dims=[], params=None) + bcast_node2 = gtx_lib_nodes.Broadcast(name="bcast_node2", broadcast_in_dims=[], params=None) + + state.add_edge( + bcast_value, None, bcast_node1, "_inp", dace.Memlet(data=bcast_value.data, subset="0") + ) + state.add_edge( + bcast_node1, + "_outp", + bcast_result, + None, + dace.Memlet(data=bcast_result.data, subset="1, 3:7"), + ) + + state.add_edge( + bcast_value, None, bcast_node2, "_inp", dace.Memlet(data=bcast_value.data, subset="0") + ) + state.add_edge( + bcast_node2, + "_outp", + bcast_result, + None, + dace.Memlet(data=bcast_result.data, subset="5:8, 4:6"), + ) + + sdfg.validate() + + return sdfg, state + + +@pytest.mark.parametrize("use_inplace_expansion", [True, False]) +def test_broadcast_expansion_inplace(use_inplace_expansion: bool): + sdfg, state = _make_2d_broadcast() + + assert state.number_of_nodes() == 4 + assert util.count_nodes(sdfg, gtx_lib_nodes.Broadcast) == 2 + assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 2 + + ref, res = util.make_sdfg_args(sdfg) + + ref["bcast_result"][1, 3:7] = ref["bcast_value"] + ref["bcast_result"][5:8, 4:6] = ref["bcast_value"] + + for node in list(state.nodes()): + if isinstance(node, gtx_lib_nodes.Broadcast): + if use_inplace_expansion: + gtx_lib_nodes.inplace_broadcast_expander(node, state, sdfg) + else: + node.expand(state) + + assert util.count_nodes(sdfg, gtx_lib_nodes.Broadcast) == 0 + assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 2 + + if use_inplace_expansion: + assert state.number_of_nodes() == 8 + assert util.count_nodes(sdfg, dace_nodes.Tasklet) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 + else: + assert state.number_of_nodes() == 4 + assert util.count_nodes(sdfg, dace_nodes.NestedSDFG) == 2 + + util.compile_and_run_sdfg(sdfg, **res) + assert util.compare_sdfg_res(ref=ref, res=res) + + +def _make_broadcast_vector( + broadcast_in_dim: int, +) -> dace.SDFG: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("broadcast_2d")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array( + "bcast_value", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "bcast_result", + shape=tuple((10 if broadcast_in_dim == dim else 15) for dim in range(3)), + dtype=dace.float64, + transient=False, + ) + + dest_subset = ["1:8", "2:9", "3:11"] + dest_subset[broadcast_in_dim] = "0:10" + + bcast_node = gtx_lib_nodes.Broadcast("bcast", broadcast_in_dims=[broadcast_in_dim], params=None) + state.add_node(bcast_node) + + state.add_edge( + state.add_access("bcast_value"), + None, + bcast_node, + "_inp", + sdfg.make_array_memlet("bcast_value"), + ) + state.add_edge( + bcast_node, + "_outp", + state.add_access("bcast_result"), + None, + dace.Memlet(data="bcast_result", subset=",".join(dest_subset)), + ) + + sdfg.validate() + + return sdfg + + +@pytest.mark.parametrize("broadcast_in_dim", [0, 1, 2]) +def test_vector_broadcast(broadcast_in_dim: int): + sdfg = _make_broadcast_vector(broadcast_in_dim) + + ref, res = util.make_sdfg_args(sdfg) + + expand_dims = list(range(3)) + expand_dims.pop(broadcast_in_dim) + bcast_result_extended = np.expand_dims(ref["bcast_value"].copy(), expand_dims) + bcast_result_ref = ref["bcast_result"].copy() + + if broadcast_in_dim == 0: + bcast_result_ref[:, 2:9, 3:11] = bcast_result_extended + elif broadcast_in_dim == 1: + bcast_result_ref[1:8, :, 3:11] = bcast_result_extended + elif broadcast_in_dim == 2: + bcast_result_ref[1:8, 2:9, :] = bcast_result_extended + ref["bcast_result"] = bcast_result_ref.copy() + + util.compile_and_run_sdfg(sdfg, **res) + assert util.compare_sdfg_res(ref=ref, res=res) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_broadcast.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_broadcast.py new file mode 100644 index 0000000000..2e68692db9 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_broadcast.py @@ -0,0 +1,426 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pytest +import copy +import numpy as np + +dace = pytest.importorskip("dace") + +from dace import data as dace_data, subsets as dace_sbs +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, + library_nodes as gtx_lib_nodes, +) + +from . import util + + +def _make_broadcast_map_substititution( + multi_edge: bool, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.MapEntry]: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("broadcast_inliner")) + state = sdfg.add_state(is_start_block=True) + + for aname in "abc": + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=(aname == "b"), + ) + for sname in "dts": + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=(sname != "d"), + ) + + # TODO: Ask Edoardo how to properly use; Maybe also modify it. + bcast_lib = gtx_lib_nodes.Broadcast(name="bcast", broadcast_in_dims=[], params=None) + + a, b, c, d, t, s = (state.add_access(name) for name in "abcdts") + + me, mx = state.add_map("map", ndrange={"__i": "1:10"}) + + tlet1, tlet2, tlet3 = [ + state.add_tasklet( + f"tlet{i + 1}", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code=f"__out = {op}", + ) + for i, op in enumerate(["__in1 + __in2", "__in1 - __in1", "2.0 * __in1 + __in2"]) + ] + + state.add_edge(d, None, bcast_lib, "_inp", dace.Memlet("d[0]")) + state.add_edge(bcast_lib, "_outp", b, None, dace.Memlet("b[0:10]")) + + if multi_edge: + bcast_res_conn1 = "b1" + bcast_res_conn2 = "b2" + else: + bcast_res_conn1 = "b" + bcast_res_conn2 = bcast_res_conn1 + + for conn in {bcast_res_conn1, bcast_res_conn2}: + state.add_edge(b, None, me, "IN_" + conn, dace.Memlet("b[1:10]")) + me.add_scope_connectors(conn) + + state.add_edge(a, None, me, "IN_a", dace.Memlet("a[1:10]")) + me.add_scope_connectors("a") + + state.add_edge(me, "OUT_a", tlet1, "__in1", dace.Memlet("a[__i]")) + state.add_edge(me, "OUT_" + bcast_res_conn1, tlet1, "__in2", dace.Memlet("b[__i - 1]")) + state.add_edge(tlet1, "__out", t, None, dace.Memlet("t[0]")) + + state.add_edge(me, "OUT_a", tlet2, "__in1", dace.Memlet("a[__i - 1]")) + state.add_edge(me, "OUT_" + bcast_res_conn2, tlet2, "__in2", dace.Memlet("b[__i]")) + state.add_edge(tlet2, "__out", s, None, dace.Memlet("s[0]")) + + state.add_edge(s, None, tlet3, "__in1", dace.Memlet("s[0]")) + state.add_edge(t, None, tlet3, "__in2", dace.Memlet("t[0]")) + state.add_edge(tlet3, "__out", mx, "IN_c", dace.Memlet("c[__i]")) + state.add_edge(mx, "OUT_c", c, None, dace.Memlet("c[1:10]")) + mx.add_scope_connectors("c") + + sdfg.validate() + + return sdfg, state, d, b, me + + +@pytest.mark.parametrize("multi_edge", [True, False]) +def test_map_replacement( + multi_edge: bool, +): + sdfg, state, bcast_value, bcast_result, map_entry = _make_broadcast_map_substititution( + multi_edge=multi_edge + ) + + assert util.count_nodes(sdfg, gtx_lib_nodes.Broadcast) == 1 + assert bcast_result in util.count_nodes(sdfg, dace_nodes.AccessNode, True) + assert bcast_result.data in sdfg.arrays + assert state.out_degree(map_entry) == 4 + + if multi_edge: + assert len(map_entry.out_connectors) == 3 + else: + assert len(map_entry.out_connectors) == 2 + + assert not any(iedge.src is bcast_value for iedge in state.in_edges(map_entry)) + + ref, res = util.make_sdfg_args(sdfg) + util.compile_and_run_sdfg(sdfg, **ref) + + nb_applied = sdfg.apply_transformations_repeated( + gtx_transformations.ScalarBrodcastInliner, validate_all=True + ) + assert nb_applied == 1 + assert util.count_nodes(sdfg, gtx_lib_nodes.Broadcast) == 0 + assert bcast_result not in util.count_nodes(sdfg, dace_nodes.AccessNode, True) + assert bcast_result.data not in sdfg.arrays + + # We only have one connection between the `bcast_value` node and the MapEntry + # because the transformation performs duplication. + assert sum([ie.src is bcast_value for ie in state.in_edges(map_entry)]) == 1 + + util.compile_and_run_sdfg(sdfg, **res) + assert util.compare_sdfg_res(ref=ref, res=res) + + +def _make_indirect_access() -> dace.SDFG: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("broadcast_indirect_access")) + state = sdfg.add_state(is_start_block=True) + + for aname in "abc": + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=(aname == "b"), + ) + sdfg.add_array( + "idx", + shape=(10,), + dtype=dace.int32, + transient=False, + ) + + for sname in "dts": + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=(sname != "d"), + ) + + # TODO: Ask Edoardo how to properly use; Maybe also modify it. + bcast_lib = gtx_lib_nodes.Broadcast(name="bcast", broadcast_in_dims=[], params=None) + a, b, c, d, t, s = (state.add_access(name) for name in "abcdts") + idx = state.add_access("idx") + me, mx = state.add_map("map", ndrange={"__i": "0:10"}) + + tlet1, tlet2 = [ + state.add_tasklet( + f"tlet{i + 1}", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code=f"__out = {op}", + ) + for i, op in enumerate(["__in1 + __in2", "2.0 * __in1 + __in2"]) + ] + tlet_idx = state.add_tasklet( + "tlet_indirect", + inputs={"__field", "__idx"}, + outputs={"__out"}, + code="__out = __field[__idx]", + ) + + state.add_edge(d, None, bcast_lib, "_inp", dace.Memlet("d[0]")) + state.add_edge(bcast_lib, "_outp", b, None, dace.Memlet("b[0:10]")) + + state.add_edge(b, None, me, "IN_b", dace.Memlet("b[0:10]")) + state.add_edge(idx, None, me, "IN_idx", dace.Memlet("idx[0:10]")) + state.add_edge(a, None, me, "IN_a", dace.Memlet("a[0:10]")) + me.add_scope_connectors("a") + me.add_scope_connectors("idx") + me.add_scope_connectors("b") + + state.add_edge(me, "OUT_b", tlet_idx, "__field", dace.Memlet("b[0:10]")) + state.add_edge(me, "OUT_idx", tlet_idx, "__idx", dace.Memlet("idx[__i]")) + state.add_edge(tlet_idx, "__out", t, None, dace.Memlet("t[0]")) + + state.add_edge(me, "OUT_a", tlet1, "__in1", dace.Memlet("a[__i]")) + state.add_edge(t, None, tlet1, "__in2", dace.Memlet("t[0]")) + state.add_edge(tlet1, "__out", s, None, dace.Memlet("s[0]")) + + state.add_edge(s, None, tlet2, "__in1", dace.Memlet("s[0]")) + state.add_edge(me, "OUT_b", tlet2, "__in2", dace.Memlet("b[__i]")) + state.add_edge(tlet2, "__out", mx, "IN_c", dace.Memlet("c[__i]")) + state.add_edge(mx, "OUT_c", c, None, dace.Memlet("c[0:10]")) + mx.add_scope_connectors("c") + + sdfg.validate() + + return sdfg + + +def test_indirect_access_broadcast(): + sdfg = _make_indirect_access() + + # NOTE: This pattern could be processed, however, we would need to inspect the + # Tasklet to make sure that it is indeed an indirect access. The safest way + # to do it would be to add another Library node for it. + nb_applied = sdfg.apply_transformations_repeated( + gtx_transformations.ScalarBrodcastInliner, validate_all=True + ) + assert nb_applied == 0 + + +def _make_access_node_chain() -> tuple[ + dace.SDFG, dace.SDFGState, gtx_lib_nodes.Broadcast, dace_nodes.AccessNode +]: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("_broadcast_access_node_chain")) + state = sdfg.add_state(is_start_block=True) + + for aname in "abc": + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=(aname != "c"), + ) + + sdfg.add_scalar( + "s", + dtype=dace.float64, + transient=False, + ) + + a, b, c, s = (state.add_access(name) for name in "abcs") + + bcast_lib = gtx_lib_nodes.Broadcast(name="bcast", broadcast_in_dims=[], params=None) + + state.add_edge(s, None, bcast_lib, "_inp", dace.Memlet("s[0]")) + state.add_edge(bcast_lib, "_outp", a, None, dace.Memlet("a[1:10]")) + state.add_nedge(a, b, dace.Memlet("a[2:9] -> [0:7]")) + state.add_nedge(b, c, dace.Memlet("c[1:5] -> [2:6]")) + + sdfg.validate() + + return sdfg, state, bcast_lib, c + + +def test_access_node_chain(): + sdfg, state, bcast_lib, c = _make_access_node_chain() + + ac_before = util.count_nodes(sdfg, dace_nodes.AccessNode, True) + assert len(ac_before) == 4 + assert c in ac_before + assert isinstance(sdfg.arrays["s"], dace_data.Scalar) + assert c.data in sdfg.arrays + assert list(util.count_nodes(sdfg, gtx_lib_nodes.Broadcast, True)) == [bcast_lib] + assert state.in_degree(c) == 1 + assert all(e.src is not bcast_lib for e in state.in_edges(c)) + + nb_applied = sdfg.apply_transformations_repeated( + gtx_transformations.ScalarBrodcastInliner, validate_all=True + ) + assert nb_applied == 2 + + ac_before = util.count_nodes(sdfg, dace_nodes.AccessNode, True) + assert len(ac_before) == 2 + assert c in ac_before + assert isinstance(sdfg.arrays["s"], dace_data.Scalar) + assert c.data in sdfg.arrays + + # In the AccessNode mode the broadcast node is copied. In this case it is not + # needed, but it is needed in more general cases. + # TODO(phimuell): Once the node is finalized, check if it is copied correctly. + bcast_libs_after = list(util.count_nodes(sdfg, gtx_lib_nodes.Broadcast, True)) + assert len(bcast_libs_after) == 1 + assert bcast_libs_after[0] is not bcast_lib + + bcast_edge = next(iter(state.in_edges(c))) + assert state.in_degree(c) == 1 + assert bcast_edge.src is bcast_libs_after[0] + assert bcast_edge.data.dst_subset == dace_sbs.Range.from_string("1:5") + + +def _make_access_node_fan_out() -> tuple[dace.SDFG, dace.SDFGState, dict[str, str]]: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("_broadcast_access_node_fan_out")) + state = sdfg.add_state(is_start_block=True) + + bcast_result_name = "bcast_result" + res_names = { + "res1": "1:10", + "res2": "2:8", + } + for aname in [bcast_result_name] + list(res_names.keys()): + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=(not aname.startswith("res")), + ) + + sdfg.add_scalar( + "s", + dtype=dace.float64, + transient=False, + ) + + bcast_result = state.add_access(bcast_result_name) + s = state.add_access("s") + bcast_lib = gtx_lib_nodes.Broadcast(name="bcast", broadcast_in_dims=[], params=None) + + state.add_edge(s, None, bcast_lib, "_inp", dace.Memlet("s[0]")) + state.add_edge(bcast_lib, "_outp", bcast_result, None, dace.Memlet(f"{bcast_result}[0:10]")) + + for dname, sbs in res_names.items(): + state.add_nedge(bcast_result, state.add_access(dname), dace.Memlet(f"{dname}[{sbs}]")) + + return sdfg, state, res_names + + +def test_access_node_fan_out(): + sdfg, state, res_names = _make_access_node_fan_out() + + assert util.count_nodes(sdfg, gtx_lib_nodes.Broadcast) == 1 + assert util.count_nodes(sdfg, dace_nodes.AccessNode) == (len(res_names) + 2) + assert all(res_name in sdfg.arrays for res_name in res_names) + + nb_applied = sdfg.apply_transformations_repeated( + gtx_transformations.ScalarBrodcastInliner, validate_all=True + ) + assert nb_applied == 1 + + # The broadcast nodes are replicated in the AccessNode mode, one for each output. + bcast_libs_after = set(util.count_nodes(sdfg, gtx_lib_nodes.Broadcast, True)) + assert len(bcast_libs_after) == len(res_names) + + ac_after = util.count_nodes(sdfg, dace_nodes.AccessNode, True) + assert len(ac_after) == (1 + len(res_names)) + + for res_ac in ac_after: + if res_ac.data == "s": + assert res_ac.data in sdfg.arrays + assert isinstance(sdfg.arrays[res_ac.data], dace_data.Scalar) + else: + assert res_ac.data in res_names + state.in_degree(res_ac) == 1 + edge = next(iter(state.in_edges(res_ac))) + assert edge.src in bcast_libs_after + assert edge.data.dst_subset == dace_sbs.Range.from_string(res_names[res_ac.data]) + bcast_libs_after.remove(edge.src) + + +def _make_access_node_multi_connection(): + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("_broadcast_access_node_multi_connection") + ) + state = sdfg.add_state(is_start_block=True) + + for aname in "ab": + sdfg.add_array( + aname, + shape=(10,), + dtype=dace.float64, + transient=(aname == "a"), + ) + + sdfg.add_scalar( + "s", + dtype=dace.float64, + transient=False, + ) + + a, b, s = (state.add_access(name) for name in "abs") + bcast_lib = gtx_lib_nodes.Broadcast(name="bcast", broadcast_in_dims=[], params=None) + + state.add_edge(s, None, bcast_lib, "_inp", dace.Memlet("s[0]")) + state.add_edge(bcast_lib, "_outp", a, None, dace.Memlet("a[1:10]")) + + state.add_nedge(a, b, dace.Memlet("a[2:5] -> [1:4]")) + state.add_nedge(a, b, dace.Memlet("b[6:9] -> [4:7]")) + + return sdfg, state, b + + +def test_access_node_multi_connection(): + sdfg, state, b = _make_access_node_multi_connection() + + assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 3 + assert state.in_degree(b) == 2 + bcast_libs_before = list(util.count_nodes(sdfg, gtx_lib_nodes.Broadcast, True)) + assert len(bcast_libs_before) == 1 + + nb_applied = sdfg.apply_transformations_repeated( + gtx_transformations.ScalarBrodcastInliner, validate_all=True + ) + assert nb_applied == 1 + + # The broadcast library nodes are replicated and the original one is then removed. + bcast_libs_after = set(util.count_nodes(sdfg, gtx_lib_nodes.Broadcast, True)) + assert len(bcast_libs_after) == 2 + assert bcast_libs_before[0] not in bcast_libs_after + + ac_after = util.count_nodes(sdfg, dace_nodes.AccessNode, True) + assert len(ac_after) == 2 + assert b in ac_after + assert state.in_degree(b) == 2 + assert bcast_libs_after == {e.src for e in state.in_edges(b)} + + expected_sbs = { + dace_sbs.Range.from_string("1:4"), + dace_sbs.Range.from_string("6:9"), + } + assert expected_sbs == {e.data.dst_subset for e in state.in_edges(b)} diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py index 614b937cae..ac38e4d0c8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py @@ -72,7 +72,9 @@ def compile_and_run_sdfg( shared objects that are loaded multiple times. """ - with dace.config.set_temporary("compiler.use_cache", value=False): + with dace.config.temporary_config() as config: + config.set("compiler.use_cache", value=False) + config.set("compiler.allow_view_arguments", value=True) sdfg_clone = copy.deepcopy(sdfg) sdfg_clone.name = gtx_transformations_utils.unique_name(sdfg_clone.name)