Skip to content

(fix): no fill_value on reindex #10304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype

from xarray.compat import array_api_compat, npcompat
Expand Down Expand Up @@ -63,7 +64,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
if pd.api.types.is_extension_array_dtype(dtype):
return dtype, pd.NA
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
# for now, we always promote string dtypes to object for consistency with existing behavior
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
dtype_ = object
Expand Down
30 changes: 23 additions & 7 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@

if xp == np:
# numpy currently doesn't have a astype:
return data.astype(dtype, **kwargs)

Check warning on line 253 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 253 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 253 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 253 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 253 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 253 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast
return xp.astype(data, dtype, **kwargs)


Expand All @@ -272,18 +272,35 @@
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
non_nans = [x for x in scalars_or_arrays if not isna(x)]
if len(extension_array_types) == len(non_nans) and all(
non_nans_or_scalar = [
x for x in scalars_or_arrays if not (isna(x) or np.isscalar(x))
]
if len(extension_array_types) == len(non_nans_or_scalar) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return [
# Get the extension array class of the first element, guaranteed to be the same
# as the others thanks to the anove check.
extension_array_class = type(
non_nans_or_scalar[0].array
if isinstance(non_nans_or_scalar[0], PandasExtensionArray)
else non_nans_or_scalar[0]
)
# Cast scalars/nans to extension array class
arrays_with_nan_to_sequence = [
x
if not isna(x)
else PandasExtensionArray(
type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype)
if not (isna(x) or np.isscalar(x))
else extension_array_class._from_sequence(
[x], dtype=non_nans_or_scalar[0].dtype
)
for x in scalars_or_arrays
]
# Wrap the output if necessary
return [
PandasExtensionArray(x)
if not isinstance(x, PandasExtensionArray)
else x
for x in arrays_with_nan_to_sequence
]
raise ValueError(
f"Cannot cast values to shared type, found values: {scalars_or_arrays}"
)
Expand Down Expand Up @@ -406,7 +423,6 @@
condition = asarray(condition, dtype=dtype, xp=xp)
else:
condition = astype(condition, dtype=dtype, xp=xp)

return xp.where(condition, *as_shared_dtype([x, y], xp=xp))


Expand Down
15 changes: 15 additions & 0 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ def __extension_duck_array__concatenate(
return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined]


@implements(np.reshape)
def __extension_duck_array__reshape(
arr: T_ExtensionArray, shape: tuple
) -> T_ExtensionArray:
if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,):
return arr
raise NotImplementedError(
f"Cannot reshape 1d-only pandas extension array to: {shape}"
)


@implements(np.where)
def __extension_duck_array__where(
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray
Expand Down Expand Up @@ -134,6 +145,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
return ufunc(*inputs, **kwargs)

def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
if (
isinstance(key, tuple) and len(key) == 1
): # pyarrow type arrays can't handle since-length tuples
key = key[0]
item = self.array[key]
if is_extension_array_dtype(item):
return PandasExtensionArray(item)
Expand Down
54 changes: 54 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
assert_no_warnings,
has_dask,
has_dask_ge_2025_1_0,
has_pyarrow,
raise_if_dask_computes,
requires_bottleneck,
requires_cupy,
Expand All @@ -61,6 +62,7 @@
requires_iris,
requires_numexpr,
requires_pint,
requires_pyarrow,
requires_scipy,
requires_sparse,
source_ndarray,
Expand Down Expand Up @@ -3075,6 +3077,58 @@ def test_propagate_attrs(self, func) -> None:
with set_options(keep_attrs=True):
assert func(da).attrs == da.attrs

@pytest.mark.parametrize(
"fill_value,extension_array",
[
pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="categorical"),
]
+ [
pytest.param(
0,
pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"),
id="int64[pyarrow]",
)
]
if has_pyarrow
else [],
)
def test_fillna_extension_array(self, fill_value, extension_array) -> None:
srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array)
da = srs.to_xarray()
filled = da.fillna(fill_value)
assert filled.dtype == srs.dtype
assert (filled.values == np.array([fill_value, *(srs.values[1:])])).all()

@requires_pyarrow
def test_fillna_extension_array_bad_val(self) -> None:
srs: pd.Series = pd.Series(
index=np.array([1, 2, 3]),
data=pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"),
)
da = srs.to_xarray()
with pytest.raises(ValueError):
da.fillna("a")

