Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ markers = [
'uses_max_over: tests that use the max_over builtin',
'uses_mesh_with_skip_values: tests that use a mesh with skip values',
'uses_concat_where: tests that use the concat_where builtin',
'embedded_concat_where_infinite_domain: tests with concat_where resulting in an infinite domain',
'embedded_concat_where_non_contiguous_domain: tests with concat_where on non-contiguous domains',
'uses_program_metrics: tests that require backend support for program metrics',
'uses_program_with_sliced_out_arguments: tests that use a sliced argument which is not supported for non-mutable arrays, e.g. JAX',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
Expand Down
171 changes: 70 additions & 101 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,15 @@
import collections
import dataclasses
import functools
import itertools
from collections.abc import Callable, Sequence
from types import ModuleType

import numpy as np
from numpy import typing as npt

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import (
ClassVar,
Iterable,
Never,
Optional,
ParamSpec,
TypeAlias,
TypeVar,
cast,
)
from gt4py.eve.extended_typing import ClassVar, Never, Optional, ParamSpec, TypeAlias, TypeVar, cast
from gt4py.next import common, utils
from gt4py.next.embedded import (
common as embedded_common,
Expand Down Expand Up @@ -820,39 +812,6 @@ def _hyperslice(
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _compute_mask_slices(
mask: core_defs.NDArrayObject,
) -> list[tuple[bool, slice]]:
"""Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices."""
# TODO: does it make sense to upgrade this naive algorithm to numpy?
assert mask.ndim == 1
cur = bool(mask[0].item())
ind = 0
res = []
for i in range(1, mask.shape[0]):
# Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy
if (mask_i := bool(mask[i].item())) != cur:
res.append((cur, slice(ind, i)))
cur = mask_i
ind = i
res.append((cur, slice(ind, mask.shape[0])))
return res


def _trim_empty_domains(
lst: Iterable[tuple[bool, common.Domain]],
) -> list[tuple[bool, common.Domain]]:
"""Remove empty domains from beginning and end of the list."""
lst = list(lst)
if not lst:
return lst
if lst[0][1].is_empty():
return _trim_empty_domains(lst[1:])
if lst[-1][1].is_empty():
return _trim_empty_domains(lst[:-1])
return lst


def _to_field(
value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField]
) -> common.Field:
Expand Down Expand Up @@ -906,85 +865,95 @@ def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[c

def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field:
# TODO(havogt): this function could be extended to a general concat
# currently only concatenate along the given dimension and requires the fields to be ordered

if (
len(fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty()
):
raise ValueError("Fields to concatenate must not overlap.")
new_domain = _stack_domains(*[f.domain for f in fields], dim=dim)
# currently only concatenate along the given dimension
sorted_fields = sorted(fields, key=lambda f: f.domain[dim].unit_range.start)

for prev, curr in itertools.pairwise(sorted_fields):
if curr.domain[dim].unit_range.start < prev.domain[dim].unit_range.stop:
raise ValueError("Fields to concatenate must not overlap.")
Comment on lines +872 to +873
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand where the checks for the correctness of the input domains should be done. I mean, most of the private helper functions are only called from one place, so they could assume that the input domains have been already validated at the public entry point and skip other checks. However, the overlap check is done here while the contiguous check is done later when the result of another call returns None. I'm not sure if I'm missing something but, if possible, I'd suggest to concentrate the validation of the domains in a single point.

For example, one possibility

Suggested change
if curr.domain[dim].unit_range.start < prev.domain[dim].unit_range.stop:
raise ValueError("Fields to concatenate must not overlap.")
if (left := prev.domain[dim].unit_range.stop) != (right := curr.domain[dim].unit_range.start):
if left > right:
raise ValueError("Fields to concatenate must not overlap.")
if left < right:
raise NotImplementedError("Fields to concatenate must be contiguous.")

new_domain = _stack_domains(*[f.domain for f in sorted_fields], dim=dim)
if new_domain is None:
raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.")
nd_array_class = _get_nd_array_class(*fields)
nd_array_class = _get_nd_array_class(*sorted_fields)
return nd_array_class.from_array(
nd_array_class.array_ns.concatenate(
[nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields],
[
nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape)
for f in sorted_fields
],
axis=new_domain.dim_index(dim, allow_missing=False),
),
domain=new_domain,
)


def _concat_where(
mask_field: common.Field, true_field: common.Field, false_field: common.Field
) -> common.Field:
cls_ = _get_nd_array_class(mask_field, true_field, false_field)
xp = cls_.array_ns
if mask_field.domain.ndim != 1:
raise NotImplementedError(
"'concat_where': Can only concatenate fields with a 1-dimensional mask."
def _invert_domain(domain: common.Domain) -> tuple[common.Domain, ...]:
assert domain.ndim == 1
dim = domain.dims[0]
rng = domain.ranges[0]

result = []
if rng.start is not common.Infinity.NEGATIVE:
result.append(
common.Domain(
dims=(dim,), ranges=(common.UnitRange(common.Infinity.NEGATIVE, rng.start),)
)
)
mask_dim = mask_field.domain.dims[0]
if rng.stop is not common.Infinity.POSITIVE:
result.append(
common.Domain(
dims=(dim,), ranges=(common.UnitRange(rng.stop, common.Infinity.POSITIVE),)
)
)
return tuple(result)

# intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain
t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim)

# TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils
# compute the consecutive ranges (first relative, then domain) of true and false values
mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices(
mask_field.ndarray
)
mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = (
(mask, mask_field.domain.slice_at[domain_slice])
for mask, domain_slice in mask_values_to_slices_mapping
def _size0_field(
nd_array_class: type[NdArrayField], dims: tuple[common.Dimension, ...], dtype: core_defs.DType
) -> NdArrayField:
return nd_array_class.from_array(
nd_array_class.array_ns.empty((0,) * len(dims), dtype=dtype.scalar_type),
domain=common.Domain(dims=dims, ranges=(common.UnitRange(0, 0),) * len(dims)),
)
# mask domains intersected with the respective fields
mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = (
(
mask_value,
embedded_common.domain_intersection(
t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain
),


def _concat_where(
domain: common.Domain,
true_field: common.Field,
false_field: common.Field,
) -> common.Field:
if domain.ndim != 1:
raise NotImplementedError(
"'concat_where': Can only concatenate fields with a 1-dimensional domain."
)
for mask_value, mask_domain in mask_values_to_domain_mapping
)
domain_dim = domain.dims[0]

# remove the empty domains from the beginning and end
mask_values_to_intersected_domains_mapping = _trim_empty_domains(
mask_values_to_intersected_domains_mapping
# intersect the field in dimensions orthogonal to the domain, then all slices in the domain field have same domain
t_broadcasted, f_broadcasted = _intersect_fields(
true_field, false_field, ignore_dims=domain_dim
)
if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping):
raise embedded_exceptions.NonContiguousDomain(
f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}."
)

# slice the fields with the domain ranges
transformed = [
t_broadcasted[d] if v else f_broadcasted[d]
for v, d in mask_values_to_intersected_domains_mapping
]
true_domain = embedded_common.domain_intersection(t_broadcasted.domain, domain)
t_slices = () if true_domain.is_empty() else (t_broadcasted[true_domain],)

# stack the fields together
if transformed:
return _concat(*transformed, dim=mask_dim)
else:
result_domain = common.Domain(common.NamedRange(mask_dim, common.UnitRange(0, 0)))
result_array = xp.empty(result_domain.shape)
return cls_.from_array(result_array, domain=result_domain)
inverted_domains = _invert_domain(domain)
false_domains = tuple(
intersection
for d in inverted_domains
if not (
intersection := embedded_common.domain_intersection(f_broadcasted.domain, d)
).is_empty()
)
f_slices = tuple(f_broadcasted[d] for d in false_domains)

if len(t_slices) + len(f_slices) == 0:
# no data to concatenate, return an empty field
nd_array_class = _get_nd_array_class(true_field, false_field)
return _size0_field(nd_array_class, dims=t_broadcasted.domain.dims, dtype=true_field.dtype)
return _concat(*f_slices, *t_slices, dim=domain_dim)


NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # TODO(havogt): this is still the "old" concat_where, needs to be replaced in a next PR
NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type]


def _make_reduction(
Expand Down
39 changes: 20 additions & 19 deletions src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,31 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi

@WhereBuiltinFunction
def concat_where(
cond: common.Domain,
domain: common.Domain,
true_field: common.Field | core_defs.ScalarT | Tuple | named_collections.CustomNamedCollection,
false_field: common.Field | core_defs.ScalarT | Tuple | named_collections.CustomNamedCollection,
/,
) -> common.Field | Tuple:
"""
Concatenates two field fields based on a 1D mask.

The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields.
Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain.

TODO(havogt): I can't get this doctest to run, even after copying the __doc__ in the decorator
Example:
>>> I = common.Dimension("I")
>>> mask = common._field([True, False, True], domain={I: (0, 3)})
>>> true_field = common._field([1, 2], domain={I: (0, 2)})
>>> false_field = common._field([3, 4, 5], domain={I: (1, 4)})
>>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)})

>>> mask = common._field([True, False, True], domain={I: (0, 3)})
>>> true_field = common._field([1, 2, 3], domain={I: (0, 3)})
>>> false_field = common._field(
... [4], domain={I: (2, 3)}
... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values
Assemble a field by selecting from ``true_field`` where ``domain`` applies and from ``false_field`` elsewhere.

Unlike ``where`` (element-wise selection via a boolean mask field), ``concat_where``
works on **domain regions**: the condition is a ``Domain`` (not a ``Field``), and the
result is the concatenation of slices from the two fields along one dimension.
Each field only needs to cover its own region — they may be non-overlapping.

The condition must be a 1D ``Domain`` (e.g. ``I < 5``).

Args:
domain: 1D Domain specifying the "true" region.
true_field: Field (or scalar) providing values inside the domain region.
false_field: Field (or scalar) providing values outside the domain region.

Returns:
A new field whose domain is the concatenation of the contributed regions.

Raises:
NonContiguousDomain: If the resulting domain has interior gaps.
"""
raise NotImplementedError()

Expand Down
9 changes: 5 additions & 4 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_PROGRAM_METRICS = "uses_program_metrics"
USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo"
USES_CONCAT_WHERE = "uses_concat_where"
EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN = "embedded_concat_where_infinite_domain"
EMBEDDED_CONCAT_WHERE_NON_CONTIGUOUS_DOMAIN = "embedded_concat_where_non_contiguous_domain"
USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS = "uses_program_with_sliced_out_arguments"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"

Expand Down Expand Up @@ -167,7 +169,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
XFAIL,
UNSUPPORTED_MESSAGE,
), # we can't extract the field type from scan args
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
(EMBEDDED_CONCAT_WHERE_INFINITE_DOMAIN, XFAIL, UNSUPPORTED_MESSAGE),
(EMBEDDED_CONCAT_WHERE_NON_CONTIGUOUS_DOMAIN, XFAIL, UNSUPPORTED_MESSAGE),
]
JAX_EMBEDDED_SKIP_LIST = EMBEDDED_SKIP_LIST + [
(USES_PROGRAM_WITH_SLICED_OUT_ARGUMENTS, XFAIL, UNSUPPORTED_MESSAGE),
Expand All @@ -178,9 +181,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
]
GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + [
(USES_CONCAT_WHERE, XFAIL, UNSUPPORTED_MESSAGE),
]
GTIR_EMBEDDED_SKIP_LIST = ROUNDTRIP_SKIP_LIST + []
GTFN_SKIP_TEST_LIST = (
COMMON_SKIP_TEST_LIST
+ DOMAIN_INFERENCE_SKIP_LIST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def testee(a: cases.IJKField, b: cases.IJKField, N: np.int32) -> cases.IJKField:
cases.verify(cartesian_case, testee, a, b, N, out=out, ref=a.asnumpy())


@pytest.mark.embedded_concat_where_infinite_domain
def test_concat_where_scalar_broadcast(cartesian_case):
@gtx.field_operator
def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField:
Expand All @@ -97,6 +98,7 @@ def testee(a: np.int32, b: cases.IJKField, N: np.int32) -> cases.IJKField:
cases.verify(cartesian_case, testee, a, b, cartesian_case.default_sizes[KDim], out=out, ref=ref)


@pytest.mark.embedded_concat_where_infinite_domain
def test_concat_where_scalar_broadcast_on_empty_branch(cartesian_case):
"""Output domain such that the scalar branch is never active."""

Expand Down Expand Up @@ -253,6 +255,7 @@ def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField:
cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref)


@pytest.mark.embedded_concat_where_non_contiguous_domain
def test_dimension_two_conditions_or(cartesian_case):
@gtx.field_operator
def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField:
Expand All @@ -272,11 +275,19 @@ def test_lap_like(cartesian_case):
def testee(
inp: cases.IJField, boundary: np.int32, shape: tuple[np.int32, np.int32]
) -> cases.IJField:
# TODO add support for multi-dimensional concat_where masks
# TODO(havogt) add support for multi-dimensional concat_where and non-contiguous unions
return concat_where(
(IDim == 0) | (IDim == shape[0] - 1),
(IDim == 0),
boundary,
concat_where((JDim == 0) | (JDim == shape[1] - 1), boundary, inp),
concat_where(
IDim == shape[0] - 1,
boundary,
concat_where(
JDim == 0,
boundary,
concat_where(JDim == shape[1] - 1, boundary, inp),
),
),
)

out = cases.allocate(cartesian_case, testee, cases.RETURN)()
Expand Down
Loading