diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 799e8ad228..e390f61d8d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -10,7 +10,7 @@ import enum import warnings -from typing import Any, Callable, Optional, Sequence, TypeAlias, Union +from typing import Any, Callable, List, Optional, Sequence, TypeAlias, Union import dace from dace import data as dace_data @@ -118,9 +118,11 @@ def gt_auto_optimize( gpu_block_size_2d: Optional[Sequence[int | str] | str] = None, gpu_block_size_3d: Optional[Sequence[int | str] | str] = None, gpu_maxnreg: Optional[int] = None, - blocking_dim: Optional[gtx_common.Dimension] = None, + blocking_dims: Optional[List[gtx_common.Dimension]] = None, blocking_size: int = 10, blocking_only_if_independent_nodes: bool = True, + promote_independent_memlets_for_blocking: bool = False, + blocking_independent_node_threshold: Optional[int] = None, scan_loop_unrolling: bool = False, scan_loop_unrolling_factor: int = 0, disable_splitting: bool = False, @@ -179,7 +181,7 @@ def gt_auto_optimize( gpu_block_size_{1, 2, 3}d: Allows to specify the GPU thread block size for 1, 2 and 3 dimension Maps individually. See the `gpu_block_size_spec` argument of `gt_gpu_transformation()` for more. - blocking_dim: On which dimension blocking should be applied. + blocking_dims: On which dimensions blocking should be applied. Priority based on the order of the passed dimensions. blocking_size: How many elements each block should process. blocking_only_if_independent_nodes: If `True`, the default, only apply loop blocking if there are independent nodes in the Map, see the @@ -323,9 +325,11 @@ def gt_auto_optimize( # Optimize the interior of the Maps: sdfg = _gt_auto_process_dataflow_inside_maps( sdfg=sdfg, - blocking_dim=blocking_dim, + blocking_dims=blocking_dims, blocking_size=blocking_size, blocking_only_if_independent_nodes=blocking_only_if_independent_nodes, + promote_independent_memlets_for_blocking=promote_independent_memlets_for_blocking, + blocking_independent_node_threshold=blocking_independent_node_threshold, scan_loop_unrolling=scan_loop_unrolling, scan_loop_unrolling_factor=scan_loop_unrolling_factor, fuse_tasklets=fuse_tasklets, @@ -671,9 +675,11 @@ def _gt_auto_process_top_level_maps( def _gt_auto_process_dataflow_inside_maps( sdfg: dace.SDFG, - blocking_dim: Optional[gtx_common.Dimension], + blocking_dims: Optional[list[gtx_common.Dimension]], blocking_size: int, blocking_only_if_independent_nodes: Optional[bool], + promote_independent_memlets_for_blocking: Optional[bool], + blocking_independent_node_threshold: Optional[int], scan_loop_unrolling: bool, scan_loop_unrolling_factor: int, fuse_tasklets: bool, @@ -694,12 +700,14 @@ def _gt_auto_process_dataflow_inside_maps( # Separate Tasklets into dependent and independent parts to promote data # reusability. It is important that this step has to be performed before # `TaskletFusion` is used. - if blocking_dim is not None: + if blocking_dims is not None: sdfg.apply_transformations_once_everywhere( gtx_transformations.LoopBlocking( blocking_size=blocking_size, - blocking_parameter=blocking_dim, + blocking_parameters=blocking_dims, require_independent_nodes=blocking_only_if_independent_nodes, + promote_independent_memlets=promote_independent_memlets_for_blocking, + independent_node_threshold=blocking_independent_node_threshold, ), validate=False, validate_all=validate_all, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index aa34736c8a..6344b3f1eb 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -762,7 +762,7 @@ def apply( block_size[i] = map_size[map_dim_idx_to_inspect] gpu_map.gpu_block_size = tuple(block_size) - if self.maxnreg is not None: + if self.maxnreg is not None and gpu_map.gpu_maxnreg == 0: gpu_map.gpu_maxnreg = self.maxnreg elif launch_bounds is not None: # Note: empty string has a meaning in DaCe gpu_map.gpu_launch_bounds = launch_bounds diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index 2f25d4f1c3..69616f1f1e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -7,6 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +import warnings +from collections.abc import Sequence from typing import Any, Optional, Union import dace @@ -52,6 +54,8 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): blocking_parameter: On which parameter should we block. require_independent_nodes: If `True` only apply loop blocking if the Map actually contains independent nodes. Defaults to `True`. + promote_independent_memlets: If `True` then memlets with independent data are promoted to the outer map. + independent_node_threshold: Minimum number of independent nodes required to apply blocking (non-inclusive). Todo: - Modify the inner map such that it always starts at zero. @@ -63,46 +67,87 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): allow_none=True, desc="Size of the inner blocks; 'B' in the above description.", ) - blocking_parameter = dace_properties.Property( - dtype=str, + blocking_parameters = dace_properties.Property( + dtype=list, allow_none=True, - desc="Name of the iteration variable on which to block (must be an exact match);" - " 'I' in the above description.", + desc="Names of the iteration variables on which to block. The first one that is found in the map parameters is used", ) require_independent_nodes = dace_properties.Property( dtype=bool, default=True, desc="If 'True' then blocking is only applied if there are independent nodes.", ) + promote_independent_memlets = dace_properties.Property( + dtype=bool, + default=False, + desc="If 'True' then memlets with independent data are promoted to the outer map.", + ) + independent_node_threshold = dace_properties.Property( + dtype=int, + default=0, + desc="Minimum number of independent nodes required to apply blocking (non-inclusive).", + ) # Set of nodes that are independent of the blocking parameter. _independent_nodes: Optional[set[dace_nodes.AccessNode]] _dependent_nodes: Optional[set[dace_nodes.AccessNode]] + _memlet_to_promote: Optional[list[dace_graph.MultiConnectorEdge[dace.Memlet]]] outer_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, blocking_size: Optional[int] = None, - blocking_parameter: Optional[Union[gtx_common.Dimension, str]] = None, + blocking_parameters: Optional[Sequence[Union[gtx_common.Dimension, str]]] = None, require_independent_nodes: Optional[bool] = None, + promote_independent_memlets: Optional[bool] = None, + independent_node_threshold: Optional[int] = None, ) -> None: super().__init__() - if isinstance(blocking_parameter, gtx_common.Dimension): - blocking_parameter = gtx_dace_lowering.get_map_variable(blocking_parameter) - if blocking_parameter is not None: - self.blocking_parameter = blocking_parameter + if blocking_parameters is not None: + self.blocking_parameters = [ + gtx_dace_lowering.get_map_variable(p) if isinstance(p, gtx_common.Dimension) else p + for p in blocking_parameters + ] + else: + self.blocking_parameters = None if blocking_size is not None: self.blocking_size = blocking_size if require_independent_nodes is not None: self.require_independent_nodes = require_independent_nodes + if promote_independent_memlets is not None: + self.promote_independent_memlets = promote_independent_memlets + if independent_node_threshold is not None: + self.independent_node_threshold = independent_node_threshold self._independent_nodes = None self._dependent_nodes = None + self._memlet_to_promote = None @classmethod def expressions(cls) -> Any: return [dace.sdfg.utils.node_path_graph(cls.outer_entry)] + def _populate_memlet_to_promote( + self, + matched_blocking_var: str, + graph: Union[dace.SDFGState, dace.SDFG], + outer_entry: dace_nodes.MapEntry, + ) -> None: + self._memlet_to_promote = [] + for out_edge_outer_entry in graph.out_edges(outer_entry): + if self._check_if_edge_can_be_promoted( + matched_blocking_var, out_edge_outer_entry, outer_entry + ): + self._memlet_to_promote.append(out_edge_outer_entry) + + def _get_blocking_parameter(self, map_params: list[str]) -> Optional[str]: + if self.blocking_parameters is None: + return None + for p in self.blocking_parameters: + if p in map_params: + return p + return None + def can_be_applied( self, graph: Union[dace.SDFGState, dace.SDFG], @@ -118,7 +163,7 @@ def can_be_applied( - The map range must have step size of 1. - The partition must exists (see `partition_map_output()`). """ - if self.blocking_parameter is None: + if self.blocking_parameters is None: raise ValueError("The blocking dimension was not specified.") elif self.blocking_size is None: raise ValueError("The blocking size was not specified.") @@ -126,17 +171,21 @@ def can_be_applied( outer_entry: dace_nodes.MapEntry = self.outer_entry map_params: list[str] = outer_entry.map.params map_range: dace_subsets.Range = outer_entry.map.range - block_var: str = self.blocking_parameter + + matched_blocking_var: str | None = self._get_blocking_parameter(map_params) + if matched_blocking_var is None: + return False scope = graph.scope_dict() if scope[outer_entry] is not None: return False - if block_var not in map_params: + block_var_idx = map_params.index(matched_blocking_var) + map_range_size = map_range.size() + + if all((map_range_size_i == 1) == True for map_range_size_i in map_range_size): # noqa: E712 [true-false-comparison] # SymPy fuzzy bools. return False - block_var_idx = map_params.index(block_var) - map_range_size = map_range.size() if map_range[block_var_idx][2] != 1: return False @@ -146,10 +195,32 @@ def can_be_applied( if (map_range_size[block_var_idx] <= self.blocking_size) == True: # noqa: E712 [true-false-comparison] # SymPy Fancy comparison. return False - if not self.partition_map_output(graph, sdfg): + if self.promote_independent_memlets: + self._populate_memlet_to_promote(matched_blocking_var, graph, outer_entry) + + if ( + not self.partition_map_output(matched_blocking_var, graph, sdfg, outer_entry) + and self.require_independent_nodes + and ( + not self.promote_independent_memlets + or (self._memlet_to_promote is not None and len(self._memlet_to_promote) == 0) + ) + ): + return False + + # Check if the blocking is a good idea. + if not self._check_if_blocking_can_promote_anything(state=graph): return False + + # Disable by default scans because there is `ScanLoopUnrolling` for them (if the blocking is done in the same dimension as the scan). + # Otherwise blocking the loop in the other dimension shouldn't be beneficial. + for node in graph.scope_subgraph(outer_entry).nodes(): + if isinstance(node, dace_nodes.NestedSDFG) and node.label.startswith("scan_"): + return False + self._independent_nodes = None self._dependent_nodes = None + self._memlet_to_promote = None return True @@ -162,11 +233,21 @@ def apply( Performs the operation described in the doc string. """ + outer_entry: dace_nodes.MapEntry = self.outer_entry + map_params: list[str] = outer_entry.map.params + matched_blocking_var: str | None = self._get_blocking_parameter(map_params) + assert matched_blocking_var is not None + # Now compute the partitions of the nodes. - self.partition_map_output(graph, sdfg) + self.partition_map_output(matched_blocking_var, graph, sdfg, outer_entry) + + if self.promote_independent_memlets: + self._prepare_independent_memlets(matched_blocking_var, graph, sdfg, outer_entry) # Modify the outer map and create the inner map. - (outer_entry, outer_exit), (inner_entry, inner_exit) = self._prepare_inner_outer_maps(graph) + (outer_entry, outer_exit), (inner_entry, inner_exit) = self._prepare_inner_outer_maps( + matched_blocking_var, graph, outer_entry + ) # Reconnect the edges self._rewire_map_scope( @@ -177,20 +258,26 @@ def apply( state=graph, sdfg=sdfg, ) + inner_entry.map.unroll = True + # TODO(iomaganaris): By default unroll the inner loop with the blocking size, but it might be interesting to have this as a separate parameter. + inner_entry.map.unroll_factor = self.blocking_size self._independent_nodes = None self._dependent_nodes = None + self._memlet_to_promote = None def _prepare_inner_outer_maps( self, + matched_blocking_var: str, state: dace.SDFGState, + outer_entry: dace_nodes.MapEntry, ) -> tuple[ tuple[dace_nodes.MapEntry, dace_nodes.MapExit], tuple[dace_nodes.MapEntry, dace_nodes.MapExit], ]: """Prepare the maps for the blocking. - The function modifies the outer map, `self.outer_entry`, by replacing the - blocking parameter, `self.blocking_parameter`, with a coarsened version + The function modifies the `outer_map` by replacing the + blocking parameter, `matched_blocking_var`, with a coarsened version of it. In addition the function will then create the inner map, that iterates over the blocking parameter, and these bounds are determined by the coarsened blocking parameter of the outer map. @@ -204,12 +291,11 @@ def _prepare_inner_outer_maps( inner map. Each element consist of a pair containing the map entry and map exit nodes of the corresponding maps. """ - outer_entry: dace_nodes.MapEntry = self.outer_entry outer_exit: dace_nodes.MapExit = state.exit_node(outer_entry) outer_map: dace_nodes.Map = outer_entry.map outer_range: dace_subsets.Range = outer_entry.map.range outer_params: list[str] = outer_entry.map.params - blocking_parameter_dim = outer_params.index(self.blocking_parameter) + blocking_parameter_dim = outer_params.index(matched_blocking_var) # This is the name of the iterator that we use in the outer map for the # blocked dimension @@ -217,14 +303,14 @@ def _prepare_inner_outer_maps( # could be matched through the `unit_strides_kind` argument, which is # the case with this approach. But it makes the `unit_strides_dim` # argument `gt_set_iteration_order()` inapplicable. - coarse_block_var = "__gtx_coarse_" + self.blocking_parameter + coarse_block_var = "__gtx_coarse_" + matched_blocking_var # Generate the sequential inner map rng_start = outer_range[blocking_parameter_dim][0] rng_stop = outer_range[blocking_parameter_dim][1] inner_label = f"inner_{outer_map.label}" inner_range = { - self.blocking_parameter: dace_subsets.Range.from_string( + matched_blocking_var: dace_subsets.Range.from_string( f"(({rng_start}) + ({coarse_block_var}) * ({self.blocking_size})):" f"min(({rng_start}) + ({coarse_block_var} + 1) * ({self.blocking_size}), ({rng_stop}) + 1)" ) @@ -249,8 +335,10 @@ def _prepare_inner_outer_maps( def partition_map_output( self, + matched_blocking_var: str, state: dace.SDFGState, sdfg: dace.SDFG, + outer_entry: dace_nodes.MapEntry, ) -> bool: """Computes the partition the of the nodes of the Map. @@ -301,13 +389,14 @@ def partition_map_output( self._independent_nodes = set() self._dependent_nodes = None + # We only need to do the following if we require some independent nodes to exist while True: # Find all the nodes that we have to classify in this iteration. # - All nodes adjacent to `outer_entry` (which is # independent by definition). # - All nodes adjacent to independent nodes. nodes_to_classify: set[dace_nodes.Node] = { - edge.dst for edge in state.out_edges(self.outer_entry) + edge.dst for edge in state.out_edges(outer_entry) } for independent_node in self._independent_nodes: nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) @@ -317,9 +406,11 @@ def partition_map_output( found_new_independent_node = False for node_to_classify in nodes_to_classify: class_res = self._classify_node( + matched_blocking_var=matched_blocking_var, node_to_classify=node_to_classify, state=state, sdfg=sdfg, + outer_entry=outer_entry, ) # Check if the partition exists. @@ -338,22 +429,17 @@ def partition_map_output( assert all( all( - iedge.src in self._independent_nodes or iedge.src is self.outer_entry + iedge.src in self._independent_nodes or iedge.src is outer_entry for iedge in state.in_edges(inode) ) for inode in self._independent_nodes ) - # If requested check if the blocking is a good idea. - if self.require_independent_nodes and (not self._check_if_blocking_is_favourable(state)): - self._independent_nodes = None - return False - # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not independent. self._dependent_nodes = { edge.dst - for edge in state.out_edges(self.outer_entry) + for edge in state.out_edges(outer_entry) if edge.dst not in self._independent_nodes } @@ -361,9 +447,11 @@ def partition_map_output( def _classify_node( self, + matched_blocking_var: str, node_to_classify: dace_nodes.Node, state: dace.SDFGState, sdfg: dace.SDFG, + outer_entry: dace_nodes.MapEntry, ) -> bool | None: """Internal function for classifying a single node. @@ -374,7 +462,7 @@ def _classify_node( - All incoming _empty_ edges must be connected to the map entry. - A node either has only empty Memlets or none of them. - Incoming Memlets does not depend on the blocking parameter. - - All incoming edges must start either at `self.outer_entry` or at dependent nodes. + - All incoming edges must start either at `outer_entry` or at dependent nodes. - All output Memlets are non empty. It is important that to realize that the function will add the node to @@ -393,7 +481,6 @@ def _classify_node( sdfg: The SDFG that is processed. """ assert self._independent_nodes is not None # silence MyPy - outer_entry: dace_nodes.MapEntry = self.outer_entry # for caching. outer_exit: dace_nodes.MapExit = state.exit_node(outer_entry) # The node needs to have an input and output. @@ -416,7 +503,7 @@ def _classify_node( # Despite its type the node's free symbols can not contain the blocking # parameter. In case of a Tasklet this would be the body of the Tasklet. - if self.blocking_parameter in node_to_classify.free_symbols: + if matched_blocking_var in node_to_classify.free_symbols: return False # If the test succeed then these are the nodes we additionally consider @@ -451,7 +538,7 @@ def _classify_node( # Additionally, test if the symbol mapping depends on the block variable. for v in node_to_classify.symbol_mapping.values(): - if self.blocking_parameter in v.free_symbols: + if matched_blocking_var in v.free_symbols: return False elif isinstance(node_to_classify, dace_nodes.AccessNode): @@ -477,7 +564,7 @@ def _classify_node( # The blocking parameter can not be used inside the Map scope. map_scope = state.scope_subgraph(node_to_classify) - if self.blocking_parameter in map_scope.free_symbols: + if matched_blocking_var in map_scope.free_symbols: return False # There is an obscure case, where the Memlet on the inside of a Map scope @@ -485,7 +572,7 @@ def _classify_node( # check that here. Note that we only have to do it here in this case # because normally it would be spotted above where we checked the input. out_edges = list(state.out_edges(map_exit)) - if any(self.blocking_parameter in out_edge.data.free_symbols for out_edge in out_edges): + if any(matched_blocking_var in out_edge.data.free_symbols for out_edge in out_edges): return False # Add all nodes of the Map scope, including entry and exit node to the @@ -544,7 +631,7 @@ def _classify_node( # If a subset needs the block variable then the node is not independent # but dependent. - if any(self.blocking_parameter in subset.free_symbols for subset in subsets_to_inspect): + if any(matched_blocking_var in subset.free_symbols for subset in subsets_to_inspect): return False # The edge must either originate from `outer_entry` or from an independent @@ -600,6 +687,225 @@ def _post_process_independent_nodes( independent_nodes_were_updated = True break + def _check_if_edge_can_be_promoted( + self, + matched_blocking_var: str, + edge: dace_graph.MultiConnectorEdge[dace.Memlet], + outer_entry: dace_nodes.MapEntry, + ) -> bool: + """Check if a memlet can be promoted to the outer map. + + The function checks if the memlet can be promoted to the outer map. + This is the case if the memlet does not depend on the blocking parameter. + + Args: + edge: The edge containing the memlet to inspect. + Returns: + The function returns `True` if the memlet can be promoted. + """ + assert edge.src == outer_entry + + memlet: dace.Memlet = edge.data + src_subset: dace_subsets.Subset | None = memlet.src_subset + dst_subset: dace_subsets.Subset | None = memlet.dst_subset + + if self._independent_nodes is not None and edge.dst in self._independent_nodes: + # If the memlet is connected to an independent node, then we can not promote it, since it would be redundant. + return False + + if ( + memlet.is_empty() + ): # Empty Memlets should already be in independent nodes and don't have read dependencies + return False + + # Now we have to look at the source and destination set of the Memlet. + subsets_to_inspect: list[dace_subsets.Subset] = [] + if dst_subset is not None: + subsets_to_inspect.append(dst_subset) + if src_subset is not None: + subsets_to_inspect.append(src_subset) + + if any(matched_blocking_var in subset.free_symbols for subset in subsets_to_inspect): + return False + + # If the memlet is connected to a MapEntry and the MapEnty parameters contain the blocking parameter, + # we could promote the memlet to the outer map but at the time there's no case where this can happen + # so we leave it as future work. + # TODO(iomaganaris): Implement this case if it turns out to be relevant. + if isinstance(edge.dst, dace_nodes.MapEntry): + if matched_blocking_var in edge.dst.params: + return False + + if isinstance(edge.dst, dace_nodes.Tasklet) and not edge.data.data.startswith("gt_conn_"): + # TODO(iomaganaris): This check is done for tesklets that have as input a field + # (connection name: `__tlet_field`) and one or two offsets + # (connection name: `__tlet_index_Cell` and `__tlet_index_K`). + # If there is no `K` dependent offset if we have blocking on `K`, we can promote the memlet. + # If there is no `Cell` dependent offset if we have blocking on `Cell`, we can promote the memlet. + # However above require checking the tasklet internals without being sure of these conventions. + # For that purpose we choose to promote only memlets that are passed to tasklets only + # if they are connectivities. + return False + + if isinstance(edge.dst, dace_nodes.LibraryNode): + # We currently do not handle promotion of memlets to library nodes, since it is not clear if this can actually happen in the cases we care about and it would require some work to handle the different cases. + warnings.warn( + "LoopBlocking: Memlet promotion to LibraryNode is not supported and will be skipped.", + stacklevel=2, + ) + return False + + if isinstance(edge.dst, dace_nodes.NestedSDFG): + # We currently do not handle promotion of memlets to nested SDFGs, since it is not clear if this can actually happen in the cases we care about and it would require some work to handle the different cases. + warnings.warn( + "LoopBlocking: Memlet promotion to NestedSDFG is not supported and will be skipped.", + stacklevel=2, + ) + return False + + return True + + def _prepare_independent_memlets( + self, + matched_blocking_var: str, + state: dace.SDFGState, + sdfg: dace.SDFG, + outer_map_entry: dace_nodes.MapEntry, + ) -> None: + assert self._independent_nodes is not None # silence MyPy + assert self._memlet_to_promote is None # silence MyPy + + _ = sdfg.reset_cfg_list() + dace_sdutils.canonicalize_memlet_trees_for_map(state=state, map_node=outer_map_entry) + dace_propagation.propagate_memlets_map_scope(sdfg, state, outer_map_entry) + + self._populate_memlet_to_promote(matched_blocking_var, state, outer_map_entry) + # Below checks are necessary for MyPy + if self._memlet_to_promote and len(self._memlet_to_promote) == 0: + return + assert self._memlet_to_promote is not None + + for in_edge in self._memlet_to_promote: + if isinstance(in_edge.dst, dace_nodes.AccessNode): + raise NotImplementedError( + "Promotion of memlets to AccessNodes is not implemented because " + "this case should already be handled since the destination AccesesNode " + "should already be in the set of independent nodes." + ) + # Create a temporary AccessNode that will be used to promote the memlet + promoted_accessnode_shape = [] + # We have to adjust the subsets of the inner map out edges that are connected to the memlet we want to promote in case the destination of the in_edge is a MapEntry. + corresponding_inner_map_out_edges = list( + state.out_edges_by_connector(in_edge.dst, "OUT_" + in_edge.dst_conn[3:]) + ) + assert ( + len(corresponding_inner_map_out_edges) > 0 + if isinstance(in_edge.dst, dace_nodes.MapEntry) + else True + ), ( + "If the destination of the memlet we want to promote is a MapEntry, there should be at least one corresponding inner map out edge." + ) + # Store old subsets of the inner map out edges to be able to adjust them later. + inner_map_out_edges_and_old_subsets = dict() + for inner_map_out_edge in corresponding_inner_map_out_edges: + inner_map_out_edges_and_old_subsets[inner_map_out_edge] = ( + inner_map_out_edge.data.subset + ) + # Dict of new subsets for the inner map out edges after removing the independent dimensions. + inner_map_out_edges_and_new_subsets = dict() + for inner_map_out_edge in corresponding_inner_map_out_edges: + inner_map_out_edges_and_new_subsets[inner_map_out_edge] = [] + # Make sure that the subsets of the inner map out edges have the same number of dimensions as the memlet we want to promote. + assert all( + len(corresponding_inner_map_out_edge.data.subset) == len(in_edge.data.subset) + for corresponding_inner_map_out_edge in corresponding_inner_map_out_edges + ) + # Go through the subsets of the independent memlet to find out what should be the size of the temporary access node. + # If the in_edge.dst is a MapEntry, we have to adjust the subsets of the inner map out edges that are connected to the memlet we want to promote by removing the independent dimensions. + for i, subset_range in enumerate(in_edge.data.subset.ranges): + start, end, step = subset_range + subset_free_symbols = start.free_symbols.union(end.free_symbols).union( + step.free_symbols + ) + subset_free_symbols = set(str(s) for s in subset_free_symbols) + if not subset_free_symbols.intersection(set(outer_map_entry.map.params)): + for inner_map_out_edge in corresponding_inner_map_out_edges: + inner_map_out_edges_and_new_subsets[inner_map_out_edge].append( + inner_map_out_edges_and_old_subsets[inner_map_out_edge][i] + ) + promoted_accessnode_shape.append(dace_subsets.Range([subset_range]).size()[0]) + # The subsets of the inner map out edges should have at least one dimension left after removing the independent dimensions. + assert all( + len(inner_map_out_edges_and_new_subset) > 0 + for inner_map_out_edges_and_new_subset in inner_map_out_edges_and_new_subsets.values() + ), ( + "After removing the independent dimensions there should be at least one dimension left to promote." + ) + # The subset of the memlet from inner MapEntry. + assert all( + len(inner_map_out_edges_and_new_subset) <= len(in_edge.data.subset.ranges) - 1 + for inner_map_out_edges_and_new_subset in inner_map_out_edges_and_new_subsets.values() + ), ( + "After removing the independent dimensions there should be at least one dimension smaller than the outer map." + ) + promoted_name, promoted_desc = sdfg.add_temp_transient( + shape=promoted_accessnode_shape, + dtype=sdfg.arrays[in_edge.data.data].dtype, + ) + promoted_anode = state.add_access(promoted_name) + original_dst_of_in_edge = in_edge.dst + original_dst_conn_of_in_edge = in_edge.dst_conn + original_dst_other_subset_of_in_edge = in_edge.data.other_subset + # Redirect the memlet to the temporary AccessNode + dace_helpers.redirect_edge( + state=state, + edge=in_edge, + new_dst=promoted_anode, + new_dst_conn=None, + new_memlet=dace.Memlet( + data=in_edge.data.data, + subset=in_edge.data.subset, + other_subset=dace_subsets.Range.from_array(sdfg.arrays[promoted_name]), + ), + ) + + # Create a new memlet from the temporary AccessNode to the original destination + state.add_edge( + promoted_anode, + None, + original_dst_of_in_edge, + original_dst_conn_of_in_edge, + memlet=dace.Memlet( + data=promoted_name, + subset=dace_subsets.Range.from_array(promoted_desc), + other_subset=original_dst_other_subset_of_in_edge, + ), + ) + + if isinstance(original_dst_of_in_edge, dace_nodes.MapEntry): + # The following logic works only if the level of Map nesting is up to 2. + # This should the usual case in our applications but it is not guaranteed in general. + # In case we have more than 2 levels of Maps we have to apply the same logic recursively + # for each memlet with the `in_edge.data.data` that is connecting MapEntry to MapEntry + # and creates the nesting. + for inner_map_out_edge in state.out_edges(original_dst_of_in_edge): + if inner_map_out_edge.data.data == in_edge.data.data: + for memlet_tree in state.memlet_tree(inner_map_out_edge).traverse_children( + include_self=True + ): + edge_to_adjust = memlet_tree.edge + assert edge_to_adjust.data.data == in_edge.data.data + edge_to_adjust.data.data = promoted_name + assert len(original_dst_of_in_edge.params) == 1, ( + "Independent memlets should only be inputs to maps that have a single parameter. " + "Those should always be neighbor reductions." + ) + edge_to_adjust.data.subset = dace_subsets.Range( + inner_map_out_edges_and_new_subsets[inner_map_out_edge] + ) + + self._independent_nodes.add(promoted_anode) + def _rewire_map_scope( self, outer_entry: dace_nodes.MapEntry, @@ -845,26 +1151,16 @@ def _rewire_map_scope( dace_sdutils.canonicalize_memlet_trees_for_map(state=state, map_node=outer_entry) dace_propagation.propagate_memlets_map_scope(sdfg, state, outer_entry) - def _check_if_blocking_is_favourable( + def _check_if_blocking_can_promote_anything( self, state: dace.SDFGState, ) -> bool: - """Test if the nodes are really independent nodes. - - After the classification the function will examine the set to see if some - nodes were found that brings no benefit to move out. The classical example - is a Tasklet that writes a constant into an AccessNode. These kind of - nodes are filtered out. + """Test if there can be any memory access promoted outside the inner loop. - The function returns `True` if it decides that blocking is good and `False` - otherwise. The function will not modify `self._independent_nodes`. + If there are any memlets that have independent data or any independent nodes, + assume that blocking is favourable. """ assert self._independent_nodes is not None - assert self._dependent_nodes is None - - # There is nothing to move out so ignore it. - if len(self._independent_nodes) == 0: - return False # Currently we only filter out Tasklets that do not read any data, which # is the example above, Because of how DaCe works we also subtract all @@ -883,4 +1179,10 @@ def _check_if_blocking_is_favourable( nb_independent_nodes -= 1 assert nb_independent_nodes >= 0 - return nb_independent_nodes > 0 + # TODO(iomaganaris): Figure out how many memlets and nodes minimum there need to be + # to make blocking worthwhile. + return not self.require_independent_nodes or ( + nb_independent_nodes + + (len(self._memlet_to_promote) if self._memlet_to_promote is not None else 0) + > self.independent_node_threshold + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index 1aed515aa1..27a598761a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -231,7 +231,7 @@ def test_only_dependent(): # By default there must be dependent nodes, so by default it will not apply. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameters=["j"]), validate=True, validate_all=True, ) @@ -241,7 +241,7 @@ def test_only_dependent(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=10, - blocking_parameter="j", + blocking_parameters=["j"], require_independent_nodes=False, ), validate=True, @@ -308,7 +308,7 @@ def test_intermediate_access_node(): # Apply the transformation. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameters=["j"]), validate=True, validate_all=True, ) @@ -347,7 +347,7 @@ def test_chained_access() -> None: # Apply the transformation. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameters=["j"]), validate=True, validate_all=True, ) @@ -423,7 +423,7 @@ def test_direct_map_exit_connection() -> dace.SDFG: # Because there are no independent nodes, the transformation will not apply # and we have to explicitly enable it. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameters=["j"]), validate=True, validate_all=True, ) @@ -432,7 +432,7 @@ def test_direct_map_exit_connection() -> dace.SDFG: count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=5, - blocking_parameter="j", + blocking_parameters=["j"], require_independent_nodes=False, ), validate=True, @@ -455,7 +455,7 @@ def test_empty_memlet_1(): # thus filtered out, together with the output. Thus we have to disable this # filtering. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameters=["j"]), validate=True, validate_all=True, ) @@ -464,7 +464,7 @@ def test_empty_memlet_1(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=5, - blocking_parameter="j", + blocking_parameters=["j"], require_independent_nodes=False, ), validate=True, @@ -494,7 +494,7 @@ def test_empty_memlet_2(): # Because there are no independent node so by default the transformation should # not apply. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameters=["j"]), validate=True, validate_all=True, ) @@ -503,7 +503,7 @@ def test_empty_memlet_2(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=5, - blocking_parameter="j", + blocking_parameters=["j"], require_independent_nodes=False, ), validate=True, @@ -535,7 +535,7 @@ def test_empty_memlet_3(): # There are no independent nodes so we must force the blocking. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameters=["j"]), validate=True, validate_all=True, ) @@ -544,7 +544,7 @@ def test_empty_memlet_3(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=5, - blocking_parameter="j", + blocking_parameters=["j"], require_independent_nodes=False, ), validate=True, @@ -678,7 +678,7 @@ def test_loop_blocking_inner_map(): # Because there is no independent part, we have to force the application of the # transformation. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameters=["__i0"]), validate=True, validate_all=True, ) @@ -687,7 +687,7 @@ def test_loop_blocking_inner_map(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=5, - blocking_parameter="__i0", + blocking_parameters=["__i0"], require_independent_nodes=False, ), validate=True, @@ -731,7 +731,7 @@ def test_loop_blocking_inner_map_with_independent_part(independent_part): assert all(oedge.dst is outer_map_exit for oedge in state.out_edges(i_access_node)) count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameters=["__i0"]), validate=True, validate_all=True, ) @@ -811,7 +811,7 @@ def test_loop_blocking_sdfg_with_independent_inner_map(): sdfg, state, outer_me, inner_me = _make_loop_blocking_sdfg_with_independent_inner_map() count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameter="__i1"), + gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameters=["__i1"]), validate=True, validate_all=True, ) @@ -887,7 +887,7 @@ def test_loop_blocking_dependent_reduction(): # Because there is no independent part, we have to force the application of the # transformation. count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameter="__i1"), + gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameters=["__i1"]), validate=True, validate_all=True, ) @@ -896,7 +896,7 @@ def test_loop_blocking_dependent_reduction(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="__i1", + blocking_parameters=["__i1"], require_independent_nodes=False, ), validate=True, @@ -915,7 +915,7 @@ def test_loop_blocking_dependent_reduction(): def test_loop_blocking_independent_reduction(): sdfg, state, me, red = _make_loop_blocking_with_reduction(reduction_is_dependent=False) count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameter="__i1"), + gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameters=["__i1"]), validate=True, validate_all=True, ) @@ -1034,7 +1034,7 @@ def _apply_and_run_mixed_memlet_sdfg( require_independent_nodes = True if not tskl1_independent: count = sdfg.apply_transformations_repeated( - gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameter="j"), + gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameters=["j"]), validate=True, validate_all=True, ) @@ -1044,7 +1044,7 @@ def _apply_and_run_mixed_memlet_sdfg( count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="j", + blocking_parameters=["j"], require_independent_nodes=require_independent_nodes, ), validate=True, @@ -1188,7 +1188,7 @@ def test_loop_blocking_no_independent_nodes(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="__i1", + blocking_parameters=["__i1"], require_independent_nodes=True, ), validate=True, @@ -1200,7 +1200,7 @@ def test_loop_blocking_no_independent_nodes(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="__i1", + blocking_parameters=["__i1"], require_independent_nodes=False, ), validate=True, @@ -1250,7 +1250,7 @@ def ref_comp(a, b, c, B, N, M): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=1, - blocking_parameter="k", + blocking_parameters=["k"], require_independent_nodes=False, ), validate=True, @@ -1301,7 +1301,7 @@ def test_blocking_size_too_big(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=30, - blocking_parameter="j", + blocking_parameters=["j"], require_independent_nodes=False, ), validate=True, @@ -1313,7 +1313,7 @@ def test_blocking_size_too_big(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=5, - blocking_parameter="j", + blocking_parameters=["j"], require_independent_nodes=False, ), validate=True, @@ -1399,7 +1399,7 @@ def test_loop_blocking_sdfg_with_semi_independent_map(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="__i1", + blocking_parameters=["__i1"], require_independent_nodes=False, ), validate=True, @@ -1461,7 +1461,7 @@ def test_loop_blocking_only_independent_inner_map(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="__i1", + blocking_parameters=["__i1"], require_independent_nodes=False, ), validate=True, @@ -1535,7 +1535,7 @@ def test_loop_blocking_direct_access_node_array(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="__i1", + blocking_parameters=["__i1"], require_independent_nodes=True, ), validate=True, @@ -1546,7 +1546,7 @@ def test_loop_blocking_direct_access_node_array(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="__i1", + blocking_parameters=["__i1"], require_independent_nodes=False, ), validate=True, @@ -1586,7 +1586,7 @@ def test_loop_blocking_direct_access_node_scalar(): count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking( blocking_size=2, - blocking_parameter="__i1", + blocking_parameters=["__i1"], require_independent_nodes=True, ), validate=True, @@ -1608,3 +1608,244 @@ def test_loop_blocking_direct_access_node_scalar(): util.compile_and_run_sdfg(sdfg, **res) assert all(np.allclose(ref[name], res[name]) for name in ref) + + +def _make_loop_blocking_sdfg_with_everything() -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry +]: + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_inner_semi_independent_map")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array("A", shape=(40, 8), dtype=dace.float64, transient=False) + for name in "BC": + sdfg.add_array(name, shape=(40, 8), dtype=dace.float64, transient=False) + sdfg.add_array("inc", shape=(40, 3), dtype=dace.float64, transient=False) + sdfg.add_array("inc2", shape=(40, 3), dtype=dace.float64, transient=False) + sdfg.add_array("gt_conn_dummy", shape=(40, 2), dtype=dace.int32, transient=False) + sdfg.add_array("ikoffset", shape=(40,), dtype=dace.int32, transient=False) + sdfg.add_array("S", shape=(40,), dtype=dace.float64, transient=False) + sdfg.add_scalar("t", dtype=dace.float64, transient=True) + sdfg.add_scalar("tt", dtype=dace.float64, transient=True) + sdfg.add_scalar("ttt", dtype=dace.float64, transient=True) + sdfg.add_scalar("tttt", dtype=dace.float64, transient=True) + sdfg.add_scalar("ttttt", dtype=dace.float64, transient=True) + sdfg.add_scalar("inc2_tmp", dtype=dace.float64, transient=True) + + A, B, C, T, t, S = (state.add_access(name) for name in "ABCTtS") + inc = state.add_access("inc") + inc2 = state.add_access("inc2") + gt_conn_dummy = state.add_access("gt_conn_dummy") + ikoffset = state.add_access("ikoffset") + tt = state.add_access("tt") + ttt = state.add_access("ttt") + tttt = state.add_access("tttt") + ttttt = state.add_access("ttttt") + inc2_tmp = state.add_access("inc2_tmp") + + # Note that creating `T` as an array is not useful at all, a scalar would be + # enough. The only reason for doing it is, that the Memlets inside and outside + # the inner Map scope can refer to different data, and the outside Memlet + # is dependent. + sdfg.add_array("T", shape=(8,), dtype=dace.float64, transient=True) + + me, mx = state.add_map("main_comp", ndrange={"__i0": "0:40", "__i1": "0:8"}) + + # The inner computation on its own is not useful. + ime, imx = state.add_map("inner_comp", ndrange={"__inner": "0:3"}) + itlet = state.add_tasklet( + "inner_tasklet", + inputs={"__in", "__inc", "__inc2"}, + outputs={"__out"}, + code="__out = __in + __inc + __inc2", + ) + + indirectaccesstlet = state.add_tasklet( + "indirect_access_tlet", + inputs={"__in", "__gt_conn_dummy"}, + outputs={"__out"}, + code="__out = __in[(__gt_conn_dummy - 0), (-0) + __i1]", + ) + + indirectaccesstletkoffset = state.add_tasklet( + "indirect_access_tlet_ikoffset", + inputs={"__in", "__gt_conn_dummy", "__koffset"}, + outputs={"__out"}, + code="__out = __in[(__gt_conn_dummy - 0), (-0) + __koffset]", + ) + + dtletB = state.add_tasklet( + "dependent_tlet_B", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in + 5.0", + ) + + idtlet = state.add_tasklet( + "independent_tlet", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in + 2.0", + ) + + # This is the dependent Tasklet. + dtlet = state.add_tasklet( + "dependent_tlet", + inputs={"__in1", "__in2", "__in3", "__in4", "__in5"}, + outputs={"__out"}, + code="__out = __in1 + __in2 + __in3 + __in4 + __in5", + ) + + state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:40, 0:8]")) + state.add_edge(me, "OUT_A", ime, "IN_A", dace.Memlet("A[__i0, __i1]")) + me.add_scope_connectors("A") + + state.add_edge(inc, None, me, "IN_inc", dace.Memlet("inc[0:40, 0:3]")) + state.add_edge(me, "OUT_inc", ime, "IN_inc", dace.Memlet("inc[__i0, 0:3]")) + me.add_scope_connectors("inc") + + state.add_edge(inc2, None, me, "IN_inc2", dace.Memlet("inc2[0:40, 0:3]")) + state.add_edge(me, "OUT_inc2", ime, "IN_inc2", dace.Memlet("inc2[__i0, 0:3]")) + me.add_scope_connectors("inc2") + + state.add_edge(ime, "OUT_A", itlet, "__in", dace.Memlet("A[__i0, __i1] -> __in")) + ime.add_scope_connectors("A") + state.add_edge(ime, "OUT_inc", itlet, "__inc", dace.Memlet("inc[__i0, __inner] -> __inc")) + ime.add_scope_connectors("inc") + state.add_edge(ime, "OUT_inc2", inc2_tmp, None, dace.Memlet("[__i0, __inner] -> inc2_tmp[0]")) + state.add_edge(inc2_tmp, None, itlet, "__inc2", dace.Memlet("inc2_tmp[0] -> __inc2")) + ime.add_scope_connectors("inc2") + + # Here is the interesting part, on the inside of the inner Map scope, the output + # Memlet refers to `t` the scalar, but on the outside the Memlet refers to + # `T` and this also includes the blocking variable `__i1`. + state.add_edge(t, None, imx, "IN_t", dace.Memlet("t[0]")) + state.add_edge(imx, "OUT_t", T, None, dace.Memlet("T[__i1]")) + imx.add_scope_connectors("t") + + state.add_edge(T, None, dtlet, "__in1", dace.Memlet("T[__i1]")) + + state.add_edge(itlet, "__out", t, None, dace.Memlet("t[0]")) + + state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:40, 0:8]")) + state.add_edge( + gt_conn_dummy, None, me, "IN_gt_conn_dummy", dace.Memlet("gt_conn_dummy[0:40, 0:2]") + ) + state.add_edge(me, "OUT_B", indirectaccesstlet, "__in", dace.Memlet("B[0:40, 0:8]")) + state.add_edge( + me, + "OUT_gt_conn_dummy", + indirectaccesstlet, + "__gt_conn_dummy", + dace.Memlet("gt_conn_dummy[__i0, 0]"), + ) + me.add_scope_connectors("B") + me.add_scope_connectors("gt_conn_dummy") + + state.add_edge(indirectaccesstlet, "__out", tt, None, dace.Memlet("tt[0]")) + state.add_edge(tt, None, dtlet, "__in2", dace.Memlet("tt[0]")) + + state.add_edge(me, "OUT_B", dtletB, "__in", dace.Memlet("B[__i0, __i1]")) + state.add_edge(dtletB, "__out", ttt, None, dace.Memlet("ttt[0]")) + state.add_edge(ttt, None, dtlet, "__in3", dace.Memlet("ttt[0]")) + + state.add_edge(S, None, me, "IN_S", dace.Memlet("S[__i0]")) + state.add_edge(me, "OUT_S", idtlet, "__in", dace.Memlet("S[__i0]")) + me.add_scope_connectors("S") + + state.add_edge(idtlet, "__out", tttt, None, dace.Memlet("tttt[0]")) + state.add_edge(tttt, None, dtlet, "__in4", dace.Memlet("tttt[0]")) + + state.add_edge(dtlet, "__out", mx, "IN_C", dace.Memlet("C[__i0, __i1]")) + state.add_edge(mx, "OUT_C", C, None, dace.Memlet("C[0:40, 0:8]")) + mx.add_scope_connectors("C") + + state.add_edge(ikoffset, None, me, "IN_ikoffset", dace.Memlet("ikoffset[0:40]")) + state.add_edge(me, "OUT_B", indirectaccesstletkoffset, "__in", dace.Memlet("B[0:40, 0:8]")) + state.add_edge( + me, + "OUT_gt_conn_dummy", + indirectaccesstletkoffset, + "__gt_conn_dummy", + dace.Memlet("gt_conn_dummy[__i0, 1]"), + ) + state.add_edge( + me, + "OUT_ikoffset", + indirectaccesstletkoffset, + "__koffset", + dace.Memlet("ikoffset[__i1]"), + ) + me.add_scope_connectors("ikoffset") + state.add_edge(indirectaccesstletkoffset, "__out", ttttt, None, dace.Memlet("ttttt[0]")) + state.add_edge(ttttt, None, dtlet, "__in5", dace.Memlet("ttttt[0]")) + + sdfg.validate() + + return sdfg, state, me, ime + + +@pytest.mark.parametrize( + "require_independent_nodes,promote_independent_memlets,independent_node_threshold", + [ + (True, True, 0), + (True, False, 0), + (False, True, 0), + (False, False, 0), + (True, True, 6), + (False, True, 6), + ], +) +def test_loop_blocking_sdfg_with_everything( + require_independent_nodes: bool, + promote_independent_memlets: bool, + independent_node_threshold: int, +): + sdfg, state, me, ime = _make_loop_blocking_sdfg_with_everything() + + scope_dict_before = state.scope_dict() + assert scope_dict_before[ime] is me + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameters=["__i1"], + require_independent_nodes=require_independent_nodes, + promote_independent_memlets=promote_independent_memlets, + independent_node_threshold=independent_node_threshold, + ), + validate=True, + validate_all=True, + ) + assert count == (1 if (independent_node_threshold <= 5 or not require_independent_nodes) else 0) + + assert state.out_degree(me) == 10 + + me_tasklet_out_edges = [ + edge for edge in state.out_edges(me) if isinstance(edge.dst, dace_nodes.Tasklet) + ] + me_access_node_out_edges = [ + edge for edge in state.out_edges(me) if isinstance(edge.dst, dace_nodes.AccessNode) + ] + + if count == 1 and promote_independent_memlets: + assert len(me_tasklet_out_edges) == 1 + assert next(iter(me_tasklet_out_edges)).data.data == "S" + assert len(me_access_node_out_edges) == 4 + assert all( + [ + edge.data.data in {"inc", "inc2", "gt_conn_dummy"} + for edge in me_access_node_out_edges + ] + ) + elif count == 1 and not promote_independent_memlets: + # Tasklet that reads 'S' is independent + assert len(me_tasklet_out_edges) == 1 + assert next(iter(me_tasklet_out_edges)).data.data == "S" + assert len(me_access_node_out_edges) == 0 + else: + assert len(me_tasklet_out_edges) == 7 + assert len(me_access_node_out_edges) == 0 + + new_scope_of_inner_map = state.scope_dict()[ime] + assert isinstance(new_scope_of_inner_map, dace_nodes.MapEntry) + assert new_scope_of_inner_map is not (me if count == 1 else None)