@pytest.mark.parametrize(
"extension_array",
[
pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="categorical"),
]
+ [
pytest.param(
pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]"
)
]
if has_pyarrow
else [],
)
def test_dropna_extension_array(self, extension_array) -> None:
srs: pd.Series = pd.Series(index=np.array([1, 2, 3]), data=extension_array)
da = srs.to_xarray()
filled = da.dropna("index")
assert filled.dtype == srs.dtype
assert (filled.values == srs.values[1:]).all()

def test_fillna(self) -> None:
a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x")
actual = a.fillna(-1)
Expand Down
94 changes: 80 additions & 14 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
requires_dask,
requires_numexpr,
requires_pint,
requires_pyarrow,
requires_scipy,
requires_sparse,
source_ndarray,
Expand Down Expand Up @@ -1802,28 +1803,48 @@ def test_categorical_index_reindex(self) -> None:
actual = ds.reindex(cat=["foo"])["cat"].values
assert (actual == np.array(["foo"])).all()

@pytest.mark.parametrize("fill_value", [np.nan, pd.NA])
def test_extensionarray_negative_reindex(self, fill_value) -> None:
cat = pd.Categorical(
["foo", "bar", "baz"],
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
)
@pytest.mark.parametrize("fill_value", [np.nan, pd.NA, None])
@pytest.mark.parametrize(
"extension_array",
[
pytest.param(
pd.Categorical(
["foo", "bar", "baz"],
categories=["foo", "bar", "baz", "qux"],
),
id="categorical",
),
]
+ [
pytest.param(
pd.array([1, 1, None], dtype="int64[pyarrow]"), id="int64[pyarrow]"
)
]
if has_pyarrow
else [],
)
def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None:
ds = xr.Dataset(
{"cat": ("index", cat)},
{"arr": ("index", extension_array)},
coords={"index": ("index", np.arange(3))},
)
kwargs = {}
if fill_value is not None:
kwargs["fill_value"] = fill_value
reindexed_cat = cast(
pd.api.extensions.ExtensionArray,
(
ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"]
.to_pandas()
.values
),
(ds.reindex(index=[-1, 1, 1], **kwargs)["arr"].to_pandas().values),
)
assert reindexed_cat.equals( # type: ignore[attr-defined]
pd.array(
[pd.NA, extension_array[1], extension_array[1]],
dtype=extension_array.dtype,
)
)
assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined]

@requires_pyarrow
def test_extension_array_reindex_same(self) -> None:
series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype())
series = pd.Series([1, 2, pd.NA, 3], dtype="int32[pyarrow]")
test = xr.Dataset({"test": series})
res = test.reindex(dim_0=series.index)
align(res, test, join="exact")
Expand Down Expand Up @@ -5473,6 +5494,51 @@ def test_dropna(self) -> None:
with pytest.raises(TypeError, match=r"must specify how or thresh"):
ds.dropna("a", how=None) # type: ignore[arg-type]

@pytest.mark.parametrize(
"fill_value,extension_array",
[
pytest.param("a", pd.Categorical([pd.NA, "a", "b"]), id="category"),
]
+ [
pytest.param(
0,
pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"),
id="int64[pyarrow]",
)
]
if has_pyarrow
else [],
)
def test_fillna_extension_array(self, fill_value, extension_array) -> None:
srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3]))
ds = srs.to_xarray()
filled = ds.fillna(fill_value)
assert filled["data"].dtype == extension_array.dtype
assert (
filled["data"].values
== np.array([fill_value, *srs["data"].values[1:]], dtype="object")
).all()

@pytest.mark.parametrize(
"extension_array",
[
pytest.param(pd.Categorical([pd.NA, "a", "b"]), id="category"),
]
+ [
pytest.param(
pd.array([pd.NA, 1, 1], dtype="int64[pyarrow]"), id="int64[pyarrow]"
)
]
if has_pyarrow
else [],
)
def test_dropna_extension_array(self, extension_array) -> None:
srs = pd.DataFrame({"data": extension_array}, index=np.array([1, 2, 3]))
ds = srs.to_xarray()
dropped = ds.dropna("index")
assert dropped["data"].dtype == extension_array.dtype
assert (dropped["data"].values == srs["data"].values[1:]).all()

def test_fillna(self) -> None:
ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]})

Expand Down
Loading