diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index a6228c6125..1236a3209a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import warnings from typing import Optional, Protocol from gt4py.next import common, utils @@ -113,15 +114,15 @@ def _process_symbolic_domains_option( if use_max_domain_range_on_unstructured_shift is None: use_max_domain_range_on_unstructured_shift = _has_dynamic_domains(ir) + elif use_max_domain_range_on_unstructured_shift: + if not _has_dynamic_domains(ir): + warnings.warn( + "You are using static domains together with " + "'use_max_domain_range_on_unstructured_shift'. This is " + "likely not what you wanted.", + stacklevel=2, + ) if use_max_domain_range_on_unstructured_shift: - # TODO(havogt): ICON4Py uses this codepath as default for now. Once we use the minimal domain range, we should re-enable this warning. - # if not _has_dynamic_domains(ir): - # warnings.warn( - # "You are using static domains together with " - # "'use_max_domain_range_on_unstructured_shift'. This is " - # "likely not what you wanted.", - # stacklevel=2, # noqa: ERA001 - # ) # noqa: ERA001, RUF100 assert not symbolic_domain_sizes, "Options are mutually exclusive." symbolic_domain_sizes = _max_domain_range_sizes(offset_provider) # type: ignore[assignment] return symbolic_domain_sizes diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index da590d84e0..e7dc855428 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -989,11 +989,18 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) + if not all(isinstance(index, SymbolExpr) for index in it.indices.values()): + raise NotImplementedError("Dynamic indices in neighbors expression are not supported.") + + # make sure that the field can be dereferenced with the given connectivity type assert any(dim == conn_type.codomain for dim, _ in it.field_domain) + field_codomain_origin = next( + origin for dim, origin in it.field_domain if dim == conn_type.codomain + ) + # make sure that the iterator can access the connectivity table assert conn_type.source_dim in it.indices - origin_index = it.indices[conn_type.source_dim] - assert isinstance(origin_index, SymbolExpr) - assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) + conn_source_index = it.indices[conn_type.source_dim] + assert isinstance(conn_source_index, SymbolExpr) # initially, the storage for the connectivty tables is created as transient; # when the tables are used, the storage is changed to non-transient, @@ -1040,7 +1047,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ), ), subset=dace_subsets.Range.from_string( - f"{origin_index.value}, 0:{conn_type.max_neighbors}" + f"{conn_source_index.value}, 0:{conn_type.max_neighbors}" ), ) ) @@ -1055,7 +1062,9 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: index_connector = "__index" field_connector = "__field" output_connector = "__val" - tasklet_expression = f"{output_connector} = {field_connector}[{index_connector}]" + tasklet_expression = ( + f"{output_connector} = {field_connector}[{index_connector} - {field_codomain_origin}]" + ) input_memlets = { field_connector: self.sdfg.make_array_memlet(field_slice.dc_node.data), index_connector: dace.Memlet(data=conn_slice.dc_node.data, subset=neighbor_idx), @@ -1651,29 +1660,34 @@ def _make_unstructured_shift( offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" + # make sure that the field can be dereferenced with the given connectivity type assert any(dim == conn_type.codomain for dim, _ in it.field_domain) - neighbor_dim = conn_type.codomain - origin_dim = conn_type.source_dim - origin_index = it.indices[origin_dim] - assert isinstance(origin_index, SymbolExpr) + # make sure that the iterator can access the connectivity table + assert conn_type.source_dim in it.indices + conn_source_index = it.indices[conn_type.source_dim] + assert isinstance(conn_source_index, SymbolExpr) - shifted_indices = {dim: idx for dim, idx in it.indices.items() if dim != origin_dim} + shifted_indices = { + dim: idx for dim, idx in it.indices.items() if dim != conn_type.source_dim + } if isinstance(offset_expr, SymbolExpr): # use memlet to retrieve the neighbor index - shifted_indices[neighbor_dim] = MemletExpr( + shifted_indices[conn_type.codomain] = MemletExpr( dc_node=conn_node, gt_field=ts.FieldType( - dims=[origin_dim], + dims=[conn_type.source_dim], dtype=ts.ListType( element_type=tt.from_dtype(conn_type.dtype), offset_type=_CONST_DIM ), ), - subset=dace_subsets.Range.from_string(f"{origin_index.value}, {offset_expr.value}"), + subset=dace_subsets.Range.from_string( + f"{conn_source_index.value}, {offset_expr.value}" + ), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node - shifted_indices[neighbor_dim] = self._make_dynamic_neighbor_offset( - offset_expr, conn_node, origin_index + shifted_indices[conn_type.codomain] = self._make_dynamic_neighbor_offset( + offset_expr, conn_node, conn_source_index ) return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index c58ac5f497..af1410b2be 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -140,6 +140,29 @@ def testee(a: cases.VField) -> cases.EField: ) +@pytest.mark.uses_unstructured_shift +@pytest.mark.uses_program_with_sliced_out_arguments +def test_unstructured_shift_with_non_zero_origin(unstructured_case): + if unstructured_case.backend is None: + pytest.xfail("Embedded backend requires contiguous inverse image.") + + @gtx.field_operator + def testee(a: cases.VField) -> cases.EField: + return a(E2V[0]) + + a = cases.allocate(unstructured_case, testee, "a")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + ORIGIN = 2 + e2v_table = unstructured_case.offset_provider["E2V"].asnumpy() + neighbor_0_iter = iter(enumerate(e2v_table[:, 0])) + edge_start = next(i for i, v in neighbor_0_iter if v >= ORIGIN) + edge_stop = next(i for i, v in neighbor_0_iter if v < ORIGIN) + + ref = a.ndarray[e2v_table[edge_start:edge_stop, 0]] + cases.verify(unstructured_case, testee, a[ORIGIN:], out=out[edge_start:edge_stop], ref=ref) + + def test_horizontal_only_with_3d_mesh(unstructured_case_3d): # test field operator operating only on horizontal fields while using an offset provider # including a vertical dimension. @@ -724,6 +747,29 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b) +@pytest.mark.uses_unstructured_shift +@pytest.mark.uses_program_with_sliced_out_arguments +def test_neighbor_sum_with_non_zero_origin(unstructured_case): + if unstructured_case.backend is None: + pytest.xfail("Embedded backend requires contiguous inverse image.") + + @gtx.field_operator + def testee(a: cases.VField) -> cases.EField: + return neighbor_sum(a(E2V), axis=E2VDim) + + a = cases.allocate(unstructured_case, testee, "a")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + ORIGIN = 2 + e2v_table = unstructured_case.offset_provider["E2V"].asnumpy() + neighbor_iter = iter(enumerate(e2v_table)) + edge_start = next(i for i, v in neighbor_iter if all(v >= ORIGIN)) + edge_stop = next(i for i, v in neighbor_iter if any(v < ORIGIN)) + + ref = np.sum(a.ndarray[e2v_table[edge_start:edge_stop,]], axis=1) + cases.verify(unstructured_case, testee, a[ORIGIN:], out=out[edge_start:edge_stop], ref=ref) + + @pytest.mark.uses_unstructured_shift def test_nested_reduction(unstructured_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py index 855b315009..25dae344f2 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py @@ -157,7 +157,7 @@ def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provid def _binding_source_unstructured(use_metrics: bool) -> str: metrics_arg_index = 2 - idx = [0, 4, 1, 5, 6, 7, 2, 9, 8, 3, 11, 10] + idx = [0, 4, 5, 1, 6, 7, 8, 2, 10, 9, 3, 12, 11] if use_metrics: idx = [idx + 1 if idx >= metrics_arg_index else idx for idx in idx] return ( @@ -169,19 +169,20 @@ def {_bind_func_name}(device, sdfg_argtypes, args, sdfg_call_args, offset_provid args_1, ) = args sdfg_call_args[{idx[0]}].value = args_0.__gt_buffer_info__.data_ptr - sdfg_call_args[{idx[1]}] = ctypes.c_int(args_0.__gt_buffer_info__.elem_strides[0]) - sdfg_call_args[{idx[2]}].value = args_1.__gt_buffer_info__.data_ptr - sdfg_call_args[{idx[3]}] = ctypes.c_int(args_1.domain.ranges[0].start) - sdfg_call_args[{idx[4]}] = ctypes.c_int(args_1.domain.ranges[0].stop) - sdfg_call_args[{idx[5]}] = ctypes.c_int(args_1.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[1]}] = ctypes.c_int(args_0.domain.ranges[0].start) + sdfg_call_args[{idx[2]}] = ctypes.c_int(args_0.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[3]}].value = args_1.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[4]}] = ctypes.c_int(args_1.domain.ranges[0].start) + sdfg_call_args[{idx[5]}] = ctypes.c_int(args_1.domain.ranges[0].stop) + sdfg_call_args[{idx[6]}] = ctypes.c_int(args_1.__gt_buffer_info__.elem_strides[0]) table_E2V = offset_provider["E2V"] - sdfg_call_args[{idx[6]}].value = table_E2V.__gt_buffer_info__.data_ptr - sdfg_call_args[{idx[7]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[0]) - sdfg_call_args[{idx[8]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[1]) + sdfg_call_args[{idx[7]}].value = table_E2V.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[8]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[9]}] = ctypes.c_int(table_E2V.__gt_buffer_info__.elem_strides[1]) table_V2E = offset_provider["V2E"] - sdfg_call_args[{idx[9]}].value = table_V2E.__gt_buffer_info__.data_ptr - sdfg_call_args[{idx[10]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[0]) - sdfg_call_args[{idx[11]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[1]) + sdfg_call_args[{idx[10]}].value = table_V2E.__gt_buffer_info__.data_ptr + sdfg_call_args[{idx[11]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[0]) + sdfg_call_args[{idx[12]}] = ctypes.c_int(table_V2E.__gt_buffer_info__.elem_strides[1]) """ )