Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from typing import Optional, Protocol

from gt4py.next import common, utils
Expand Down Expand Up @@ -113,15 +114,15 @@ def _process_symbolic_domains_option(

if use_max_domain_range_on_unstructured_shift is None:
use_max_domain_range_on_unstructured_shift = _has_dynamic_domains(ir)
elif use_max_domain_range_on_unstructured_shift:
if not _has_dynamic_domains(ir):
warnings.warn(
"You are using static domains together with "
"'use_max_domain_range_on_unstructured_shift'. This is "
"likely not what you wanted.",
stacklevel=2,
)
if use_max_domain_range_on_unstructured_shift:
# TODO(havogt): ICON4Py uses this codepath as default for now. Once we use the minimal domain range, we should re-enable this warning.
# if not _has_dynamic_domains(ir):
# warnings.warn(
# "You are using static domains together with "
# "'use_max_domain_range_on_unstructured_shift'. This is "
# "likely not what you wanted.",
# stacklevel=2, # noqa: ERA001
# ) # noqa: ERA001, RUF100
assert not symbolic_domain_sizes, "Options are mutually exclusive."
symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) # type: ignore[assignment]
return symbolic_domain_sizes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -989,11 +989,18 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:

it = self.visit(node.args[1])
assert isinstance(it, IteratorExpr)
if not all(isinstance(index, SymbolExpr) for index in it.indices.values()):
raise NotImplementedError("Dynamic indices in neighbors expression are not supported.")

# make sure that the field can be dereferenced with the given connectivity type
assert any(dim == conn_type.codomain for dim, _ in it.field_domain)
field_codomain_origin = next(
origin for dim, origin in it.field_domain if dim == conn_type.codomain
)
# make sure that the iterator can access the connectivity table
assert conn_type.source_dim in it.indices
origin_index = it.indices[conn_type.source_dim]
assert isinstance(origin_index, SymbolExpr)
assert all(isinstance(index, SymbolExpr) for index in it.indices.values())
conn_source_index = it.indices[conn_type.source_dim]
assert isinstance(conn_source_index, SymbolExpr)

