-
Notifications
You must be signed in to change notification settings - Fork 56
feat[next-dace]: Use SDFG library node for lowering of broadcast and reduce #2386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f8180e2
ef9ef92
bd1b766
25abd36
071f512
331bcd3
b90976b
21a79d2
a1f6f1a
6e97232
08484df
da12bdb
5382dce
b04586c
4647c7d
779b164
f250040
68a417f
9d774e4
c72e75a
93b1ffd
0915185
cad3fa0
2947655
ab1f400
007f435
16dc60f
84b821d
9e10019
7c54d44
0dd25e6
f62bc31
4e7173f
2c785d9
f70e5b3
1ec5c31
9c50896
21bac3c
db00ace
1c93925
35309ef
44bee9d
cc2c710
8cf38e6
e7bd089
64c7b87
e90ebef
a2e1451
6d1a358
6209739
67d5ee4
755afea
b8f7d91
40eb027
69cf288
a31e163
5249529
98e098a
c7b1f2d
2444a2d
c3cde1e
b8d254e
5b95cfa
c6c5bda
0cc7109
bd5b88e
bec8ab8
080bb32
08b1d85
44f524c
13541b0
0645d58
9592908
1e2089b
8d3a89e
3dcf818
11b8651
2cd3a56
28e8808
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from typing import Final | ||
|
|
||
| from dace import nodes as dace_nodes | ||
|
|
||
| from gt4py.next.program_processors.runners.dace.library_nodes.broadcast import ( | ||
| Broadcast, | ||
| inplace_broadcast_expander, | ||
| ) | ||
| from gt4py.next.program_processors.runners.dace.library_nodes.reduce_with_skip_values import ( | ||
| ReduceWithSkipValues, | ||
| ) | ||
|
|
||
|
|
||
| GTIR_LIBRARY_NODES: Final[tuple[dace_nodes.LibraryNode, ...]] = (Broadcast, ReduceWithSkipValues) | ||
| """List of available GTIR library nodes.""" | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "Broadcast", | ||
| "ReduceWithSkipValues", | ||
| "inplace_broadcast_expander", | ||
| ] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,302 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||
| # GT4Py - GridTools Framework | ||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2014-2024, ETH Zurich | ||||||||||||||||||||||||||||||||||||||||||||||||
| # All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Please, refer to the LICENSE file in the root directory. | ||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: BSD-3-Clause | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| import copy | ||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Final, Iterable, Optional, Sequence | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| import dace | ||||||||||||||||||||||||||||||||||||||||||||||||
| from dace import ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| data as dace_data, | ||||||||||||||||||||||||||||||||||||||||||||||||
| library as dace_library, | ||||||||||||||||||||||||||||||||||||||||||||||||
| nodes as dace_nodes, | ||||||||||||||||||||||||||||||||||||||||||||||||
| properties as dace_properties, | ||||||||||||||||||||||||||||||||||||||||||||||||
| subsets as dace_sbs, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| from dace.sdfg import graph as dace_graph | ||||||||||||||||||||||||||||||||||||||||||||||||
| from dace.transformation import transformation as dace_transform | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| from gt4py.next import common as gtx_common | ||||||||||||||||||||||||||||||||||||||||||||||||
| from gt4py.next.program_processors.runners.dace import lowering as gtx_dace_lowering | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| _INPUT_NAME: Final[str] = "_inp" | ||||||||||||||||||||||||||||||||||||||||||||||||
| _OUTPUT_NAME: Final[str] = "_outp" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| @dace_library.node | ||||||||||||||||||||||||||||||||||||||||||||||||
| class Broadcast(dace_nodes.LibraryNode): | ||||||||||||||||||||||||||||||||||||||||||||||||
| """Implements write of a scalar value over an array subset. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Same as XLA. | ||||||||||||||||||||||||||||||||||||||||||||||||
| broadcast_in_dim[i] describes where dimension `i` of the `value_to_broadcast` | ||||||||||||||||||||||||||||||||||||||||||||||||
| goes. In case of a scalar it is empty. | ||||||||||||||||||||||||||||||||||||||||||||||||
| Furthermore the following has to hold: | ||||||||||||||||||||||||||||||||||||||||||||||||
| ```python | ||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(len(broadcast_in_dim): | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert output.shape[broadcast_in_dim[i]] == value_to_broadcast.shape[i] | ||||||||||||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||
| broadcast_in_dim: How to broadcast. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| params: The parameters that should be used for the expansion. If given one | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| entry for each dimension of the destination is needed. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Todo: | ||||||||||||||||||||||||||||||||||||||||||||||||
| - While for the output it is probably okay to always require an adjacent | ||||||||||||||||||||||||||||||||||||||||||||||||
| AccessNode for the input it might be possible to be on the other side | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
However, I don't understand how this could happen. |
||||||||||||||||||||||||||||||||||||||||||||||||
| of a Map. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||
| default_implementation: Final[str | None] = "pure" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| brodcast_in_dims = dace_properties.ListProperty(element_type=int) | ||||||||||||||||||||||||||||||||||||||||||||||||
| params = dace_properties.ListProperty(element_type=str, allow_none=True) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||
| name: str, | ||||||||||||||||||||||||||||||||||||||||||||||||
| broadcast_in_dims: Sequence[int], | ||||||||||||||||||||||||||||||||||||||||||||||||
| params: Optional[Iterable[gtx_common.Dimension | str]], | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| debuginfo: dace.dtypes.DebugInfo | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO(philip, edopao): I would propose to drop `value` then. This makes it | ||||||||||||||||||||||||||||||||||||||||||||||||
| # simpler to handle in the transformations. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+70
to
+71
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, we can remove the todo comment.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(name, inputs={_INPUT_NAME}, outputs={_OUTPUT_NAME}) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| self.brodcast_in_dims = list(broadcast_in_dims) | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.debuginfo = debuginfo | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if params is None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.params = None | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| self.params = [ | ||||||||||||||||||||||||||||||||||||||||||||||||
| gtx_dace_lowering.get_map_variable(param) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(param, gtx_common.Dimension) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else param | ||||||||||||||||||||||||||||||||||||||||||||||||
| for param in params | ||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def validate(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| if len(self.brodcast_in_dims) != len(set(self.brodcast_in_dims)): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("`Can not broadcast to multiple dimensions at the same time.") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO(phimuell): Handle empty Memlets in the input. | ||||||||||||||||||||||||||||||||||||||||||||||||
| if state.in_degree(self) != 1 and next(iter(state.in_edges(self))).dst_conn == _INPUT_NAME: | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("GT4Py Broadcast node needs exactly one input.") | ||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| state.out_degree(self) != 1 | ||||||||||||||||||||||||||||||||||||||||||||||||
| and next(iter(state.out_edges(self))).src_conn == _OUTPUT_NAME | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+95
to
+96
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("GT4Py Broadcast node needs exactly one output.") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_value_node: dace_nodes.AccessNode = next(iter(state.in_edges(self))).src | ||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(bcast_value_node, dace_nodes.AccessNode): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Source of broadcasting must be an AccessNode.") | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_value_desc = bcast_value_node.desc(sdfg) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(bcast_value_desc, dace_data.View): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Can not broadcast from a view.") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_result_node: dace_nodes.AccessNode = next(iter(state.out_edges(self))).dst | ||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(bcast_result_node, dace_nodes.AccessNode): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Broadcast result must be an AccessNode.") | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_result_desc = bcast_result_node.desc(sdfg) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(bcast_result_desc, dace_data.View): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Broadcast result can not be a view.") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(bcast_result_desc, dace_data.Array): | ||||||||||||||||||||||||||||||||||||||||||||||||
| # In fact it would also be possible to broadcast into a Scalar, but this | ||||||||||||||||||||||||||||||||||||||||||||||||
| # does not make much sense. | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"Can only broadcast into an array, but target was `{type(bcast_result_desc).__name__}`." | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if (self.params is not None) and (len(self.params) != len(bcast_result_desc.shape)): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"Expected that {len(bcast_result_desc.shape)} parameters are" | ||||||||||||||||||||||||||||||||||||||||||||||||
| f" needed but {len(self.params)} were specified." | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+123
to
+124
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(bcast_value_desc, dace_data.Scalar): | ||||||||||||||||||||||||||||||||||||||||||||||||
| if len(self.brodcast_in_dims) != 0: | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("For a scalar `broadcast_in_dims` must be empty.") | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| if len(self.brodcast_in_dims) != len(bcast_value_desc.shape): | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+130
to
+131
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"`broadcast_in_dims` has {len(self.brodcast_in_dims)} entries," | ||||||||||||||||||||||||||||||||||||||||||||||||
| f" but the value to broadcast had {len(bcast_value_desc.shape)} dimensions." | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if len(bcast_result_desc.shape) < len(bcast_value_desc.shape): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"The value to broadcast has more dimensions ({len(bcast_value_desc.shape)})" | ||||||||||||||||||||||||||||||||||||||||||||||||
| f" than the result ({len(bcast_result_desc.shape)})." | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| for src_dim, bcast_dst_dim in enumerate(self.brodcast_in_dims): | ||||||||||||||||||||||||||||||||||||||||||||||||
| if bcast_dst_dim < 0: | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Negative broadcast") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if bcast_dst_dim >= len(bcast_result_desc.shape): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Out of range broadcast dim found.") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Only do the size matching test if the sizes are known, as different | ||||||||||||||||||||||||||||||||||||||||||||||||
| # symbols can have the same value. | ||||||||||||||||||||||||||||||||||||||||||||||||
| src_size = bcast_value_desc.shape[src_dim] | ||||||||||||||||||||||||||||||||||||||||||||||||
| dst_size = bcast_result_desc.shape[bcast_dst_dim] | ||||||||||||||||||||||||||||||||||||||||||||||||
| if str(src_size).isdigit() and str(dst_size).isdigit() and (src_size != dst_size): | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("Size mismatch found.") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def inplace_broadcast_expander( | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_node: Broadcast, | ||||||||||||||||||||||||||||||||||||||||||||||||
| state: dace.SDFGState, | ||||||||||||||||||||||||||||||||||||||||||||||||
| sdfg: dace.SDFG, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| """Perform expansion of `bcast_node` inside `state`. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| The main difference between this and the normal expansion transformation is | ||||||||||||||||||||||||||||||||||||||||||||||||
| that this function does not generate nested SDFG and instead performs the | ||||||||||||||||||||||||||||||||||||||||||||||||
| expansion in place. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| input_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = next(iter(state.in_edges(bcast_node))) | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = next( | ||||||||||||||||||||||||||||||||||||||||||||||||
| iter(state.out_edges(bcast_node)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO(phimuell): Add warning. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should add a warning. Either allow it or not. |
||||||||||||||||||||||||||||||||||||||||||||||||
| map_params: list[str] = ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| [f"__bcast{dst_dim}" for dst_dim in range(len(output_edge.data.subset))] | ||||||||||||||||||||||||||||||||||||||||||||||||
| if bcast_node.params is None | ||||||||||||||||||||||||||||||||||||||||||||||||
| else list(bcast_node.params) | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| output_subset: list[str] = map_params.copy() | ||||||||||||||||||||||||||||||||||||||||||||||||
| map_ranges: dict[str, dace_sbs.Range] = { | ||||||||||||||||||||||||||||||||||||||||||||||||
| map_param: dace_sbs.Range([sbs]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| for map_param, sbs in zip(map_params, output_edge.data.subset) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| input_subset: list[str] | ||||||||||||||||||||||||||||||||||||||||||||||||
| if len(bcast_node.brodcast_in_dims) == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_subset = ["0"] | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_value_offset = input_edge.data.subset.min_element() | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_subset = [ | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"{map_params[dst_dim]} + ({offset})" | ||||||||||||||||||||||||||||||||||||||||||||||||
| for dst_dim, offset in zip(bcast_node.brodcast_in_dims, bcast_value_offset) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| me, mx = state.add_map(f"__gt4py_broadcast_map_{bcast_node.name}", ndrange=map_ranges) | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_tlet = state.add_tasklet( | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"__gt4py_broadcast_tasklet_{bcast_node.name}", | ||||||||||||||||||||||||||||||||||||||||||||||||
| inputs={"__in"}, | ||||||||||||||||||||||||||||||||||||||||||||||||
| outputs={"__out"}, | ||||||||||||||||||||||||||||||||||||||||||||||||
| code="__out = __in", | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| state.add_edge( | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_edge.src, | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_edge.src_conn, | ||||||||||||||||||||||||||||||||||||||||||||||||
| me, | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"IN_{input_edge.data.data}", | ||||||||||||||||||||||||||||||||||||||||||||||||
| dace.Memlet.from_array(input_edge.data.data, sdfg.arrays[input_edge.data.data]), | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| state.add_edge( | ||||||||||||||||||||||||||||||||||||||||||||||||
| me, | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"OUT_{input_edge.data.data}", | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_tlet, | ||||||||||||||||||||||||||||||||||||||||||||||||
| "__in", | ||||||||||||||||||||||||||||||||||||||||||||||||
| dace.Memlet(data=input_edge.data.data, subset=", ".join(input_subset)), | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| me.add_scope_connectors(input_edge.data.data) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+204
to
+218
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| state.add_edge( | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_tlet, | ||||||||||||||||||||||||||||||||||||||||||||||||
| "__out", | ||||||||||||||||||||||||||||||||||||||||||||||||
| mx, | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"IN_{output_edge.data.data}", | ||||||||||||||||||||||||||||||||||||||||||||||||
| dace.Memlet(data=output_edge.data.data, subset=", ".join(output_subset)), | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| state.add_edge( | ||||||||||||||||||||||||||||||||||||||||||||||||
| mx, | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"OUT_{output_edge.data.data}", | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_edge.dst, | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_edge.dst_conn, | ||||||||||||||||||||||||||||||||||||||||||||||||
| dace.Memlet(data=output_edge.data.data, subset=copy.deepcopy(output_edge.data.subset)), | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| mx.add_scope_connectors(output_edge.data.data) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+220
to
+234
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Now delete the node. | ||||||||||||||||||||||||||||||||||||||||||||||||
| state.remove_node(bcast_node) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| @dace_library.register_expansion(Broadcast, "pure") | ||||||||||||||||||||||||||||||||||||||||||||||||
| class BroadcastExpandInlined(dace_transform.ExpandTransformation): | ||||||||||||||||||||||||||||||||||||||||||||||||
| """Implements pure expansion of the Broadcast library node.""" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| environments: Final[list[Any]] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||
| def expansion(node: Broadcast, state: dace.SDFGState, sdfg: dace.SDFG) -> dace.SDFG: | ||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: | ||||||||||||||||||||||||||||||||||||||||||||||||
| # - Modify the edges on the outside. | ||||||||||||||||||||||||||||||||||||||||||||||||
| # - Handle the missing symbols. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The missing symbols are automatically mapped to symbols with the same name, aren't they? |
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # NOTE: We have to cheat a here a bit. Actually only parts of the output | ||||||||||||||||||||||||||||||||||||||||||||||||
| # would be mapped into the nested SDFG. | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(node, Broadcast) | ||||||||||||||||||||||||||||||||||||||||||||||||
| assert state.out_degree(node) == 1 and state.in_degree(node) == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| nsdfg = dace.SDFG(f"__gt4py_broadcast_expansion_{node.label}") | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_st = nsdfg.add_state(f"__gt4py_broadcast_expansion_{node.label}_state") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| input_edge = next(state.in_edges_by_connector(node, _INPUT_NAME)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_edge = next(state.out_edges_by_connector(node, _OUTPUT_NAME)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_value = input_edge.src | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_result = output_edge.dst | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Creating the input and output data inside the nested SDFG, such that we can | ||||||||||||||||||||||||||||||||||||||||||||||||
| # map them _fully_ (see later) into the nested SDFG. | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_value_inner_data = nsdfg.add_datadesc(_INPUT_NAME, bcast_value.desc(sdfg).clone()) | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_result_inner_data = nsdfg.add_datadesc(_OUTPUT_NAME, bcast_result.desc(sdfg).clone()) | ||||||||||||||||||||||||||||||||||||||||||||||||
| nsdfg.arrays[bcast_value_inner_data].transient = False | ||||||||||||||||||||||||||||||||||||||||||||||||
| nsdfg.arrays[bcast_result_inner_data].transient = False | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| inner_bcast_node = copy.deepcopy(node) | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_bcast_value_edge = bcast_st.add_edge( | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_st.add_access(bcast_value_inner_data), | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_edge.src_conn, | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_bcast_node, | ||||||||||||||||||||||||||||||||||||||||||||||||
| "_inp", | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy.deepcopy(input_edge.data), | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_bcast_value_edge.data.data = bcast_value_inner_data | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| inner_bcast_result_edge = bcast_st.add_edge( | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_bcast_node, | ||||||||||||||||||||||||||||||||||||||||||||||||
| "_outp", | ||||||||||||||||||||||||||||||||||||||||||||||||
| bcast_st.add_access(bcast_result_inner_data), | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_edge.dst_conn, | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy.deepcopy(output_edge.data), | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_bcast_result_edge.data.data = bcast_result_inner_data | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Now we run the inplace expansion on the node inside the nested SDFG. | ||||||||||||||||||||||||||||||||||||||||||||||||
| inplace_broadcast_expander(inner_bcast_node, bcast_st, nsdfg) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # To ensure that the full data is passed into the nested SDFG, which we | ||||||||||||||||||||||||||||||||||||||||||||||||
| # assumed because of how we want that everything is mapped into. | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_edge.data.subset = dace_sbs.Range.from_array(bcast_value.desc(sdfg)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_edge.data.subset = dace_sbs.Range.from_array(bcast_result.desc(sdfg)) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # NOTE: We will not update `nsdfg.symbols`, instead we will rely on | ||||||||||||||||||||||||||||||||||||||||||||||||
| # `add_nested_sdfg()` that is called implicitly by the expansion driver. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return nsdfg | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the transformation module, we use the
dace_sbsalias, in the lowering module we usedace_subsets. It's OK to usedace_sbsin this module, but let's try to keep it consistent.