diff --git a/doc/api.rst b/doc/api.rst index b6023866eb8..b1fbf5eaebb 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -789,6 +789,7 @@ Index into all nodes in the subtree simultaneously. DataTree.isel DataTree.sel + DataTree.subset .. DataTree.drop_sel .. DataTree.drop_isel diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c8384e3f1eb..251ba0103e4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -12,6 +12,8 @@ v2025.07.0 (unreleased) New Features ~~~~~~~~~~~~ +- Added :py:meth:`~xarray.DataTree.subset` to index variables on all nodes of a datatree (:pull:`10400`) + By `Mathias Hauser `_. Breaking changes diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 734927fd3d1..0ba4d57739f 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -10,6 +10,7 @@ Iterable, Iterator, Mapping, + Sequence, ) from html import escape from typing import ( @@ -1014,6 +1015,35 @@ def __delitem__(self, key: str) -> None: else: raise KeyError(key) + def subset( + self, keys: str | Sequence[str], *, errors: ErrorOptions = "raise" + ) -> DataTree: + """Index DataArrays on each node + + Parameters + ---------- + keys : str | Sequence[str] + Name of the data variables to index. + errors : "raise", "ignore" + Whether to raise a key error if a data variable is missing on a node. + + Returns + ------- + out : DataTree + """ + + if isinstance(keys, str): + keys = [keys] + + def getitem(ds): + keys_for_ds = keys + if errors == "ignore": + keys_for_ds = [key for key in keys if key in ds.data_vars] + + return ds[keys_for_ds] + + return map_over_datasets(getitem, self) + @overload def update(self, other: Dataset) -> None: ... diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 82c624b9bf6..fc7833fd8e4 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -333,6 +333,34 @@ def test_getitem_dict_like_selection_access_to_dataset(self) -> None: assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] +def test_subset() -> None: + ds1 = xr.Dataset(data_vars={"var1": ("x", [1, 2]), "var2": ("x", [0, 1])}) + ds2 = xr.Dataset(data_vars={"var1": ("x", [1, 2])}) + dt = xr.DataTree.from_dict({"ds1": ds1, "ds2": ds2}) + + dt_var1 = xr.DataTree.from_dict({"ds1": ds1[["var1"]], "ds2": ds2}) + + # errors as map_over_datasets does not skip empty nodes + with pytest.raises(KeyError, match="var1"): + dt.subset("var1") + + # will still error if map_over_datasets will ever skip empty nodes + with pytest.raises(KeyError, match="var2"): + dt.subset("var2") + + result = dt.subset("var1", errors="ignore") + expected = dt_var1 + xr.testing.assert_equal(result, expected) + + result = dt.subset(["var1"], errors="ignore") + expected = dt_var1 + xr.testing.assert_equal(result, expected) + + result = dt.subset(["var1", "var2"], errors="ignore") + expected = dt + xr.testing.assert_equal(result, expected) + + class TestUpdate: def test_update(self) -> None: dt = DataTree()