# initially, the storage for the connectivty tables is created as transient;
# when the tables are used, the storage is changed to non-transient,
Expand Down Expand Up @@ -1040,7 +1047,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
),
),
subset=dace_subsets.Range.from_string(
f"{origin_index.value}, 0:{conn_type.max_neighbors}"
f"{conn_source_index.value}, 0:{conn_type.max_neighbors}"
),
)
)
Expand All @@ -1055,7 +1062,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr:
index_connector = "__index"
field_connector = "__field"
output_connector = "__val"
tasklet_expression = f"{output_connector} = {field_connector}[{index_connector}]"
tasklet_expression = (
f"{output_connector} = {field_connector}[{index_connector} - {field_codomain_origin}]"
)
input_memlets = {
field_connector: self.sdfg.make_array_memlet(field_slice.dc_node.data),
index_connector: dace.Memlet(data=conn_slice.dc_node.data, subset=neighbor_idx),
Expand Down Expand Up @@ -1651,29 +1660,34 @@ def _make_unstructured_shift(
offset_expr: DataExpr,
) -> IteratorExpr:
"""Implements shift in unstructured domain by means of a neighbor table."""
# make sure that the field can be dereferenced with the given connectivity type
assert any(dim == conn_type.codomain for dim, _ in it.field_domain)
neighbor_dim = conn_type.codomain
origin_dim = conn_type.source_dim
origin_index = it.indices[origin_dim]
assert isinstance(origin_index, SymbolExpr)
# make sure that the iterator can access the connectivity table
assert conn_type.source_dim in it.indices
conn_source_index = it.indices[conn_type.source_dim]
assert isinstance(conn_source_index, SymbolExpr)

shifted_indices = {dim: idx for dim, idx in it.indices.items() if dim != origin_dim}
shifted_indices = {
dim: idx for dim, idx in it.indices.items() if dim != conn_type.source_dim
}
if isinstance(offset_expr, SymbolExpr):
# use memlet to retrieve the neighbor index
shifted_indices[neighbor_dim] = MemletExpr(
shifted_indices[conn_type.codomain] = MemletExpr(
dc_node=conn_node,
gt_field=ts.FieldType(
dims=[origin_dim],
dims=[conn_type.source_dim],
dtype=ts.ListType(
element_type=tt.from_dtype(conn_type.dtype), offset_type=_CONST_DIM
),
),
subset=dace_subsets.Range.from_string(f"{origin_index.value}, {offset_expr.value}"),
subset=dace_subsets.Range.from_string(
f"{conn_source_index.value}, {offset_expr.value}"
),
)
else:
# dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node
shifted_indices[neighbor_dim] = self._make_dynamic_neighbor_offset(
offset_expr, conn_node, origin_index
shifted_indices[conn_type.codomain] = self._make_dynamic_neighbor_offset(
offset_expr, conn_node, conn_source_index
)

return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,29 @@ def testee(a: cases.VField) -> cases.EField:
)


@pytest.mark.uses_unstructured_shift
@pytest.mark.uses_program_with_sliced_out_arguments
def test_unstructured_shift_with_non_zero_origin(unstructured_case):
if unstructured_case.backend is None:
pytest.xfail("Embedded backend requires contiguous inverse image.")

@gtx.field_operator
def testee(a: cases.VField) -> cases.EField:
return a(E2V[0])

a = cases.allocate(unstructured_case, testee, "a")()
out = cases.allocate(unstructured_case, testee, cases.RETURN)()

ORIGIN = 2
e2v_table = unstructured_case.offset_provider["E2V"].asnumpy()
neighbor_0_iter = iter(enumerate(e2v_table[:, 0]))
edge_start = next(i for i, v in neighbor_0_iter if v >= ORIGIN)
edge_stop = next(i for i, v in neighbor_0_iter if v < ORIGIN)

ref = a.ndarray[e2v_table[edge_start:edge_stop, 0]]
cases.verify(unstructured_case, testee, a[ORIGIN:], out=out[edge_start:edge_stop], ref=ref)


def test_horizontal_only_with_3d_mesh(unstructured_case_3d):
# test field operator operating only on horizontal fields while using an offset provider
# including a vertical dimension.
Expand Down Expand Up @@ -724,6 +747,29 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField:
cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b)


@pytest.mark.uses_unstructured_shift
@pytest.mark.uses_program_with_sliced_out_arguments
def test_neighbor_sum_with_non_zero_origin(unstructured_case):
if unstructured_case.backend is None:
pytest.xfail("Embedded backend requires contiguous inverse image.")

@gtx.field_operator
def testee(a: cases.VField) -> cases.EField:
return neighbor_sum(a(E2V), axis=E2VDim)

a = cases.allocate(unstructured_case, testee, "a")()
out = cases.allocate(unstructured_case, testee, cases.RETURN)()

ORIGIN = 2
e2v_table = unstructured_case.offset_provider["E2V"].asnumpy()
neighbor_iter = iter(enumerate(e2v_table))
edge_start = next(i for i, v in neighbor_iter if all(v >= ORIGIN))
edge_stop = next(i for i, v in neighbor_iter if any(v < ORIGIN))

ref = np.sum(a.ndarray[e2v_table[edge_start:edge_stop,]], axis=1)
cases.verify(unstructured_case, testee, a[ORIGIN:], out=out[edge_start:edge_stop], ref=ref)


@pytest.mark.uses_unstructured_shift
def test_nested_reduction(unstructured_case):
@gtx.field_operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provid

def _binding_source_unstructured(use_metrics: bool) -> str:
metrics_arg_index = 2
idx = [0, 4, 1, 5, 6, 7, 2, 9, 8, 3, 11, 10]
idx = [0, 4, 5, 1, 6, 7, 8, 2, 10, 9, 3, 12, 11]
if use_metrics:
idx = [idx + 1 if idx >= metrics_arg_index else idx for idx in idx]
return (
Expand All @@ -169,19 +169,20 @@ def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provid
args_1,
) = args
sdfg_call_args[{idx[0]}].value = args_0.__gt_buffer_info__.data_ptr
sdfg_call_args[{idx[1]}] = ctypes.c_int(args_0.__gt_buffer_info__.elem_strides[0])
sdfg_call_args[{idx[2]}].value = args_1.__gt_buffer_info__.data_ptr
sdfg_call_args[{idx[3]}] = ctypes.c_int(args_1.domain.ranges[0].start)
sdfg_call_args[{idx[4]}] = ctypes.c_int(args_1.domain.ranges[0].stop)
sdfg_call_args[{idx[5]}] = ctypes.c_int(args_1.__gt_buffer_info__.elem_strides[0])
sdfg_call_args[{idx[1]}] = ctypes.c_int(args_0.domain.ranges[0].start)
sdfg_call_args[{idx[2]}] = ctypes.c_int(args_0.__gt_buffer_info__.elem_strides[0])
sdfg_call_args[{idx[3]}].value = args_1.__gt_buffer_info__.data_ptr
sdfg_call_args[{idx[4]}] = ctypes.c_int(args_1.domain.ranges[0].start)
sdfg_call_args[{idx[5]}] = ctypes.c_int(args_1.domain.ranges[0].stop)
sdfg_call_args[{idx[6]}] = ctypes.c_int(args_1.__gt_buffer_info__.elem_strides[0])
table_E2V = offset_provider["E2V"]
sdfg_call_args[{idx[6]}].value = table_E2V.__gt_buffer_info__.data_ptr
sdfg_call_args[{idx[7]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[0])
sdfg_call_args[{idx[8]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[1])
sdfg_call_args[{idx[7]}].value = table_E2V.__gt_buffer_info__.data_ptr
sdfg_call_args[{idx[8]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[0])
sdfg_call_args[{idx[9]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[1])
table_V2E = offset_provider["V2E"]
sdfg_call_args[{idx[9]}].value = table_V2E.__gt_buffer_info__.data_ptr
sdfg_call_args[{idx[10]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[0])
sdfg_call_args[{idx[11]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[1])
sdfg_call_args[{idx[10]}].value = table_V2E.__gt_buffer_info__.data_ptr
sdfg_call_args[{idx[11]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[0])
sdfg_call_args[{idx[12]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[1])
"""
)

Expand Down
Loading