diff --git a/ci/minimum_versions.py b/ci/minimum_versions.py index 08808d002d9..21123bffcd6 100644 --- a/ci/minimum_versions.py +++ b/ci/minimum_versions.py @@ -30,6 +30,7 @@ "coveralls", "pip", "pytest", + "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mypy-plugins", diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index 5f5db4a0f18..65780d91949 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -28,6 +28,7 @@ dependencies: - pip - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/all-but-numba.yml b/ci/requirements/all-but-numba.yml index 7c492aec704..cf62b42b41e 100644 --- a/ci/requirements/all-but-numba.yml +++ b/ci/requirements/all-but-numba.yml @@ -41,6 +41,7 @@ dependencies: - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index 02e99d34af2..cc34a6e4824 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -7,6 +7,7 @@ dependencies: - coveralls - pip - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-3.14.yml b/ci/requirements/environment-3.14.yml index 06c4df82663..d4d47d85536 100644 --- a/ci/requirements/environment-3.14.yml +++ b/ci/requirements/environment-3.14.yml @@ -37,6 +37,7 @@ dependencies: - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-windows-3.14.yml b/ci/requirements/environment-windows-3.14.yml index dd48add6b73..e86d57beb95 100644 --- a/ci/requirements/environment-windows-3.14.yml +++ b/ci/requirements/environment-windows-3.14.yml @@ -32,6 +32,7 @@ dependencies: - pyarrow # importing dask.dataframe raises an ImportError without this - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 3213ef687d3..7c0d4dd9231 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -32,6 +32,7 @@ dependencies: - pyarrow # importing dask.dataframe raises an ImportError without this - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index fc54b6600fe..1deffcaeacd 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -39,6 +39,7 @@ dependencies: - pydap - pydap-server - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index 03e14773d53..1293f4d78d6 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -44,6 +44,7 @@ dependencies: - pip - pydap=3.5 - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 9a6037cf3c4..98d3704de9b 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -228,6 +228,7 @@ Variable.isnull Variable.item Variable.load + Variable.load_async Variable.max Variable.mean Variable.median diff --git a/doc/api.rst b/doc/api.rst index b6023866eb8..80715555e56 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1122,6 +1122,7 @@ Dataset methods Dataset.filter_by_attrs Dataset.info Dataset.load + Dataset.load_async Dataset.persist Dataset.unify_chunks @@ -1154,6 +1155,7 @@ DataArray methods DataArray.compute DataArray.persist DataArray.load + DataArray.load_async DataArray.unify_chunks DataTree methods diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index d3b5c3a9267..b5dfe3b5f8e 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -331,10 +331,12 @@ information on plugins. How to support lazy loading +++++++++++++++++++++++++++ -If you want to make your backend effective with big datasets, then you should -support lazy loading. -Basically, you shall replace the :py:class:`numpy.ndarray` inside the -variables with a custom class that supports lazy loading indexing. +If you want to make your backend effective with big datasets, then you should take advantage of xarray's +support for lazy loading and indexing. + +Basically, when your backend constructs the ``Variable`` objects, +you need to replace the :py:class:`numpy.ndarray` inside the +variables with a custom :py:class:`~xarray.backends.BackendArray` subclass that supports lazy loading and indexing. See the example below: .. code-block:: python @@ -345,25 +347,27 @@ See the example below: Where: -- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a class - provided by Xarray that manages the lazy loading. -- ``MyBackendArray`` shall be implemented by the backend and shall inherit +- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a wrapper class + provided by Xarray that manages the lazy loading and indexing. +- ``MyBackendArray`` should be implemented by the backend and must inherit from :py:class:`~xarray.backends.BackendArray`. BackendArray subclassing ^^^^^^^^^^^^^^^^^^^^^^^^ -The BackendArray subclass shall implement the following method and attributes: +The BackendArray subclass must implement the following method and attributes: -- the ``__getitem__`` method that takes in input an index and returns a - `NumPy `__ array -- the ``shape`` attribute +- the ``__getitem__`` method that takes an index as an input and returns a + `NumPy `__ array, +- the ``shape`` attribute, - the ``dtype`` attribute. -Xarray supports different type of :doc:`/user-guide/indexing`, that can be -grouped in three types of indexes +It may also optionally implement an additional ``async_getitem`` method. + +Xarray supports different types of :doc:`/user-guide/indexing`, that can be +grouped in three types of indexes: :py:class:`~xarray.core.indexing.BasicIndexer`, -:py:class:`~xarray.core.indexing.OuterIndexer` and +:py:class:`~xarray.core.indexing.OuterIndexer`, and :py:class:`~xarray.core.indexing.VectorizedIndexer`. This implies that the implementation of the method ``__getitem__`` can be tricky. In order to simplify this task, Xarray provides a helper function, @@ -419,8 +423,22 @@ input the ``key``, the array ``shape`` and the following parameters: For more details see :py:class:`~xarray.core.indexing.IndexingSupport` and :ref:`RST indexing`. +Async support +^^^^^^^^^^^^^ + +Backends can also optionally support loading data asynchronously via xarray's asynchronous loading methods +(e.g. ``~xarray.Dataset.load_async``). +To support async loading the ``BackendArray`` subclass must additionally implement the ``BackendArray.async_getitem`` method. + +Note that implementing this method is only necessary if you want to be able to load data from different xarray objects concurrently. +Even without this method your ``BackendArray`` implementation is still free to concurrently load chunks of data for a single ``Variable`` itself, +so long as it does so behind the synchronous ``__getitem__`` interface. + +Dask support +^^^^^^^^^^^^ + In order to support `Dask Distributed `__ and -:py:mod:`multiprocessing`, ``BackendArray`` subclass should be serializable +:py:mod:`multiprocessing`, the ``BackendArray`` subclass should be serializable either with :ref:`io.pickle` or `cloudpickle `__. That implies that all the reference to open files should be dropped. For diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 607661ed30b..8960b1085a2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,13 @@ v2025.07.0 (unreleased) New Features ~~~~~~~~~~~~ +- Added new asynchronous loading methods :py:meth:`~xarray.Dataset.load_async`, :py:meth:`~xarray.DataArray.load_async`, :py:meth:`~xarray.Variable.load_async`. + (:issue:`10326`, :pull:`10327`) By `Tom Nicholas `_. +- Allow an Xarray index that uses multiple dimensions checking equality with another + index for only a subset of those dimensions (i.e., ignoring the dimensions + that are excluded from alignment). + (:issue:`10243`, :pull:`10293`) + By `Benoit Bovy `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -28,6 +35,8 @@ Bug fixes By `Kai Mühlbauer `_. - Fix the SciPy backend for netCDF3 files . (:issue:`8909`, :pull:`10376`) By `Deepak Cherian `_. +- Allow accessing arbitrary attributes on Pandas ExtensionArrays. + By `Deepak Cherian `_. - Check and fix character array string dimension names, issue warnings as needed (:issue:`6352`, :pull:`10395`). By `Kai Mühlbauer `_. diff --git a/pyproject.toml b/pyproject.toml index c980c204b5f..9995eec97a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dev = [ "pytest-mypy-plugins", "pytest-timeout", "pytest-xdist", + "pytest-asyncio", "ruff>=0.8.0", "sphinx", "sphinx_autosummary_accessors", diff --git a/xarray/backends/common.py b/xarray/backends/common.py index e1f8dc5cecd..f64aa001c22 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -270,10 +270,17 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): __slots__ = () + async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: + raise NotImplementedError("Backend does not not support asynchronous loading") + def get_duck_array(self, dtype: np.typing.DTypeLike = None): key = indexing.BasicIndexer((slice(None),) * self.ndim) return self[key] # type: ignore[index] + async def async_get_duck_array(self, dtype: np.typing.DTypeLike = None): + key = indexing.BasicIndexer((slice(None),) * self.ndim) + return await self.async_getitem(key) # type: ignore[index] + class AbstractDataStore: __slots__ = () diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 54ff419b2f2..1d8d9ff46a4 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -186,6 +186,8 @@ class ZarrArrayWrapper(BackendArray): def __init__(self, zarr_array): # some callers attempt to evaluate an array if an `array` property exists on the object. # we prefix with _ to avoid this inference. + + # TODO type hint this? self._array = zarr_array self.shape = self._array.shape @@ -213,6 +215,18 @@ def _vindex(self, key): def _getitem(self, key): return self._array[key] + async def _async_getitem(self, key): + async_array = self._array._async_array + return await async_array.getitem(key) + + async def _async_oindex(self, key): + async_array = self._array._async_array + return await async_array.oindex.getitem(key) + + async def _async_vindex(self, key): + async_array = self._array._async_array + return await async_array.vindex.getitem(key) + def __getitem__(self, key): array = self._array if isinstance(key, indexing.BasicIndexer): @@ -228,6 +242,19 @@ def __getitem__(self, key): # if self.ndim == 0: # could possibly have a work-around for 0d data here + async def async_getitem(self, key): + print("async getting") + array = self._array + if isinstance(key, indexing.BasicIndexer): + method = self._async_getitem + elif isinstance(key, indexing.VectorizedIndexer): + method = self._async_vindex + elif isinstance(key, indexing.OuterIndexer): + method = self._async_oindex + return await indexing.async_explicit_indexing_adapter( + key, array.shape, indexing.IndexingSupport.VECTORIZED, method + ) + def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): """ diff --git a/xarray/coding/common.py b/xarray/coding/common.py index 0e8d7e1955e..79e5e7502b3 100644 --- a/xarray/coding/common.py +++ b/xarray/coding/common.py @@ -79,6 +79,9 @@ def __getitem__(self, key): def get_duck_array(self): return self.func(self.array.get_duck_array()) + async def async_get_duck_array(self): + return self.func(await self.array.async_get_duck_array()) + def __repr__(self) -> str: return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})" diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c13d33872b6..4bb2e06c4e7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1160,6 +1160,14 @@ def load(self, **kwargs) -> Self: self._coords = new._coords return self + async def load_async(self, **kwargs) -> Self: + temp_ds = self._to_temp_dataset() + ds = await temp_ds.load_async(**kwargs) + new = self._from_temp_dataset(ds) + self._variable = new._variable + self._coords = new._coords + return self + def compute(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 367da2f60a5..cab08b6230c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import copy import datetime import math @@ -532,24 +533,50 @@ def load(self, **kwargs) -> Self: dask.compute """ # access .data to coerce everything to numpy or dask arrays - lazy_data = { + chunked_data = { k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) } - if lazy_data: - chunkmanager = get_chunked_array_type(*lazy_data.values()) + if chunked_data: + chunkmanager = get_chunked_array_type(*chunked_data.values()) # evaluate all the chunked arrays simultaneously evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( - *lazy_data.values(), **kwargs + *chunked_data.values(), **kwargs ) - for k, data in zip(lazy_data, evaluated_data, strict=False): + for k, data in zip(chunked_data, evaluated_data, strict=False): self.variables[k].data = data # load everything else sequentially - for k, v in self.variables.items(): - if k not in lazy_data: - v.load() + [v.load() for k, v in self.variables.items() if k not in chunked_data] + + return self + + async def load_async(self, **kwargs) -> Self: + # TODO refactor this to pull out the common chunked_data codepath + + # this blocks on chunked arrays but not on lazily indexed arrays + + # access .data to coerce everything to numpy or dask arrays + chunked_data = { + k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) + } + if chunked_data: + chunkmanager = get_chunked_array_type(*chunked_data.values()) + + # evaluate all the chunked arrays simultaneously + evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( + *chunked_data.values(), **kwargs + ) + + for k, data in zip(chunked_data, evaluated_data, strict=False): + self.variables[k].data = data + + # load everything else concurrently + coros = [ + v.load_async() for k, v in self.variables.items() if k not in chunked_data + ] + await asyncio.gather(*coros) return self diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index e14543e646f..fe8aa99691d 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -516,13 +516,31 @@ def get_duck_array(self): return self.array -class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): - __slots__ = () +class IndexingAdapter: + """Marker class for indexing adapters. + + These classes translate between Xarray's indexing semantics and the underlying array's + indexing semantics. + """ def get_duck_array(self): key = BasicIndexer((slice(None),) * self.ndim) return self[key] + async def async_get_duck_array(self): + """These classes are applied to in-memory arrays, so specific async support isn't needed.""" + return self.get_duck_array() + + +class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): + __slots__ = () + + def get_duck_array(self): + raise NotImplementedError + + async def async_get_duck_array(self): + raise NotImplementedError + def _oindex_get(self, indexer: OuterIndexer): raise NotImplementedError( f"{self.__class__.__name__}._oindex_get method should be overridden" @@ -646,19 +664,25 @@ def shape(self) -> _Shape: return self._shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): - array = apply_indexer(self.array, self.key) - else: - # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): array = self.array[self.key] + else: + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) - # self.array[self.key] is now a numpy array when - # self.array is a BackendArray subclass - # and self.key is BasicIndexer((slice(None, None, None),)) - # so we need the explicit check for ExplicitlyIndexed - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() + async def async_get_duck_array(self): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = await self.array.async_getitem(self.key) + else: + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = await array.async_get_duck_array() return _wrap_numpy_scalars(array) def transpose(self, order): @@ -722,18 +746,26 @@ def shape(self) -> _Shape: return np.broadcast(*self.key.tuple).shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = self.array[self.key] + else: array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) + + async def async_get_duck_array(self): + print("inside LazilyVectorizedIndexedArray.async_get_duck_array") + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = await self.array.async_getitem(self.key) else: - # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ - array = self.array[self.key] - # self.array[self.key] is now a numpy array when - # self.array is a BackendArray subclass - # and self.key is BasicIndexer((slice(None, None, None),)) - # so we need the explicit check for ExplicitlyIndexed - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = await array.async_get_duck_array() return _wrap_numpy_scalars(array) def _updated_key(self, new_key: ExplicitIndexer): @@ -798,6 +830,9 @@ def _ensure_copied(self): def get_duck_array(self): return self.array.get_duck_array() + async def async_get_duck_array(self): + return await self.array.async_get_duck_array() + def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -838,12 +873,17 @@ class MemoryCachedArray(ExplicitlyIndexedNDArrayMixin): def __init__(self, array): self.array = _wrap_numpy_scalars(as_indexable(array)) - def _ensure_cached(self): - self.array = as_indexable(self.array.get_duck_array()) - def get_duck_array(self): - self._ensure_cached() - return self.array.get_duck_array() + duck_array = self.array.get_duck_array() + # ensure the array object is cached in-memory + self.array = as_indexable(duck_array) + return duck_array + + async def async_get_duck_array(self): + duck_array = await self.array.async_get_duck_array() + # ensure the array object is cached in-memory + self.array = as_indexable(duck_array) + return duck_array def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -1028,6 +1068,21 @@ def explicit_indexing_adapter( return result +async def async_explicit_indexing_adapter( + key: ExplicitIndexer, + shape: _Shape, + indexing_support: IndexingSupport, + raw_indexing_method: Callable[..., Any], +) -> Any: + raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support) + result = await raw_indexing_method(raw_key.tuple) + if numpy_indices.tuple: + # index the loaded duck array + indexable = as_indexable(result) + result = apply_indexer(indexable, numpy_indices) + return result + + def apply_indexer(indexable, indexer: ExplicitIndexer): """Apply an indexer to an indexable object.""" if isinstance(indexer, VectorizedIndexer): @@ -1527,7 +1582,7 @@ def is_fancy_indexer(indexer: Any) -> bool: return True -class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class NumpyIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a NumPy array to use explicit indexing.""" __slots__ = ("array",) @@ -1606,7 +1661,7 @@ def __init__(self, array): self.array = array -class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class ArrayApiIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap an array API array to use explicit indexing.""" __slots__ = ("array",) @@ -1671,7 +1726,7 @@ def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None: ) -class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class DaskIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a dask array to support explicit indexing.""" __slots__ = ("array",) @@ -1747,7 +1802,7 @@ def transpose(self, order): return self.array.transpose(order) -class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class PandasIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" __slots__ = ("_dtype", "array") @@ -2068,7 +2123,9 @@ def copy(self, deep: bool = True) -> Self: return type(self)(array, self._dtype, self.level) -class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class CoordinateTransformIndexingAdapter( + IndexingAdapter, ExplicitlyIndexedNDArrayMixin +): """Wrap a CoordinateTransform as a lazy coordinate array. Supports explicit indexing (both outer and vectorized). diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9c753a2ffa7..45ca533d957 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -47,6 +47,7 @@ from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import ( + async_to_duck_array, integer_types, is_0d_dask_array, is_chunked_array, @@ -970,6 +971,10 @@ def load(self, **kwargs): self._data = to_duck_array(self._data, **kwargs) return self + async def load_async(self, **kwargs): + self._data = await async_to_duck_array(self._data, **kwargs) + return self + def compute(self, **kwargs): """Manually trigger loading of this variable's data from disk or a remote source into memory and return a new variable. The original is diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py index 68b6a7853bf..6e61d3445ab 100644 --- a/xarray/namedarray/pycompat.py +++ b/xarray/namedarray/pycompat.py @@ -145,3 +145,23 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, return data else: return np.asarray(data) # type: ignore[return-value] + + +async def async_to_duck_array( + data: Any, **kwargs: dict[str, Any] +) -> duckarray[_ShapeType, _DType]: + from xarray.core.indexing import ( + ExplicitlyIndexed, + ImplicitToExplicitIndexingAdapter, + IndexingAdapter, + ) + + print(type(data)) + if isinstance(data, IndexingAdapter): + # These wrap in-memory arrays, and async isn't needed + return data.get_duck_array() + elif isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter): + print("async inside to_duck_array") + return await data.async_get_duck_array() # type: ignore[no-untyped-call, no-any-return] + else: + return to_duck_array(data, **kwargs) diff --git a/xarray/tests/test_async.py b/xarray/tests/test_async.py new file mode 100644 index 00000000000..8a3ca76e16d --- /dev/null +++ b/xarray/tests/test_async.py @@ -0,0 +1,201 @@ +import asyncio +from typing import Literal, TypeVar +from unittest.mock import patch + +import pytest + +import xarray as xr +import xarray.testing as xrt +from xarray.tests import has_zarr_v3, requires_zarr_v3 +from xarray.tests.test_dataset import create_test_data + +if has_zarr_v3: + import zarr + from zarr.abc.store import ByteRequest, Store + from zarr.core.buffer import Buffer, BufferPrototype + from zarr.storage import MemoryStore + from zarr.storage._wrapper import WrapperStore + + T_Store = TypeVar("T_Store", bound=Store) + + class ReadOnlyStore(WrapperStore[T_Store]): + """ + We shouldn't need this - but we currently do just as a way around https://github.com/zarr-developers/zarr-python/issues/3105#issuecomment-2990367167 + + Works the same way as the zarr LoggingStore. + """ + + read_only = True + + def __init__( + self, + store: T_Store, + ) -> None: + super().__init__(store) + + async def get( + self, + key: str, + prototype: BufferPrototype, + byte_range: ByteRequest | None = None, + ) -> Buffer | None: + return await self._store.get( + key=key, prototype=prototype, byte_range=byte_range + ) + +else: + ReadOnlyStore = {} + + +@pytest.fixture +def memorystore() -> "MemoryStore": + memorystore = zarr.storage.MemoryStore({}) + + ds = create_test_data() + ds.to_zarr(memorystore, zarr_format=3, consolidated=False) + + return memorystore + + +@pytest.fixture +def store(memorystore) -> "zarr.abc.store.Store": + # TODO we shouldn't this Store at all for the patched tests, but we currently use it just as a way around https://github.com/zarr-developers/zarr-python/issues/3105#issuecomment-2990367167 + return ReadOnlyStore(memorystore) + + +def get_xr_obj( + store: "zarr.abc.store.Store", cls_name: Literal["Variable", "DataArray", "Dataset"] +): + ds = xr.open_zarr(store, zarr_format=3, consolidated=False, chunks=None) + + match cls_name: + case "Variable": + return ds["var1"].variable + case "DataArray": + return ds["var1"] + case "Dataset": + return ds + + +@requires_zarr_v3 +@pytest.mark.asyncio +class TestAsyncLoad: + async def test_concurrent_load_multiple_variables(self, store) -> None: + target_class = zarr.AsyncArray + method_name = "getitem" + original_method = getattr(target_class, method_name) + + # TODO up the number of variables in the dataset? + # the coordinate variable is not lazy + N_LAZY_VARS = 1 + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + # blocks upon loading the coordinate variables here + ds = xr.open_zarr(store, zarr_format=3, consolidated=False, chunks=None) + + # TODO we're not actually testing that these indexing methods are not blocking... + result_ds = await ds.load_async() + + mocked_meth.assert_called() + assert mocked_meth.call_count >= N_LAZY_VARS + mocked_meth.assert_awaited() + + xrt.assert_identical(result_ds, ds.load()) + + @pytest.mark.parametrize("cls_name", ["Variable", "DataArray", "Dataset"]) + async def test_concurrent_load_multiple_objects(self, store, cls_name) -> None: + N_OBJECTS = 5 + + target_class = zarr.AsyncArray + method_name = "getitem" + original_method = getattr(target_class, method_name) + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + xr_obj = get_xr_obj(store, cls_name) + + # TODO we're not actually testing that these indexing methods are not blocking... + coros = [xr_obj.load_async() for _ in range(N_OBJECTS)] + results = await asyncio.gather(*coros) + + mocked_meth.assert_called() + assert mocked_meth.call_count >= N_OBJECTS + mocked_meth.assert_awaited() + + for result in results: + xrt.assert_identical(result, xr_obj.load()) + + @pytest.mark.parametrize("cls_name", ["Variable", "DataArray", "Dataset"]) + @pytest.mark.parametrize( + "indexer, method, zarr_class_and_method", + [ + ({}, "sel", (zarr.AsyncArray, "getitem")), + ({}, "isel", (zarr.AsyncArray, "getitem")), + ({"dim2": 1.0}, "sel", (zarr.AsyncArray, "getitem")), + ({"dim2": 2}, "isel", (zarr.AsyncArray, "getitem")), + ({"dim2": slice(1.0, 3.0)}, "sel", (zarr.AsyncArray, "getitem")), + ({"dim2": slice(1, 3)}, "isel", (zarr.AsyncArray, "getitem")), + ({"dim2": [1.0, 3.0]}, "sel", (zarr.core.indexing.AsyncOIndex, "getitem")), + ({"dim2": [1, 3]}, "isel", (zarr.core.indexing.AsyncOIndex, "getitem")), + ( + { + "dim1": xr.Variable(data=[2, 3], dims="points"), + "dim2": xr.Variable(data=[1.0, 2.0], dims="points"), + }, + "sel", + (zarr.core.indexing.AsyncVIndex, "getitem"), + ), + ( + { + "dim1": xr.Variable(data=[2, 3], dims="points"), + "dim2": xr.Variable(data=[1, 3], dims="points"), + }, + "isel", + (zarr.core.indexing.AsyncVIndex, "getitem"), + ), + ], + ids=[ + "no-indexing-sel", + "no-indexing-isel", + "basic-int-sel", + "basic-int-isel", + "basic-slice-sel", + "basic-slice-isel", + "outer-sel", + "outer-isel", + "vectorized-sel", + "vectorized-isel", + ], + ) + async def test_indexing( + self, + store, + cls_name, + method, + indexer, + zarr_class_and_method, + ) -> None: + if cls_name == "Variable" and method == "sel": + pytest.skip("Variable doesn't have a .sel method") + + # each type of indexing ends up calling a different zarr indexing method + target_class, method_name = zarr_class_and_method + original_method = getattr(target_class, method_name) + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + xr_obj = get_xr_obj(store, cls_name) + + # TODO we're not actually testing that these indexing methods are not blocking... + result = await getattr(xr_obj, method)(**indexer).load_async() + + mocked_meth.assert_called() + mocked_meth.assert_awaited() + assert mocked_meth.call_count > 0 + + expected = getattr(xr_obj, method)(**indexer).load() + xrt.assert_identical(result, expected) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6dd75b58c6a..010987337a6 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -490,6 +490,25 @@ def test_sub_array(self) -> None: assert isinstance(child.array, indexing.NumpyIndexingAdapter) assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + @pytest.mark.asyncio + async def test_async_wrapper(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + await wrapped.async_get_duck_array() + assert_array_equal(wrapped, np.arange(10)) + assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter) + + @pytest.mark.asyncio + async def test_async_sub_array(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + child = wrapped[B[:5]] + assert isinstance(child, indexing.MemoryCachedArray) + await child.async_get_duck_array() + assert_array_equal(child, np.arange(5)) + assert isinstance(child.array, indexing.NumpyIndexingAdapter) + assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + def test_setitem(self) -> None: original = np.arange(10) wrapped = indexing.MemoryCachedArray(original)