diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index c959a7f2536..132e6553412 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -4,6 +4,7 @@ from typing import Any import numpy as np +import pandas as pd from pandas.api.types import is_extension_array_dtype from xarray.compat import array_api_compat, npcompat @@ -63,7 +64,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: # N.B. these casting rules should match pandas dtype_: np.typing.DTypeLike fill_value: Any - if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): + if pd.api.types.is_extension_array_dtype(dtype): + return dtype, pd.NA + elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): # for now, we always promote string dtypes to object for consistency with existing behavior # TODO: refactor this once we have a better way to handle numpy vlen-string dtypes dtype_ = object diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e98ac0f36a1..f659c257aee 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -272,18 +272,35 @@ def as_shared_dtype(scalars_or_arrays, xp=None): extension_array_types = [ x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) ] - non_nans = [x for x in scalars_or_arrays if not isna(x)] - if len(extension_array_types) == len(non_nans) and all( + non_nans_or_scalar = [ + x for x in scalars_or_arrays if not (isna(x) or np.isscalar(x)) + ] + if len(extension_array_types) == len(non_nans_or_scalar) and all( isinstance(x, type(extension_array_types[0])) for x in extension_array_types ): - return [ + # Get the extension array class of the first element, guaranteed to be the same + # as the others thanks to the anove check. + extension_array_class = type( + non_nans_or_scalar[0].array + if isinstance(non_nans_or_scalar[0], PandasExtensionArray) + else non_nans_or_scalar[0] + ) + # Cast scalars/nans to extension array class + arrays_with_nan_to_sequence = [ x - if not isna(x) - else PandasExtensionArray( - type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype) + if not (isna(x) or np.isscalar(x)) + else extension_array_class._from_sequence( + [x], dtype=non_nans_or_scalar[0].dtype ) for x in scalars_or_arrays ] + # Wrap the output if necessary + return [ + PandasExtensionArray(x) + if not isinstance(x, PandasExtensionArray) + else x + for x in arrays_with_nan_to_sequence + ] raise ValueError( f"Cannot cast values to shared type, found values: {scalars_or_arrays}" ) @@ -406,7 +423,6 @@ def where(condition, x, y): condition = asarray(condition, dtype=dtype, xp=xp) else: condition = astype(condition, dtype=dtype, xp=xp) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 7cc9db96d0d..49ffe23117f 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -52,6 +52,17 @@ def __extension_duck_array__concatenate( return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] +@implements(np.reshape) +def __extension_duck_array__reshape( + arr: T_ExtensionArray, shape: tuple +) -> T_ExtensionArray: + if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,): + return arr + raise NotImplementedError( + f"Cannot reshape 1d-only pandas extension array to: {shape}" + ) + + @implements(np.where) def __extension_duck_array__where( condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray @@ -134,6 +145,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: + if ( + isinstance(key, tuple) and len(key) == 1 + ): # pyarrow type arrays can't handle since-length tuples + key = key[0] item = self.array[key] if is_extension_array_dtype(item): return PandasExtensionArray(item) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b2b9ae314c4..d4c45fa99eb 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -53,6 +53,7 @@ assert_no_warnings, has_dask, has_dask_ge_2025_1_0, + has_pyarrow, raise_if_dask_computes, requires_bottleneck, requires_cupy, @@ -61,6 +62,7 @@ requires_iris, requires_numexpr, requires_pint, + requires_pyarrow, requires_scipy, requires_sparse, source_ndarray, @@ -3075,6 +3077,58 @@ def test_propagate_attrs(self, func) -> None: with set_options(keep_attrs=True): assert func(da).attrs == da.attrs + @pytest.mark.parametrize( + "fill_value,extension_array", + [ + pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="categorical"), + ] + + [ + pytest.param( + 0, + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], + ) + def test_fillna_extension_array(self, fill_value, extension_array) -> None: + srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) + da = srs.to_xarray() + filled = da.fillna(fill_value) + assert filled.dtype == srs.dtype + assert (filled.values == np.array([fill_value, *(srs.values[1:])])).all() + + @requires_pyarrow + def test_fillna_extension_array_bad_val(self) -> None: + srs: pd.Series = pd.Series( + index=np.array([1, 2, 3]), + data=pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + ) + da = srs.to_xarray() + with pytest.raises(ValueError): + da.fillna("a") + + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="categorical"), + ] + + [ + pytest.param( + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_dropna_extension_array(self, extension_array) -> None: + srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array) + da = srs.to_xarray() + filled = da.dropna("index") + assert filled.dtype == srs.dtype + assert (filled.values == srs.values[1:]).all() + def test_fillna(self) -> None: a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") actual = a.fillna(-1) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b17ea252a58..096de7ec50a 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -70,6 +70,7 @@ requires_dask, requires_numexpr, requires_pint, + requires_pyarrow, requires_scipy, requires_sparse, source_ndarray, @@ -1802,28 +1803,48 @@ def test_categorical_index_reindex(self) -> None: actual = ds.reindex(cat=["foo"])["cat"].values assert (actual == np.array(["foo"])).all() - @pytest.mark.parametrize("fill_value", [np.nan, pd.NA]) - def test_extensionarray_negative_reindex(self, fill_value) -> None: - cat = pd.Categorical( - ["foo", "bar", "baz"], - categories=["foo", "bar", "baz", "qux", "quux", "corge"], - ) + @pytest.mark.parametrize("fill_value", [np.nan, pd.NA, None]) + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param( + pd.Categorical( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux"], + ), + id="categorical", + ), + ] + + [ + pytest.param( + pd.array([1, 1, None], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None: ds = xr.Dataset( - {"cat": ("index", cat)}, + {"arr": ("index", extension_array)}, coords={"index": ("index", np.arange(3))}, ) + kwargs = {} + if fill_value is not None: + kwargs["fill_value"] = fill_value reindexed_cat = cast( pd.api.extensions.ExtensionArray, - ( - ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"] - .to_pandas() - .values - ), + (ds.reindex(index=[-1, 1, 1], **kwargs)["arr"].to_pandas().values), + ) + assert reindexed_cat.equals( # type: ignore[attr-defined] + pd.array( + [pd.NA, extension_array[1], extension_array[1]], + dtype=extension_array.dtype, + ) ) - assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined] + @requires_pyarrow def test_extension_array_reindex_same(self) -> None: - series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype()) + series = pd.Series([1, 2, pd.NA, 3], dtype="int32[pyarrow]") test = xr.Dataset({"test": series}) res = test.reindex(dim_0=series.index) align(res, test, join="exact") @@ -5473,6 +5494,51 @@ def test_dropna(self) -> None: with pytest.raises(TypeError, match=r"must specify how or thresh"): ds.dropna("a", how=None) # type: ignore[arg-type] + @pytest.mark.parametrize( + "fill_value,extension_array", + [ + pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="category"), + ] + + [ + pytest.param( + 0, + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), + id="int64[pyarrow]", + ) + ] + if has_pyarrow + else [], + ) + def test_fillna_extension_array(self, fill_value, extension_array) -> None: + srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) + ds = srs.to_xarray() + filled = ds.fillna(fill_value) + assert filled["data"].dtype == extension_array.dtype + assert ( + filled["data"].values + == np.array([fill_value, *srs["data"].values[1:]], dtype="object") + ).all() + + @pytest.mark.parametrize( + "extension_array", + [ + pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="category"), + ] + + [ + pytest.param( + pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]" + ) + ] + if has_pyarrow + else [], + ) + def test_dropna_extension_array(self, extension_array) -> None: + srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3])) + ds = srs.to_xarray() + dropped = ds.dropna("index") + assert dropped["data"].dtype == extension_array.dtype + assert (dropped["data"].values == srs["data"].values[1:]).all() + def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]})