Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
f8180e2
edit
edopao Nov 11, 2025
ef9ef92
edit
edopao Nov 11, 2025
bd1b766
undo extra change
edopao Nov 11, 2025
25abd36
use library node also in concat_where
edopao Nov 11, 2025
071f512
edit
edopao Nov 12, 2025
331bcd3
Merge remote-tracking branch 'upstream/main' into dace-fill_node
edopao Nov 12, 2025
b90976b
edit
edopao Nov 12, 2025
21a79d2
fix for inf expressions
edopao Nov 13, 2025
a1f6f1a
edit
edopao Nov 13, 2025
6e97232
edit
edopao Nov 13, 2025
08484df
Merge branch 'dace-refactor_concat_where' into dace-fill_node
edopao Nov 13, 2025
da12bdb
edit
edopao Nov 13, 2025
5382dce
Merge branch 'dace-refactor_concat_where' into dace-fill_node
edopao Nov 13, 2025
b04586c
Merge branch 'main' into dace-fill_node
edopao Nov 17, 2025
4647c7d
Merge branch 'main' into dace-fill_node
edopao Nov 21, 2025
779b164
remove special handling for inf symbol
edopao Nov 21, 2025
f250040
fix rebase
edopao Nov 24, 2025
68a417f
fix rebase
edopao Nov 24, 2025
9d774e4
Merge branch 'main' into dace-fill_node
edopao Nov 25, 2025
c72e75a
Merge branch 'main' into dace-fill_node
edopao Feb 10, 2026
93b1ffd
Merge branch 'main' into dace-fill_node
edopao Feb 11, 2026
0915185
edit
edopao Feb 11, 2026
cad3fa0
add library_nodes module folder
edopao Feb 11, 2026
2947655
edit
edopao Feb 11, 2026
ab1f400
undo extra change
edopao Feb 11, 2026
007f435
Merge branch 'main' into dace-fill_node
edopao Feb 11, 2026
16dc60f
edit
edopao Feb 12, 2026
84b821d
fix
edopao Feb 12, 2026
9e10019
fix
edopao Feb 12, 2026
7c54d44
Started with the broadcast inline transformation.
philip-paul-mueller Apr 27, 2026
0dd25e6
Fixed some wrong names.
philip-paul-mueller Apr 27, 2026
f62bc31
Fix
philip-paul-mueller Apr 27, 2026
4e7173f
In C++ Speak it compiles but the test does not pass.
philip-paul-mueller Apr 27, 2026
2c785d9
Fixed two errors.
philip-paul-mueller Apr 27, 2026
f70e5b3
Added a test that ensures that indirect accessing is not triggered.
philip-paul-mueller Apr 27, 2026
1ec5c31
Continued.
philip-paul-mueller Apr 27, 2026
9c50896
Some more notes.
philip-paul-mueller Apr 28, 2026
21bac3c
I think I need to merge these two transformations.
philip-paul-mueller Apr 28, 2026
db00ace
Merged it and now pre-commit is now happy, but does it works.
philip-paul-mueller Apr 28, 2026
1c93925
Now at least the unit tests passes.
philip-paul-mueller Apr 28, 2026
35309ef
Added a test for the access mode case.
philip-paul-mueller Apr 29, 2026
44bee9d
Added a test for the fan out of AccessNodes.
philip-paul-mueller Apr 29, 2026
cc2c710
Added multi connection test.
philip-paul-mueller Apr 29, 2026
8cf38e6
Added more tests to the broadcast node.
philip-paul-mueller Apr 29, 2026
e7bd089
Started with an inline expansion, but the regular expansion needs to …
philip-paul-mueller Apr 29, 2026
64c7b87
Updated the Broadcast expansion. Not integrated yet.
philip-paul-mueller Apr 30, 2026
e90ebef
Merge remote-tracking branch 'gt4py/main' into dace-fill_node_philip
philip-paul-mueller May 4, 2026
a2e1451
Fixed an error with the life time.
philip-paul-mueller May 4, 2026
6d1a358
Made a first integration, that is probably not correct.
philip-paul-mueller May 4, 2026
6209739
Made the inline expander accessible to the outside world.
philip-paul-mueller May 4, 2026
67d5ee4
Fixed some issues in the validatiobn function of the broadcast node. …
philip-paul-mueller May 4, 2026
755afea
I do not fully understand why this is needed, sometimes NumPy is just…
philip-paul-mueller May 4, 2026
b8f7d91
Merge branch 'main' into dace-fill_node
edopao May 4, 2026
40eb027
Fixed the expander, it should now work.
philip-paul-mueller May 4, 2026
69cf288
Fixed the broadcasting test since `unique_name()` has been moved.
philip-paul-mueller May 4, 2026
a31e163
Added a test for the library expansion, only scalar yet.
philip-paul-mueller May 4, 2026
5249529
Update src/gt4py/next/program_processors/runners/dace/lowering/gtir_t…
philip-paul-mueller May 4, 2026
98e098a
Fixed an error in validation.
philip-paul-mueller May 4, 2026
c7b1f2d
Added a test to operate on vectors.
philip-paul-mueller May 4, 2026
2444a2d
NOw we also do some slicing.
philip-paul-mueller May 4, 2026
c3cde1e
Discussion points with Edoardo.
philip-paul-mueller May 4, 2026
b8d254e
Fixed a wrong annotation.
philip-paul-mueller May 4, 2026
5b95cfa
Fixed some issue with renaming.
philip-paul-mueller May 4, 2026
c6c5bda
Something is not working as expected.
philip-paul-mueller May 4, 2026
0cc7109
Refined how the `MapToCopy` detects broadcasts.
philip-paul-mueller May 5, 2026
bd5b88e
Handled how expansion of librarynodes is currently performed.
philip-paul-mueller May 5, 2026
bec8ab8
Fixed extensive validation in `FuseHorizontalConditionBlocks`.
philip-paul-mueller May 5, 2026
080bb32
Integrated the correction.
philip-paul-mueller May 5, 2026
08b1d85
Incorporated the domain correction at the expansion.
philip-paul-mueller May 5, 2026
44f524c
Made a unit test a bit harder.
philip-paul-mueller May 5, 2026
13541b0
Makeing the optimnizer temporaraly more strict.
philip-paul-mueller May 5, 2026
0645d58
Merge branch 'dace-fill_node_philip' into dace-fill_node
philip-paul-mueller May 5, 2026
9592908
Now also the parameter names are correctly set.
philip-paul-mueller May 5, 2026
1e2089b
Implemented a missing case.
philip-paul-mueller May 5, 2026
8d3a89e
Forgot to update them.
philip-paul-mueller May 5, 2026
3dcf818
git Fixed a wrong check.
philip-paul-mueller May 6, 2026
11b8651
Updated the transformation to also handle vectors.
philip-paul-mueller May 7, 2026
2cd3a56
Moved the `ScalarBroadcastInliner` (which still has the wrong name) i…
philip-paul-mueller May 7, 2026
28e8808
Merge remote-tracking branch 'gt4py/main' into dace-fill_node
philip-paul-mueller May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Final

