diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c1bfaba8756..92126cf5143 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -77,6 +77,9 @@ Internal Changes - Migrates ``treenode`` functionality into ``xarray/core`` (:pull:`8757`) By `Matt Savoie `_ and `Tom Nicholas `_. +- Migrates ``datatree`` functionality into ``xarray/core``. (:pull: `8789`) + By `Owen Littlejohns `_, `Matt Savoie + `_ and `Tom Nicholas `_. .. _whats-new.2024.02.0: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index d3026a535e2..637eea4d076 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -69,7 +69,7 @@ T_NetcdfTypes = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 7d3cc00a52d..f318b4dd42f 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -23,8 +23,8 @@ from netCDF4 import Dataset as ncDataset from xarray.core.dataset import Dataset + from xarray.core.datatree import DataTree from xarray.core.types import NestedSequence - from xarray.datatree_.datatree import DataTree # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -137,8 +137,8 @@ def _open_datatree_netcdf( **kwargs, ) -> DataTree: from xarray.backends.api import open_dataset + from xarray.core.datatree import DataTree from xarray.core.treenode import NodePath - from xarray.datatree_.datatree import DataTree ds = open_dataset(filename_or_obj, **kwargs) tree_root = DataTree.from_dict({"/": ds}) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index b7c1b2a5f03..654086e50c7 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -39,7 +39,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree class H5NetCDFArrayWrapper(BaseNetCDF4Array): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 6720a67ae2f..ae86c4ce384 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -45,7 +45,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index e9465dc0ba0..13b1819f206 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -34,7 +34,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset - from xarray.datatree_.datatree import DataTree + from xarray.core.datatree import DataTree # need some special secret attributes to tell us the dimensions @@ -1048,8 +1048,8 @@ def open_datatree( import zarr from xarray.backends.api import open_dataset + from xarray.core.datatree import DataTree from xarray.core.treenode import NodePath - from xarray.datatree_.datatree import DataTree zds = zarr.open_group(filename_or_obj, mode="r") ds = open_dataset(filename_or_obj, engine="zarr", **kwargs) diff --git a/xarray/datatree_/datatree/datatree.py b/xarray/core/datatree.py similarity index 89% rename from xarray/datatree_/datatree/datatree.py rename to xarray/core/datatree.py index 10133052185..1b06d87c9fb 100644 --- a/xarray/datatree_/datatree/datatree.py +++ b/xarray/core/datatree.py @@ -2,24 +2,14 @@ import copy import itertools -from collections import OrderedDict +from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping from html import escape from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Generic, - Hashable, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, NoReturn, - Optional, - Set, - Tuple, Union, overload, ) @@ -31,6 +21,7 @@ from xarray.core.indexes import Index, Indexes from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS +from xarray.core.treenode import NamedNode, NodePath, Tree from xarray.core.utils import ( Default, Frozen, @@ -40,17 +31,22 @@ maybe_wrap_array, ) from xarray.core.variable import Variable - -from . import formatting, formatting_html -from .common import TreeAttrAccessMixin -from .mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree -from .ops import ( +from xarray.datatree_.datatree.common import TreeAttrAccessMixin +from xarray.datatree_.datatree.formatting import datatree_repr +from xarray.datatree_.datatree.formatting_html import ( + datatree_repr as datatree_repr_html, +) +from xarray.datatree_.datatree.mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) +from xarray.datatree_.datatree.ops import ( DataTreeArithmeticMixin, MappedDatasetMethodsMixin, MappedDataWithCoords, ) -from .render import RenderTree -from xarray.core.treenode import NamedNode, NodePath, Tree +from xarray.datatree_.datatree.render import RenderTree try: from xarray.core.variable import calculate_dimensions @@ -60,6 +56,7 @@ if TYPE_CHECKING: import pandas as pd + from xarray.core.merge import CoercibleValue from xarray.core.types import ErrorOptions @@ -130,9 +127,9 @@ class DatasetView(Dataset): def __init__( self, - data_vars: Optional[Mapping[Any, Any]] = None, - coords: Optional[Mapping[Any, Any]] = None, - attrs: Optional[Mapping[Any, Any]] = None, + data_vars: Mapping[Any, Any] | None = None, + coords: Mapping[Any, Any] | None = None, + attrs: Mapping[Any, Any] | None = None, ): raise AttributeError("DatasetView objects are not to be initialized directly") @@ -169,33 +166,33 @@ def update(self, other) -> NoReturn: ) # FIXME https://github.com/python/mypy/issues/7328 - @overload - def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[misc] + @overload # type: ignore[override] + def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[overload-overlap] ... @overload - def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[misc] + def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[overload-overlap] ... + # See: https://github.com/pydata/xarray/issues/8855 @overload - def __getitem__(self, key: Any) -> Dataset: - ... + def __getitem__(self, key: Any) -> Dataset: ... - def __getitem__(self, key) -> DataArray: + def __getitem__(self, key) -> DataArray | Dataset: # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes # For now just call Dataset.__getitem__ return Dataset.__getitem__(self, key) @classmethod - def _construct_direct( + def _construct_direct( # type: ignore[override] cls, variables: dict[Any, Variable], coord_names: set[Hashable], - dims: Optional[dict[Any, int]] = None, - attrs: Optional[dict] = None, - indexes: Optional[dict[Any, Index]] = None, - encoding: Optional[dict] = None, - close: Optional[Callable[[], None]] = None, + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, + close: Callable[[], None] | None = None, ) -> Dataset: """ Overriding this method (along with ._replace) and modifying it to return a Dataset object @@ -215,13 +212,13 @@ def _construct_direct( obj._encoding = encoding return obj - def _replace( + def _replace( # type: ignore[override] self, - variables: Optional[dict[Hashable, Variable]] = None, - coord_names: Optional[set[Hashable]] = None, - dims: Optional[dict[Any, int]] = None, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: Optional[dict[Hashable, Index]] = None, + indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, inplace: bool = False, ) -> Dataset: @@ -244,7 +241,7 @@ def _replace( inplace=inplace, ) - def map( + def map( # type: ignore[override] self, func: Callable, keep_attrs: bool | None = None, @@ -259,7 +256,7 @@ def map( Function which can be called in the form `func(x, *args, **kwargs)` to transform each DataArray `x` in this dataset into another DataArray. - keep_attrs : bool or None, optional + keep_attrs : bool | None, optional If True, both the dataset's and variables' attributes (`attrs`) will be copied from the original objects to the new ones. If False, the new dataset and variables will be returned without copying the attributes. @@ -293,7 +290,7 @@ def map( bar (x) float64 16B 1.0 2.0 """ - # Copied from xarray.Dataset so as not to call type(self), which causes problems (see datatree GH188). + # Copied from xarray.Dataset so as not to call type(self), which causes problems (see https://github.com/xarray-contrib/datatree/issues/188). # TODO Refactor xarray upstream to avoid needing to overwrite this. # TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated variables = { @@ -333,21 +330,19 @@ class DataTree( # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from - # TODO __slots__ - # TODO all groupby classes - _name: Optional[str] - _parent: Optional[DataTree] - _children: OrderedDict[str, DataTree] - _attrs: Optional[Dict[Hashable, Any]] - _cache: Dict[str, Any] - _coord_names: Set[Hashable] - _dims: Dict[Hashable, int] - _encoding: Optional[Dict[Hashable, Any]] - _close: Optional[Callable[[], None]] - _indexes: Dict[Hashable, Index] - _variables: Dict[Hashable, Variable] + _name: str | None + _parent: DataTree | None + _children: dict[str, DataTree] + _attrs: dict[Hashable, Any] | None + _cache: dict[str, Any] + _coord_names: set[Hashable] + _dims: dict[Hashable, int] + _encoding: dict[Hashable, Any] | None + _close: Callable[[], None] | None + _indexes: dict[Hashable, Index] + _variables: dict[Hashable, Variable] __slots__ = ( "_name", @@ -365,10 +360,10 @@ class DataTree( def __init__( self, - data: Optional[Dataset | DataArray] = None, - parent: Optional[DataTree] = None, - children: Optional[Mapping[str, DataTree]] = None, - name: Optional[str] = None, + data: Dataset | DataArray | None = None, + parent: DataTree | None = None, + children: Mapping[str, DataTree] | None = None, + name: str | None = None, ): """ Create a single node of a DataTree. @@ -446,7 +441,9 @@ def ds(self) -> DatasetView: return DatasetView._from_node(self) @ds.setter - def ds(self, data: Optional[Union[Dataset, DataArray]] = None) -> None: + def ds(self, data: Dataset | DataArray | None = None) -> None: + # Known mypy issue for setters with different type to property: + # https://github.com/python/mypy/issues/3004 ds = _coerce_to_dataset(data) _check_for_name_collisions(self.children, ds.variables) @@ -515,15 +512,14 @@ def is_hollow(self) -> bool: def variables(self) -> Mapping[Hashable, Variable]: """Low level interface to node contents as dict of Variable objects. - This ordered dictionary is frozen to prevent mutation that could - violate Dataset invariants. It contains all variable objects - constituting this DataTree node, including both data variables and - coordinates. + This dictionary is frozen to prevent mutation that could violate + Dataset invariants. It contains all variable objects constituting this + DataTree node, including both data variables and coordinates. """ return Frozen(self._variables) @property - def attrs(self) -> Dict[Hashable, Any]: + def attrs(self) -> dict[Hashable, Any]: """Dictionary of global attributes on this node object.""" if self._attrs is None: self._attrs = {} @@ -534,7 +530,7 @@ def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property - def encoding(self) -> Dict: + def encoding(self) -> dict: """Dictionary of global encoding attributes on this node object.""" if self._encoding is None: self._encoding = {} @@ -589,7 +585,7 @@ def _item_sources(self) -> Iterable[Mapping[Any, Any]]: # immediate child nodes yield self.children - def _ipython_key_completions_(self) -> List[str]: + def _ipython_key_completions_(self) -> list[str]: """Provide method for the key-autocompletions in IPython. See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion For the details. @@ -636,31 +632,31 @@ def __array__(self, dtype=None): "invoking the `to_array()` method." ) - def __repr__(self) -> str: - return formatting.datatree_repr(self) + def __repr__(self) -> str: # type: ignore[override] + return datatree_repr(self) def __str__(self) -> str: - return formatting.datatree_repr(self) + return datatree_repr(self) def _repr_html_(self): """Make html representation of datatree object""" if XR_OPTS["display_style"] == "text": return f"
{escape(repr(self))}
" - return formatting_html.datatree_repr(self) + return datatree_repr_html(self) @classmethod def _construct_direct( cls, variables: dict[Any, Variable], coord_names: set[Hashable], - dims: Optional[dict[Any, int]] = None, - attrs: Optional[dict] = None, - indexes: Optional[dict[Any, Index]] = None, - encoding: Optional[dict] = None, + dims: dict[Any, int] | None = None, + attrs: dict | None = None, + indexes: dict[Any, Index] | None = None, + encoding: dict | None = None, name: str | None = None, parent: DataTree | None = None, - children: Optional[OrderedDict[str, DataTree]] = None, - close: Optional[Callable[[], None]] = None, + children: dict[str, DataTree] | None = None, + close: Callable[[], None] | None = None, ) -> DataTree: """Shortcut around __init__ for internal use when we want to skip costly validation.""" @@ -670,7 +666,7 @@ def _construct_direct( if indexes is None: indexes = {} if children is None: - children = OrderedDict() + children = dict() obj: DataTree = object.__new__(cls) obj._variables = variables @@ -690,15 +686,15 @@ def _construct_direct( def _replace( self: DataTree, - variables: Optional[dict[Hashable, Variable]] = None, - coord_names: Optional[set[Hashable]] = None, - dims: Optional[dict[Any, int]] = None, + variables: dict[Hashable, Variable] | None = None, + coord_names: set[Hashable] | None = None, + dims: dict[Any, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, - indexes: Optional[dict[Hashable, Index]] = None, + indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, name: str | None | Default = _default, - parent: DataTree | None = _default, - children: Optional[OrderedDict[str, DataTree]] = None, + parent: DataTree | None | Default = _default, + children: dict[str, DataTree] | None = None, inplace: bool = False, ) -> DataTree: """ @@ -817,7 +813,7 @@ def _copy_node( """Copy just one node of a tree""" new_node: DataTree = DataTree() new_node.name = self.name - new_node.ds = self.to_dataset().copy(deep=deep) + new_node.ds = self.to_dataset().copy(deep=deep) # type: ignore[assignment] return new_node def __copy__(self: DataTree) -> DataTree: @@ -826,9 +822,9 @@ def __copy__(self: DataTree) -> DataTree: def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree: return self._copy_subtree(deep=True, memo=memo) - def get( - self: DataTree, key: str, default: Optional[DataTree | DataArray] = None - ) -> Optional[DataTree | DataArray]: + def get( # type: ignore[override] + self: DataTree, key: str, default: DataTree | DataArray | None = None + ) -> DataTree | DataArray | None: """ Access child nodes, variables, or coordinates stored in this node. @@ -839,7 +835,7 @@ def get( ---------- key : str Name of variable / child within this node. Must lie in this immediate node (not elsewhere in the tree). - default : DataTree | DataArray, optional + default : DataTree | DataArray | None, optional A value to return if the specified key does not exist. Default return value is None. """ if key in self.children: @@ -863,7 +859,7 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: Returns ------- - Union[DataTree, DataArray] + DataTree | DataArray """ # Either: @@ -926,21 +922,38 @@ def __setitem__( else: raise ValueError("Invalid format for key") - def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: + @overload + def update(self, other: Dataset) -> None: ... + + @overload + def update(self, other: Mapping[Hashable, DataArray | Variable]) -> None: ... + + @overload + def update(self, other: Mapping[str, DataTree | DataArray | Variable]) -> None: ... + + def update( + self, + other: ( + Dataset + | Mapping[Hashable, DataArray | Variable] + | Mapping[str, DataTree | DataArray | Variable] + ), + ) -> None: """ Update this node's children and / or variables. Just like `dict.update` this is an in-place operation. """ # TODO separate by type - new_children = {} + new_children: dict[str, DataTree] = {} new_variables = {} for k, v in other.items(): if isinstance(v, DataTree): # avoid named node being stored under inconsistent key - new_child = v.copy() - new_child.name = k - new_children[k] = new_child + new_child: DataTree = v.copy() + # Datatree's name is always a string until we fix that (#8836) + new_child.name = str(k) + new_children[str(k)] = new_child elif isinstance(v, (DataArray, Variable)): # TODO this should also accommodate other types that can be coerced into Variables new_variables[k] = v @@ -949,7 +962,7 @@ def update(self, other: Dataset | Mapping[str, DataTree | DataArray]) -> None: vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) # TODO are there any subtleties with preserving order of children like this? - merged_children = OrderedDict({**self.children, **new_children}) + merged_children = {**self.children, **new_children} self._replace( inplace=True, children=merged_children, **vars_merge_result._asdict() ) @@ -1027,16 +1040,16 @@ def drop_nodes( if extra: raise KeyError(f"Cannot drop all nodes - nodes {extra} not present") - children_to_keep = OrderedDict( - {name: child for name, child in self.children.items() if name not in names} - ) + children_to_keep = { + name: child for name, child in self.children.items() if name not in names + } return self._replace(children=children_to_keep) @classmethod def from_dict( cls, d: MutableMapping[str, Dataset | DataArray | DataTree | None], - name: Optional[str] = None, + name: str | None = None, ) -> DataTree: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1050,7 +1063,7 @@ def from_dict( tree nodes will be constructed as necessary. To assign data to the root node of the tree use "/" as the path. - name : Hashable, optional + name : Hashable | None, optional Name for the root node of the tree. Default is None. Returns @@ -1064,14 +1077,18 @@ def from_dict( # First create the root node root_data = d.pop("/", None) - obj = cls(name=name, data=root_data, parent=None, children=None) + if isinstance(root_data, DataTree): + obj = root_data.copy() + obj.orphan() + else: + obj = cls(name=name, data=root_data, parent=None, children=None) if d: # Populate tree with children determined from data_objects mapping for path, data in d.items(): # Create and set new node node_name = NodePath(path).name - if isinstance(data, cls): + if isinstance(data, DataTree): new_node = data.copy() new_node.orphan() else: @@ -1085,13 +1102,13 @@ def from_dict( return obj - def to_dict(self) -> Dict[str, Dataset]: + def to_dict(self) -> dict[str, Dataset]: """ Create a dictionary mapping of absolute node paths to the data contained in those nodes. Returns ------- - Dict[str, Dataset] + dict[str, Dataset] """ return {node.path: node.to_dataset() for node in self.subtree} @@ -1313,7 +1330,7 @@ def map_over_subtree( func: Callable, *args: Iterable[Any], **kwargs: Any, - ) -> DataTree | Tuple[DataTree]: + ) -> DataTree | tuple[DataTree]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. @@ -1336,13 +1353,13 @@ def map_over_subtree( Returns ------- - subtrees : DataTree, Tuple of DataTrees + subtrees : DataTree, tuple of DataTrees One or more subtrees containing results from applying ``func`` to the data at each node. """ # TODO this signature means that func has no way to know which node it is being called upon - change? # TODO fix this typing error - return map_over_subtree(func)(self, *args, **kwargs) # type: ignore[operator] + return map_over_subtree(func)(self, *args, **kwargs) def map_over_subtree_inplace( self, @@ -1449,8 +1466,8 @@ def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree: # TODO some kind of .collapse() or .flatten() method to merge a subtree - def as_array(self) -> DataArray: - return self.ds.as_dataarray() + def to_dataarray(self) -> DataArray: + return self.ds.to_dataarray() @property def groups(self): @@ -1485,7 +1502,7 @@ def to_netcdf( kwargs : Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` """ - from .io import _datatree_to_netcdf + from xarray.datatree_.datatree.io import _datatree_to_netcdf _datatree_to_netcdf( self, @@ -1515,7 +1532,7 @@ def to_zarr( Persistence mode: “w” means create (overwrite if exists); “w-” means create (fail if exists); “a” means override existing variables (create if does not exist); “r+” means modify existing array values only (raise an error if any metadata or shapes would change). The default mode - is “a” if append_dim is set. Otherwise, it is “r+” if region is set and w- otherwise. + is “w-”. encoding : dict, optional Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., @@ -1527,7 +1544,7 @@ def to_zarr( kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` """ - from .io import _datatree_to_zarr + from xarray.datatree_.datatree.io import _datatree_to_zarr _datatree_to_zarr( self, diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index b3e6e43f306..8cee3f69d70 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -230,7 +230,7 @@ def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None: pass def _iter_parents(self: Tree) -> Iterator[Tree]: - """Iterate up the tree, starting from the current node.""" + """Iterate up the tree, starting from the current node's parent.""" node: Tree | None = self.parent while node is not None: yield node diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py index 071dcbecf8c..f2603b64641 100644 --- a/xarray/datatree_/datatree/__init__.py +++ b/xarray/datatree_/datatree/__init__.py @@ -1,15 +1,11 @@ # import public API -from .datatree import DataTree -from .extensions import register_datatree_accessor from .mapping import TreeIsomorphismError, map_over_subtree from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError __all__ = ( - "DataTree", "TreeIsomorphismError", "InvalidTreeError", "NotFoundInTreeError", "map_over_subtree", - "register_datatree_accessor", ) diff --git a/xarray/datatree_/datatree/extensions.py b/xarray/datatree_/datatree/extensions.py index f6f4e985a79..bf888fc4484 100644 --- a/xarray/datatree_/datatree/extensions.py +++ b/xarray/datatree_/datatree/extensions.py @@ -1,6 +1,6 @@ from xarray.core.extensions import _register_accessor -from .datatree import DataTree +from xarray.core.datatree import DataTree def register_datatree_accessor(name): diff --git a/xarray/datatree_/datatree/formatting.py b/xarray/datatree_/datatree/formatting.py index deba57eb09d..9ebee72d4ef 100644 --- a/xarray/datatree_/datatree/formatting.py +++ b/xarray/datatree_/datatree/formatting.py @@ -2,11 +2,11 @@ from xarray.core.formatting import _compat_to_str, diff_dataset_repr -from .mapping import diff_treestructure -from .render import RenderTree +from xarray.datatree_.datatree.mapping import diff_treestructure +from xarray.datatree_.datatree.render import RenderTree if TYPE_CHECKING: - from .datatree import DataTree + from xarray.core.datatree import DataTree def diff_nodewise_summary(a, b, compat): diff --git a/xarray/datatree_/datatree/io.py b/xarray/datatree_/datatree/io.py index d3d533ee71e..48335ddca70 100644 --- a/xarray/datatree_/datatree/io.py +++ b/xarray/datatree_/datatree/io.py @@ -1,4 +1,4 @@ -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree def _get_nc_dataset_class(engine): diff --git a/xarray/datatree_/datatree/mapping.py b/xarray/datatree_/datatree/mapping.py index 355149060a9..7742ece9738 100644 --- a/xarray/datatree_/datatree/mapping.py +++ b/xarray/datatree_/datatree/mapping.py @@ -156,7 +156,7 @@ def map_over_subtree(func: Callable) -> Callable: @functools.wraps(func) def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" - from .datatree import DataTree + from xarray.core.datatree import DataTree all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [ a for a in kwargs.values() if isinstance(a, DataTree) diff --git a/xarray/datatree_/datatree/render.py b/xarray/datatree_/datatree/render.py index aef327c5c47..e6af9c85ee8 100644 --- a/xarray/datatree_/datatree/render.py +++ b/xarray/datatree_/datatree/render.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .datatree import DataTree + from xarray.core.datatree import DataTree Row = collections.namedtuple("Row", ("pre", "fill", "node")) diff --git a/xarray/datatree_/datatree/testing.py b/xarray/datatree_/datatree/testing.py index 1cbcdf2d4e3..bf54116725a 100644 --- a/xarray/datatree_/datatree/testing.py +++ b/xarray/datatree_/datatree/testing.py @@ -1,6 +1,6 @@ from xarray.testing.assertions import ensure_warnings -from .datatree import DataTree +from xarray.core.datatree import DataTree from .formatting import diff_tree_repr diff --git a/xarray/datatree_/datatree/tests/conftest.py b/xarray/datatree_/datatree/tests/conftest.py index bd2e7ba3247..53a9a72239d 100644 --- a/xarray/datatree_/datatree/tests/conftest.py +++ b/xarray/datatree_/datatree/tests/conftest.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree @pytest.fixture(scope="module") diff --git a/xarray/datatree_/datatree/tests/test_dataset_api.py b/xarray/datatree_/datatree/tests/test_dataset_api.py index c3eb74451a6..4ca532ebba4 100644 --- a/xarray/datatree_/datatree/tests/test_dataset_api.py +++ b/xarray/datatree_/datatree/tests/test_dataset_api.py @@ -1,7 +1,7 @@ import numpy as np import xarray as xr -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.testing import assert_equal diff --git a/xarray/datatree_/datatree/tests/test_extensions.py b/xarray/datatree_/datatree/tests/test_extensions.py index 0241e496abf..fb2e82453ec 100644 --- a/xarray/datatree_/datatree/tests/test_extensions.py +++ b/xarray/datatree_/datatree/tests/test_extensions.py @@ -1,6 +1,7 @@ import pytest -from xarray.datatree_.datatree import DataTree, register_datatree_accessor +from xarray.core.datatree import DataTree +from xarray.datatree_.datatree.extensions import register_datatree_accessor class TestAccessor: diff --git a/xarray/datatree_/datatree/tests/test_formatting.py b/xarray/datatree_/datatree/tests/test_formatting.py index b58c02282e7..77f8346ae72 100644 --- a/xarray/datatree_/datatree/tests/test_formatting.py +++ b/xarray/datatree_/datatree/tests/test_formatting.py @@ -2,7 +2,7 @@ from xarray import Dataset -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.formatting import diff_tree_repr diff --git a/xarray/datatree_/datatree/tests/test_formatting_html.py b/xarray/datatree_/datatree/tests/test_formatting_html.py index 943bbab4154..98cdf02bff4 100644 --- a/xarray/datatree_/datatree/tests/test_formatting_html.py +++ b/xarray/datatree_/datatree/tests/test_formatting_html.py @@ -1,7 +1,8 @@ import pytest import xarray as xr -from xarray.datatree_.datatree import DataTree, formatting_html +from xarray.core.datatree import DataTree +from xarray.datatree_.datatree import formatting_html @pytest.fixture(scope="module", params=["some html", "some other html"]) diff --git a/xarray/datatree_/datatree/tests/test_mapping.py b/xarray/datatree_/datatree/tests/test_mapping.py index 53d6e085440..c6cd04887c0 100644 --- a/xarray/datatree_/datatree/tests/test_mapping.py +++ b/xarray/datatree_/datatree/tests/test_mapping.py @@ -2,7 +2,7 @@ import pytest import xarray as xr -from xarray.datatree_.datatree.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree from xarray.datatree_.datatree.testing import assert_equal diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 8590c9fb4e7..a32b0e08bea 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -6,7 +6,7 @@ import xarray as xr from xarray import DataArray, Dataset -from xarray.datatree_.datatree import DataTree +from xarray.core.datatree import DataTree from xarray.tests import create_test_data, requires_dask diff --git a/xarray/datatree_/datatree/tests/test_datatree.py b/xarray/tests/test_datatree.py similarity index 80% rename from xarray/datatree_/datatree/tests/test_datatree.py rename to xarray/tests/test_datatree.py index cfb57470651..c7359b3929e 100644 --- a/xarray/datatree_/datatree/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2,29 +2,30 @@ import numpy as np import pytest + import xarray as xr +import xarray.datatree_.datatree.testing as dtt import xarray.testing as xrt +from xarray.core.datatree import DataTree +from xarray.core.treenode import NotFoundInTreeError from xarray.tests import create_test_data, source_ndarray -import xarray.datatree_.datatree.testing as dtt -from xarray.datatree_.datatree import DataTree, NotFoundInTreeError - class TestTreeCreation: def test_empty(self): - dt = DataTree(name="root") + dt: DataTree = DataTree(name="root") assert dt.name == "root" assert dt.parent is None assert dt.children == {} xrt.assert_identical(dt.to_dataset(), xr.Dataset()) def test_unnamed(self): - dt = DataTree() + dt: DataTree = DataTree() assert dt.name is None def test_bad_names(self): with pytest.raises(TypeError): - DataTree(name=5) + DataTree(name=5) # type: ignore[arg-type] with pytest.raises(ValueError): DataTree(name="folder/data") @@ -32,7 +33,7 @@ def test_bad_names(self): class TestFamilyTree: def test_setparent_unnamed_child_node_fails(self): - john = DataTree(name="john") + john: DataTree = DataTree(name="john") with pytest.raises(ValueError, match="unnamed"): DataTree(parent=john) @@ -40,8 +41,8 @@ def test_create_two_children(self): root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) set1_data = xr.Dataset({"a": 0, "b": 1}) - root = DataTree(data=root_data) - set1 = DataTree(name="set1", parent=root, data=set1_data) + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) DataTree(name="set1", parent=root) DataTree(name="set2", parent=set1) @@ -50,11 +51,11 @@ def test_create_full_tree(self, simple_datatree): set1_data = xr.Dataset({"a": 0, "b": 1}) set2_data = xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}) - root = DataTree(data=root_data) - set1 = DataTree(name="set1", parent=root, data=set1_data) + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) DataTree(name="set1", parent=set1) DataTree(name="set2", parent=set1) - set2 = DataTree(name="set2", parent=root, data=set2_data) + set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) DataTree(name="set1", parent=set2) DataTree(name="set3", parent=root) @@ -64,36 +65,36 @@ def test_create_full_tree(self, simple_datatree): class TestNames: def test_child_gets_named_on_attach(self): - sue = DataTree() - mary = DataTree(children={"Sue": sue}) # noqa + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) # noqa assert sue.name == "Sue" class TestPaths: def test_path_property(self): - sue = DataTree() - mary = DataTree(children={"Sue": sue}) - john = DataTree(children={"Mary": mary}) # noqa + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + john: DataTree = DataTree(children={"Mary": mary}) assert sue.path == "/Mary/Sue" assert john.path == "/" def test_path_roundtrip(self): - sue = DataTree() - mary = DataTree(children={"Sue": sue}) - john = DataTree(children={"Mary": mary}) # noqa + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + john: DataTree = DataTree(children={"Mary": mary}) assert john[sue.path] is sue def test_same_tree(self): - mary = DataTree() - kate = DataTree() - john = DataTree(children={"Mary": mary, "Kate": kate}) # noqa + mary: DataTree = DataTree() + kate: DataTree = DataTree() + john: DataTree = DataTree(children={"Mary": mary, "Kate": kate}) # noqa assert mary.same_tree(kate) def test_relative_paths(self): - sue = DataTree() - mary = DataTree(children={"Sue": sue}) - annie = DataTree() - john = DataTree(children={"Mary": mary, "Annie": annie}) + sue: DataTree = DataTree() + mary: DataTree = DataTree(children={"Sue": sue}) + annie: DataTree = DataTree() + john: DataTree = DataTree(children={"Mary": mary, "Annie": annie}) result = sue.relative_to(john) assert result == "Mary/Sue" @@ -102,7 +103,7 @@ def test_relative_paths(self): assert sue.relative_to(annie) == "../Mary/Sue" assert sue.relative_to(sue) == "." - evil_kate = DataTree() + evil_kate: DataTree = DataTree() with pytest.raises( NotFoundInTreeError, match="nodes do not lie within the same tree" ): @@ -112,116 +113,117 @@ def test_relative_paths(self): class TestStoreDatasets: def test_create_with_data(self): dat = xr.Dataset({"a": 0}) - john = DataTree(name="john", data=dat) + john: DataTree = DataTree(name="john", data=dat) xrt.assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): - DataTree(name="mary", parent=john, data="junk") # noqa + DataTree(name="mary", parent=john, data="junk") # type: ignore[arg-type] def test_set_data(self): - john = DataTree(name="john") + john: DataTree = DataTree(name="john") dat = xr.Dataset({"a": 0}) - john.ds = dat + john.ds = dat # type: ignore[assignment] xrt.assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): - john.ds = "junk" + john.ds = "junk" # type: ignore[assignment] def test_has_data(self): - john = DataTree(name="john", data=xr.Dataset({"a": 0})) + john: DataTree = DataTree(name="john", data=xr.Dataset({"a": 0})) assert john.has_data - john = DataTree(name="john", data=None) - assert not john.has_data + john_no_data: DataTree = DataTree(name="john", data=None) + assert not john_no_data.has_data def test_is_hollow(self): - john = DataTree(data=xr.Dataset({"a": 0})) + john: DataTree = DataTree(data=xr.Dataset({"a": 0})) assert john.is_hollow - eve = DataTree(children={"john": john}) + eve: DataTree = DataTree(children={"john": john}) assert eve.is_hollow - eve.ds = xr.Dataset({"a": 1}) + eve.ds = xr.Dataset({"a": 1}) # type: ignore[assignment] assert not eve.is_hollow class TestVariablesChildrenNameCollisions: def test_parent_already_has_variable_with_childs_name(self): - dt = DataTree(data=xr.Dataset({"a": [0], "b": 1})) + dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) with pytest.raises(KeyError, match="already contains a data variable named a"): DataTree(name="a", data=None, parent=dt) def test_assign_when_already_child_with_variables_name(self): - dt = DataTree(data=None) + dt: DataTree = DataTree(data=None) DataTree(name="a", data=None, parent=dt) with pytest.raises(KeyError, match="names would collide"): - dt.ds = xr.Dataset({"a": 0}) + dt.ds = xr.Dataset({"a": 0}) # type: ignore[assignment] - dt.ds = xr.Dataset() + dt.ds = xr.Dataset() # type: ignore[assignment] new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) with pytest.raises(KeyError, match="names would collide"): - dt.ds = new_ds + dt.ds = new_ds # type: ignore[assignment] -class TestGet: - ... +class TestGet: ... class TestGetItem: def test_getitem_node(self): - folder1 = DataTree(name="folder1") - results = DataTree(name="results", parent=folder1) - highres = DataTree(name="highres", parent=results) + folder1: DataTree = DataTree(name="folder1") + results: DataTree = DataTree(name="results", parent=folder1) + highres: DataTree = DataTree(name="highres", parent=results) assert folder1["results"] is results assert folder1["results/highres"] is highres def test_getitem_self(self): - dt = DataTree() + dt: DataTree = DataTree() assert dt["."] is dt def test_getitem_single_data_variable(self): data = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results", data=data) + results: DataTree = DataTree(name="results", data=data) xrt.assert_identical(results["temp"], data["temp"]) def test_getitem_single_data_variable_from_node(self): data = xr.Dataset({"temp": [0, 50]}) - folder1 = DataTree(name="folder1") - results = DataTree(name="results", parent=folder1) + folder1: DataTree = DataTree(name="folder1") + results: DataTree = DataTree(name="results", parent=folder1) DataTree(name="highres", parent=results, data=data) xrt.assert_identical(folder1["results/highres/temp"], data["temp"]) def test_getitem_nonexistent_node(self): - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") DataTree(name="results", parent=folder1) with pytest.raises(KeyError): folder1["results/highres"] def test_getitem_nonexistent_variable(self): data = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results", data=data) + results: DataTree = DataTree(name="results", data=data) with pytest.raises(KeyError): results["pressure"] @pytest.mark.xfail(reason="Should be deprecated in favour of .subset") def test_getitem_multiple_data_variables(self): data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) - results = DataTree(name="results", data=data) - xrt.assert_identical(results[["temp", "p"]], data[["temp", "p"]]) + results: DataTree = DataTree(name="results", data=data) + xrt.assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] - @pytest.mark.xfail(reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)") + @pytest.mark.xfail( + reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)" + ) def test_getitem_dict_like_selection_access_to_dataset(self): data = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results", data=data) - xrt.assert_identical(results[{"temp": 1}], data[{"temp": 1}]) + results: DataTree = DataTree(name="results", data=data) + xrt.assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] class TestUpdate: def test_update(self): - dt = DataTree() + dt: DataTree = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree()}) expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None}) print(dt) @@ -233,13 +235,13 @@ def test_update(self): def test_update_new_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1.update({"results": da}) expected = da.rename("results") xrt.assert_equal(folder1["results"], expected) def test_update_doesnt_alter_child_name(self): - dt = DataTree() + dt: DataTree = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree(name="b")}) assert "a" in dt.children child = dt["a"] @@ -336,8 +338,8 @@ def test_copy_with_data(self, create_test_datatree): class TestSetItem: def test_setitem_new_child_node(self): - john = DataTree(name="john") - mary = DataTree(name="mary") + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary") john["mary"] = mary grafted_mary = john["mary"] @@ -345,14 +347,14 @@ def test_setitem_new_child_node(self): assert grafted_mary.name == "mary" def test_setitem_unnamed_child_node_becomes_named(self): - john2 = DataTree(name="john2") + john2: DataTree = DataTree(name="john2") john2["sonny"] = DataTree() assert john2["sonny"].name == "sonny" def test_setitem_new_grandchild_node(self): - john = DataTree(name="john") - mary = DataTree(name="mary", parent=john) - rose = DataTree(name="rose") + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary", parent=john) + rose: DataTree = DataTree(name="rose") john["mary/rose"] = rose grafted_rose = john["mary/rose"] @@ -360,98 +362,97 @@ def test_setitem_new_grandchild_node(self): assert grafted_rose.name == "rose" def test_grafted_subtree_retains_name(self): - subtree = DataTree(name="original_subtree_name") - root = DataTree(name="root") + subtree: DataTree = DataTree(name="original_subtree_name") + root: DataTree = DataTree(name="root") root["new_subtree_name"] = subtree # noqa assert subtree.name == "original_subtree_name" def test_setitem_new_empty_node(self): - john = DataTree(name="john") + john: DataTree = DataTree(name="john") john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) xrt.assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self): - john = DataTree(name="john") - mary = DataTree(name="mary", parent=john, data=xr.Dataset()) + john: DataTree = DataTree(name="john") + mary: DataTree = DataTree(name="mary", parent=john, data=xr.Dataset()) john["mary"] = DataTree() xrt.assert_identical(mary.to_dataset(), xr.Dataset()) - john.ds = xr.Dataset() + john.ds = xr.Dataset() # type: ignore[assignment] with pytest.raises(ValueError, match="has no name"): john["."] = DataTree() @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_on_this_node(self): data = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results") + results: DataTree = DataTree(name="results") results["."] = data xrt.assert_identical(results.to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node(self): data = xr.Dataset({"temp": [0, 50]}) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = data xrt.assert_identical(folder1["results"].to_dataset(), data) @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self): data = xr.Dataset({"temp": [0, 50]}) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results/highres"] = data xrt.assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self): da = xr.DataArray(name="temp", data=[0, 50]) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = da expected = da.rename("results") xrt.assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self): data = xr.DataArray([0, 50]) - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = data xrt.assert_equal(folder1["results"], data) def test_setitem_variable(self): var = xr.Variable(data=[0, 50], dims="x") - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = var xrt.assert_equal(folder1["results"], xr.DataArray(var)) def test_setitem_coerce_to_dataarray(self): - folder1 = DataTree(name="folder1") + folder1: DataTree = DataTree(name="folder1") folder1["results"] = 0 xrt.assert_equal(folder1["results"], xr.DataArray(0)) def test_setitem_add_new_variable_to_empty_node(self): - results = DataTree(name="results") + results: DataTree = DataTree(name="results") results["pressure"] = xr.DataArray(data=[2, 3]) assert "pressure" in results.ds results["temp"] = xr.Variable(data=[10, 11], dims=["x"]) assert "temp" in results.ds # What if there is a path to traverse first? - results = DataTree(name="results") - results["highres/pressure"] = xr.DataArray(data=[2, 3]) - assert "pressure" in results["highres"].ds - results["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) - assert "temp" in results["highres"].ds + results_with_path: DataTree = DataTree(name="results") + results_with_path["highres/pressure"] = xr.DataArray(data=[2, 3]) + assert "pressure" in results_with_path["highres"].ds + results_with_path["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) + assert "temp" in results_with_path["highres"].ds def test_setitem_dataarray_replace_existing_node(self): t = xr.Dataset({"temp": [0, 50]}) - results = DataTree(name="results", data=t) + results: DataTree = DataTree(name="results", data=t) p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) xrt.assert_identical(results.to_dataset(), expected) -class TestDictionaryInterface: - ... +class TestDictionaryInterface: ... class TestTreeFromDict: @@ -501,8 +502,8 @@ def test_full(self, simple_datatree): ] def test_datatree_values(self): - dat1 = DataTree(data=xr.Dataset({"a": 1})) - expected = DataTree() + dat1: DataTree = DataTree(data=xr.Dataset({"a": 1})) + expected: DataTree = DataTree() expected["a"] = dat1 actual = DataTree.from_dict({"a": dat1}) @@ -527,7 +528,7 @@ def test_roundtrip_unnamed_root(self, simple_datatree): class TestDatasetView: def test_view_contents(self): ds = create_test_data() - dt = DataTree(data=ds) + dt: DataTree = DataTree(data=ds) assert ds.identical( dt.ds ) # this only works because Dataset.identical doesn't check types @@ -535,7 +536,7 @@ def test_view_contents(self): def test_immutability(self): # See issue https://github.com/xarray-contrib/datatree/issues/38 - dt = DataTree(name="root", data=None) + dt: DataTree = DataTree(name="root", data=None) DataTree(name="a", data=None, parent=dt) with pytest.raises( @@ -553,7 +554,7 @@ def test_immutability(self): def test_methods(self): ds = create_test_data() - dt = DataTree(data=ds) + dt: DataTree = DataTree(data=ds) assert ds.mean().identical(dt.ds.mean()) assert type(dt.ds.mean()) == xr.Dataset @@ -572,7 +573,7 @@ def test_init_via_type(self): dims=["x", "y", "time"], coords={"area": (["x", "y"], np.random.rand(3, 4))}, ).to_dataset(name="data") - dt = DataTree(data=a) + dt: DataTree = DataTree(data=a) def weighted_mean(ds): return ds.weighted(ds.area).mean(["x", "y"]) @@ -643,7 +644,7 @@ def test_drop_nodes(self): assert childless.children == {} def test_assign(self): - dt = DataTree() + dt: DataTree = DataTree() expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "/a": None}) # kwargs form @@ -727,5 +728,5 @@ def test_filter(self): }, name="Abe", ) - elders = simpsons.filter(lambda node: node["age"] > 18) + elders = simpsons.filter(lambda node: node["age"].item() > 18) dtt.assert_identical(elders, expected)