from dace import nodes as dace_nodes

from gt4py.next.program_processors.runners.dace.library_nodes.broadcast import (
Broadcast,
inplace_broadcast_expander,
)
from gt4py.next.program_processors.runners.dace.library_nodes.reduce_with_skip_values import (
ReduceWithSkipValues,
)


GTIR_LIBRARY_NODES: Final[tuple[dace_nodes.LibraryNode, ...]] = (Broadcast, ReduceWithSkipValues)
"""List of available GTIR library nodes."""


__all__ = [
"Broadcast",
"ReduceWithSkipValues",
"inplace_broadcast_expander",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import copy
from typing import Any, Final, Iterable, Optional, Sequence

import dace
from dace import (
data as dace_data,
library as dace_library,
nodes as dace_nodes,
properties as dace_properties,
subsets as dace_sbs,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the transformation module, we use the dace_sbs alias, in the lowering module we use dace_subsets. It's OK to use dace_sbs in this module, but let's try to keep it consistent.

)
from dace.sdfg import graph as dace_graph
from dace.transformation import transformation as dace_transform

from gt4py.next import common as gtx_common
from gt4py.next.program_processors.runners.dace import lowering as gtx_dace_lowering


_INPUT_NAME: Final[str] = "_inp"
_OUTPUT_NAME: Final[str] = "_outp"


@dace_library.node
class Broadcast(dace_nodes.LibraryNode):
"""Implements write of a scalar value over an array subset.

Same as XLA.
broadcast_in_dim[i] describes where dimension `i` of the `value_to_broadcast`
goes. In case of a scalar it is empty.
Furthermore the following has to hold:
```python
for i in range(len(broadcast_in_dim):
assert output.shape[broadcast_in_dim[i]] == value_to_broadcast.shape[i]
```
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```
```
In other words, the result array shape has the same size as the broadcast domain.


Args:
broadcast_in_dim: How to broadcast.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
broadcast_in_dim: How to broadcast.
broadcast_in_dim: How to broadcast, see the class documentation.

params: The parameters that should be used for the expansion. If given one
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
params: The parameters that should be used for the expansion. If given one
params: The parameters that should be used for the expansion. If given, one

entry for each dimension of the destination is needed.

Todo:
- While for the output it is probably okay to always require an adjacent
AccessNode for the input it might be possible to be on the other side
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
AccessNode for the input it might be possible to be on the other side
AccessNode, the input nodes might be outside a map scope.

However, I don't understand how this could happen.

of a Map.
"""

implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {}
default_implementation: Final[str | None] = "pure"

brodcast_in_dims = dace_properties.ListProperty(element_type=int)
params = dace_properties.ListProperty(element_type=str, allow_none=True)

def __init__(
self,
name: str,
broadcast_in_dims: Sequence[int],
params: Optional[Iterable[gtx_common.Dimension | str]],
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
params: Optional[Iterable[gtx_common.Dimension | str]],
params: Iterable[gtx_common.Dimension | str] | None,

debuginfo: dace.dtypes.DebugInfo | None = None,
):
# TODO(philip, edopao): I would propose to drop `value` then. This makes it
# simpler to handle in the transformations.
Comment on lines +70 to +71
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we can remove the todo comment.

Suggested change
# TODO(philip, edopao): I would propose to drop `value` then. This makes it
# simpler to handle in the transformations.

super().__init__(name, inputs={_INPUT_NAME}, outputs={_OUTPUT_NAME})

self.brodcast_in_dims = list(broadcast_in_dims)
self.debuginfo = debuginfo

if params is None:
self.params = None
else:
self.params = [
gtx_dace_lowering.get_map_variable(param)
if isinstance(param, gtx_common.Dimension)
else param
for param in params
]

def validate(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None:
if len(self.brodcast_in_dims) != len(set(self.brodcast_in_dims)):
raise ValueError("`Can not broadcast to multiple dimensions at the same time.")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("`Can not broadcast to multiple dimensions at the same time.")
raise ValueError("`Broadcast dimensions must be unique")


# TODO(phimuell): Handle empty Memlets in the input.
if state.in_degree(self) != 1 and next(iter(state.in_edges(self))).dst_conn == _INPUT_NAME:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if state.in_degree(self) != 1 and next(iter(state.in_edges(self))).dst_conn == _INPUT_NAME:
if state.in_degree(self) != 1 or next(iter(state.in_edges(self))).dst_conn != _INPUT_NAME:

raise ValueError("GT4Py Broadcast node needs exactly one input.")
if (
state.out_degree(self) != 1
and next(iter(state.out_edges(self))).src_conn == _OUTPUT_NAME
Comment on lines +95 to +96
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state.out_degree(self) != 1
and next(iter(state.out_edges(self))).src_conn == _OUTPUT_NAME
state.out_degree(self) != 1
or next(iter(state.out_edges(self))).src_conn != _OUTPUT_NAME

):
raise ValueError("GT4Py Broadcast node needs exactly one output.")

bcast_value_node: dace_nodes.AccessNode = next(iter(state.in_edges(self))).src
if not isinstance(bcast_value_node, dace_nodes.AccessNode):
raise ValueError("Source of broadcasting must be an AccessNode.")
bcast_value_desc = bcast_value_node.desc(sdfg)
if isinstance(bcast_value_desc, dace_data.View):
raise ValueError("Can not broadcast from a view.")

bcast_result_node: dace_nodes.AccessNode = next(iter(state.out_edges(self))).dst
if not isinstance(bcast_result_node, dace_nodes.AccessNode):
raise ValueError("Broadcast result must be an AccessNode.")
bcast_result_desc = bcast_result_node.desc(sdfg)
if isinstance(bcast_result_desc, dace_data.View):
raise ValueError("Broadcast result can not be a view.")

if not isinstance(bcast_result_desc, dace_data.Array):
# In fact it would also be possible to broadcast into a Scalar, but this
# does not make much sense.
raise ValueError(
f"Can only broadcast into an array, but target was `{type(bcast_result_desc).__name__}`."
)

if (self.params is not None) and (len(self.params) != len(bcast_result_desc.shape)):
raise ValueError(
f"Expected that {len(bcast_result_desc.shape)} parameters are"
f" needed but {len(self.params)} were specified."
Comment on lines +123 to +124
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of dimensions of the result shape {len(bcast_result_desc.shape)} does not match the nmuber of broadcast parameters {len(self.params)}.

)

if isinstance(bcast_value_desc, dace_data.Scalar):
if len(self.brodcast_in_dims) != 0:
raise ValueError("For a scalar `broadcast_in_dims` must be empty.")
else:
if len(self.brodcast_in_dims) != len(bcast_value_desc.shape):
Comment on lines +130 to +131
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
if len(self.brodcast_in_dims) != len(bcast_value_desc.shape):
elif len(self.brodcast_in_dims) != len(bcast_value_desc.shape):

raise ValueError(
f"`broadcast_in_dims` has {len(self.brodcast_in_dims)} entries,"
f" but the value to broadcast had {len(bcast_value_desc.shape)} dimensions."
)
if len(bcast_result_desc.shape) < len(bcast_value_desc.shape):
raise ValueError(
f"The value to broadcast has more dimensions ({len(bcast_value_desc.shape)})"
f" than the result ({len(bcast_result_desc.shape)})."
)

for src_dim, bcast_dst_dim in enumerate(self.brodcast_in_dims):
if bcast_dst_dim < 0:
raise ValueError("Negative broadcast")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("Negative broadcast")
raise ValueError("Negative dimension index is invalid.")

if bcast_dst_dim >= len(bcast_result_desc.shape):
raise ValueError("Out of range broadcast dim found.")

# Only do the size matching test if the sizes are known, as different
# symbols can have the same value.
src_size = bcast_value_desc.shape[src_dim]
dst_size = bcast_result_desc.shape[bcast_dst_dim]
if str(src_size).isdigit() and str(dst_size).isdigit() and (src_size != dst_size):
raise ValueError("Size mismatch found.")


def inplace_broadcast_expander(
bcast_node: Broadcast,
state: dace.SDFGState,
sdfg: dace.SDFG,
) -> None:
"""Perform expansion of `bcast_node` inside `state`.

The main difference between this and the normal expansion transformation is
that this function does not generate nested SDFG and instead performs the
expansion in place.
"""

input_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = next(iter(state.in_edges(bcast_node)))
output_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = next(
iter(state.out_edges(bcast_node))
)

# TODO(phimuell): Add warning.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should add a warning. Either allow it or not.

map_params: list[str] = (
[f"__bcast{dst_dim}" for dst_dim in range(len(output_edge.data.subset))]
if bcast_node.params is None
else list(bcast_node.params)
)

output_subset: list[str] = map_params.copy()
map_ranges: dict[str, dace_sbs.Range] = {
map_param: dace_sbs.Range([sbs])
for map_param, sbs in zip(map_params, output_edge.data.subset)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for map_param, sbs in zip(map_params, output_edge.data.subset)
for map_param, sbs in zip(map_params, output_edge.data.subset, strict=True)

}

input_subset: list[str]
if len(bcast_node.brodcast_in_dims) == 0:
input_subset = ["0"]
else:
bcast_value_offset = input_edge.data.subset.min_element()
input_subset = [
f"{map_params[dst_dim]} + ({offset})"
for dst_dim, offset in zip(bcast_node.brodcast_in_dims, bcast_value_offset)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for dst_dim, offset in zip(bcast_node.brodcast_in_dims, bcast_value_offset)
for dst_dim, offset in zip(bcast_node.brodcast_in_dims, bcast_value_offset, strict=True)

]

me, mx = state.add_map(f"__gt4py_broadcast_map_{bcast_node.name}", ndrange=map_ranges)
bcast_tlet = state.add_tasklet(
f"__gt4py_broadcast_tasklet_{bcast_node.name}",
inputs={"__in"},
outputs={"__out"},
code="__out = __in",
)

state.add_edge(
input_edge.src,
input_edge.src_conn,
me,
f"IN_{input_edge.data.data}",
dace.Memlet.from_array(input_edge.data.data, sdfg.arrays[input_edge.data.data]),
)
state.add_edge(
me,
f"OUT_{input_edge.data.data}",
bcast_tlet,
"__in",
dace.Memlet(data=input_edge.data.data, subset=", ".join(input_subset)),
)
me.add_scope_connectors(input_edge.data.data)
Comment on lines +204 to +218
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state.add_edge(
input_edge.src,
input_edge.src_conn,
me,
f"IN_{input_edge.data.data}",
dace.Memlet.from_array(input_edge.data.data, sdfg.arrays[input_edge.data.data]),
)
state.add_edge(
me,
f"OUT_{input_edge.data.data}",
bcast_tlet,
"__in",
dace.Memlet(data=input_edge.data.data, subset=", ".join(input_subset)),
)
me.add_scope_connectors(input_edge.data.data)
state.add_memlet_path(
input_edge.src,
me,
bcast_tlet,
src_conn=input_edge.src_conn,
dst_conn="__in",
dace.Memlet(data=input_edge.data.data, subset=", ".join(input_subset)),
)


state.add_edge(
bcast_tlet,
"__out",
mx,
f"IN_{output_edge.data.data}",
dace.Memlet(data=output_edge.data.data, subset=", ".join(output_subset)),
)
state.add_edge(
mx,
f"OUT_{output_edge.data.data}",
output_edge.dst,
output_edge.dst_conn,
dace.Memlet(data=output_edge.data.data, subset=copy.deepcopy(output_edge.data.subset)),
)
mx.add_scope_connectors(output_edge.data.data)
Comment on lines +220 to +234
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state.add_edge(
bcast_tlet,
"__out",
mx,
f"IN_{output_edge.data.data}",
dace.Memlet(data=output_edge.data.data, subset=", ".join(output_subset)),
)
state.add_edge(
mx,
f"OUT_{output_edge.data.data}",
output_edge.dst,
output_edge.dst_conn,
dace.Memlet(data=output_edge.data.data, subset=copy.deepcopy(output_edge.data.subset)),
)
mx.add_scope_connectors(output_edge.data.data)
state.add_memlet_path(
bcast_tlet,
mx,
output_edge.dst,
src_conn="__out",
dst_conn=output_edge.dst_conn,
dace.Memlet(data=output_edge.data.data, subset=copy.deepcopy(output_edge.data.subset)),
)


# Now delete the node.
state.remove_node(bcast_node)


@dace_library.register_expansion(Broadcast, "pure")
class BroadcastExpandInlined(dace_transform.ExpandTransformation):
"""Implements pure expansion of the Broadcast library node."""

environments: Final[list[Any]] = []

@staticmethod
def expansion(node: Broadcast, state: dace.SDFGState, sdfg: dace.SDFG) -> dace.SDFG:
# TODO:
# - Modify the edges on the outside.
# - Handle the missing symbols.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The missing symbols are automatically mapped to symbols with the same name, aren't they?


# NOTE: We have to cheat a here a bit. Actually only parts of the output
# would be mapped into the nested SDFG.
assert isinstance(node, Broadcast)
assert state.out_degree(node) == 1 and state.in_degree(node) == 1

nsdfg = dace.SDFG(f"__gt4py_broadcast_expansion_{node.label}")
bcast_st = nsdfg.add_state(f"__gt4py_broadcast_expansion_{node.label}_state")

input_edge = next(state.in_edges_by_connector(node, _INPUT_NAME))
output_edge = next(state.out_edges_by_connector(node, _OUTPUT_NAME))
bcast_value = input_edge.src
bcast_result = output_edge.dst

# Creating the input and output data inside the nested SDFG, such that we can
# map them _fully_ (see later) into the nested SDFG.
bcast_value_inner_data = nsdfg.add_datadesc(_INPUT_NAME, bcast_value.desc(sdfg).clone())
bcast_result_inner_data = nsdfg.add_datadesc(_OUTPUT_NAME, bcast_result.desc(sdfg).clone())
nsdfg.arrays[bcast_value_inner_data].transient = False
nsdfg.arrays[bcast_result_inner_data].transient = False

inner_bcast_node = copy.deepcopy(node)
inner_bcast_value_edge = bcast_st.add_edge(
bcast_st.add_access(bcast_value_inner_data),
input_edge.src_conn,
inner_bcast_node,
"_inp",
copy.deepcopy(input_edge.data),
)
inner_bcast_value_edge.data.data = bcast_value_inner_data

inner_bcast_result_edge = bcast_st.add_edge(
inner_bcast_node,
"_outp",
bcast_st.add_access(bcast_result_inner_data),
output_edge.dst_conn,
copy.deepcopy(output_edge.data),
)
inner_bcast_result_edge.data.data = bcast_result_inner_data

# Now we run the inplace expansion on the node inside the nested SDFG.
inplace_broadcast_expander(inner_bcast_node, bcast_st, nsdfg)

# To ensure that the full data is passed into the nested SDFG, which we
# assumed because of how we want that everything is mapped into.
input_edge.data.subset = dace_sbs.Range.from_array(bcast_value.desc(sdfg))
output_edge.data.subset = dace_sbs.Range.from_array(bcast_result.desc(sdfg))

# NOTE: We will not update `nsdfg.symbols`, instead we will rely on
# `add_nested_sdfg()` that is called implicitly by the expansion driver.

return nsdfg
Loading