diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 679b5ad023..79126ccd03 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -25,7 +25,7 @@ from types import ModuleType import dask.dataframe.dask_expr as dx - from typing_extensions import Self, TypeIs + from typing_extensions import Self, TypeAlias, TypeIs from narwhals._compliant.typing import CompliantDataFrameAny from narwhals._dask.expr import DaskExpr @@ -36,6 +36,13 @@ from narwhals.dtypes import DType from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy +Incomplete: TypeAlias = "Any" +"""Using `_pandas_like` utils with `_dask`. + +Typing this correctly will complicate the `_pandas_like`-side. +Very low priority until `dask` adds typing. +""" + class DaskLazyFrame( CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"] @@ -159,8 +166,9 @@ def filter(self, predicate: DaskExpr) -> Self: return self._with_native(self.native.loc[mask]) def simple_select(self, *column_names: str) -> Self: + df: Incomplete = self.native native = select_columns_by_name( - self.native, list(column_names), self._backend_version, self._implementation + df, list(column_names), self._backend_version, self._implementation ) return self._with_native(native) @@ -171,8 +179,9 @@ def aggregate(self, *exprs: DaskExpr) -> Self: def select(self, *exprs: DaskExpr) -> Self: new_series = evaluate_exprs(self, *exprs) + df: Incomplete = self.native df = select_columns_by_name( - self.native.assign(**dict(new_series)), + df.assign(**dict(new_series)), [s[0] for s in new_series], self._backend_version, self._implementation, @@ -287,6 +296,7 @@ def join( # noqa: C901 ) .drop(columns=key_token) ) + other_native: Incomplete = other.native if how == "anti": indicator_token = generate_temporary_column_name( @@ -298,7 +308,7 @@ def join( # noqa: C901 raise TypeError(msg) other_native = ( select_columns_by_name( - other.native, + other_native, list(right_on), self._backend_version, self._implementation, @@ -325,7 +335,7 @@ def join( # noqa: C901 raise TypeError(msg) other_native = ( select_columns_by_name( - other.native, + other_native, list(right_on), self._backend_version, self._implementation, diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 0823a972c4..97de9aed42 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -17,7 +17,7 @@ import dask.dataframe as dd import dask.dataframe.dask_expr as dx - from narwhals._dask.dataframe import DaskLazyFrame + from narwhals._dask.dataframe import DaskLazyFrame, Incomplete from narwhals._dask.expr import DaskExpr from narwhals.typing import IntoDType else: @@ -65,9 +65,9 @@ def add_row_index( implementation: Implementation, ) -> dd.DataFrame: original_cols = frame.columns - frame = frame.assign(**{name: 1}) + df: Incomplete = frame.assign(**{name: 1}) return select_columns_by_name( - frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1}), + df.assign(**{name: df[name].cumsum(method="blelloch") - 1}), [name, *original_cols], backend_version, implementation, diff --git a/narwhals/_namespace.py b/narwhals/_namespace.py index 996ebc78c6..d894e8a93e 100644 --- a/narwhals/_namespace.py +++ b/narwhals/_namespace.py @@ -30,6 +30,7 @@ ) if TYPE_CHECKING: + from collections.abc import Collection, Sized from types import ModuleType from typing import ClassVar @@ -39,7 +40,7 @@ import pyarrow as pa import pyspark.sql as pyspark_sql from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame - from typing_extensions import TypeAlias, TypeIs + from typing_extensions import Self, TypeAlias, TypeIs from narwhals._arrow.namespace import ArrowNamespace from narwhals._dask.namespace import DaskNamespace @@ -98,15 +99,36 @@ Implementation.POLARS, ] + class _BasePandasLike(Sized, Protocol): + index: Any + """`mypy` doesn't like the asymmetric `property` setter in `pandas`.""" + + def __getitem__(self, key: Any, /) -> Any: ... + def __mul__(self, other: float | Collection[float] | Self) -> Self: ... + def __floordiv__(self, other: float | Collection[float] | Self) -> Self: ... + @property + def loc(self) -> Any: ... + @property + def shape(self) -> tuple[int, ...]: ... + def set_axis(self, labels: Any, *, axis: Any = ..., copy: bool = ...) -> Self: ... + def copy(self, deep: bool = ...) -> Self: ... # noqa: FBT001 + def rename(self, *args: Any, inplace: Literal[False], **kwds: Any) -> Self: + """`inplace=False` is required to avoid (incorrect?) default overloads.""" + ... + + class _BasePandasLikeFrame(NativeFrame, _BasePandasLike, Protocol): ... + + class _BasePandasLikeSeries(NativeSeries, _BasePandasLike, Protocol): + def where(self, cond: Any, other: Any = ..., **kwds: Any) -> Any: ... + class _NativeDask(Protocol): _partition_type: type[pd.DataFrame] - class _CuDFDataFrame(NativeFrame, Protocol): + class _CuDFDataFrame(_BasePandasLikeFrame, Protocol): def to_pylibcudf(self, *args: Any, **kwds: Any) -> Any: ... - class _CuDFSeries(NativeSeries, Protocol): + class _CuDFSeries(_BasePandasLikeSeries, Protocol): def to_pylibcudf(self, *args: Any, **kwds: Any) -> Any: ... - def where(self, cond: Any, other: Any = ..., **kwds: Any) -> Any: ... class _NativeIbis(Protocol): def sql(self, *args: Any, **kwds: Any) -> Any: ... @@ -114,14 +136,12 @@ def __pyarrow_result__(self, *args: Any, **kwds: Any) -> Any: ... def __pandas_result__(self, *args: Any, **kwds: Any) -> Any: ... def __polars_result__(self, *args: Any, **kwds: Any) -> Any: ... - class _ModinDataFrame(NativeFrame, Protocol): + class _ModinDataFrame(_BasePandasLikeFrame, Protocol): _pandas_class: type[pd.DataFrame] - class _ModinSeries(NativeSeries, Protocol): + class _ModinSeries(_BasePandasLikeSeries, Protocol): _pandas_class: type[pd.Series[Any]] - def where(self, cond: Any, other: Any = ..., **kwds: Any) -> Any: ... - _NativePolars: TypeAlias = "pl.DataFrame | pl.LazyFrame | pl.Series" _NativeArrow: TypeAlias = "pa.Table | pa.ChunkedArray[Any]" _NativeDuckDB: TypeAlias = "duckdb.DuckDBPyRelation" diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 6e442ffc32..0d3a0b7498 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -333,6 +333,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ~null_mask_result, None ) else: + # NOTE: Trying to help `mypy` later + # error: Cannot determine type of "values" [has-type] + values: list[PandasLikeSeries] init_value, *values = [ s.zip_with(~nm, "") for s, nm in zip(series, null_mask) ] diff --git a/narwhals/_pandas_like/typing.py b/narwhals/_pandas_like/typing.py index 1e2453f7e9..054b011eac 100644 --- a/narwhals/_pandas_like/typing.py +++ b/narwhals/_pandas_like/typing.py @@ -10,16 +10,34 @@ import pandas as pd from typing_extensions import TypeAlias - from narwhals._namespace import _NativePandasLikeDataFrame, _NativePandasLikeSeries + from narwhals._namespace import ( + _CuDFDataFrame, + _CuDFSeries, + _ModinDataFrame, + _ModinSeries, + _NativePandasLikeDataFrame, + ) from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries IntoPandasLikeExpr: TypeAlias = "PandasLikeExpr | PandasLikeSeries" - NDFrameT = TypeVar("NDFrameT", "pd.DataFrame", "pd.Series[Any]") NativeSeriesT = TypeVar( - "NativeSeriesT", bound="_NativePandasLikeSeries", default="pd.Series[Any]" + "NativeSeriesT", + "pd.Series[Any]", + "_CuDFSeries", + "_ModinSeries", + default="pd.Series[Any]", ) NativeDataFrameT = TypeVar( "NativeDataFrameT", bound="_NativePandasLikeDataFrame", default="pd.DataFrame" ) +NativeNDFrameT = TypeVar( + "NativeNDFrameT", + "pd.DataFrame", + "pd.Series[Any]", + "_CuDFDataFrame", + "_CuDFSeries", + "_ModinDataFrame", + "_ModinSeries", +) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 254eabddf6..62e8b42d88 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -2,7 +2,6 @@ import functools import re -from collections.abc import Sized from contextlib import suppress from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar @@ -18,13 +17,16 @@ ) from narwhals.exceptions import DuplicateError, ShapeError -T = TypeVar("T", bound=Sized) - if TYPE_CHECKING: from pandas._typing import Dtype as PandasDtype from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries + from narwhals._pandas_like.typing import ( + NativeDataFrameT, + NativeNDFrameT, + NativeSeriesT, + ) from narwhals.dtypes import DType from narwhals.typing import DTypeBackend, IntoDType, TimeUnit, _1DArray @@ -123,12 +125,12 @@ def align_and_extract_native( def set_index( - obj: T, + obj: NativeNDFrameT, index: Any, *, implementation: Implementation, backend_version: tuple[int, ...], -) -> T: +) -> NativeNDFrameT: """Wrapper around pandas' set_axis to set object index. We can set `copy` / `inplace` based on implementation/version. @@ -139,30 +141,30 @@ def set_index( msg = f"Expected object of length {expected_len}, got length: {actual_len}" raise ShapeError(msg) if implementation is Implementation.CUDF: # pragma: no cover - obj = obj.copy(deep=False) # type: ignore[attr-defined] - obj.index = index # type: ignore[attr-defined] + obj = obj.copy(deep=False) + obj.index = index return obj if implementation is Implementation.PANDAS and ( (1, 5) <= backend_version < (3,) ): # pragma: no cover - return obj.set_axis(index, axis=0, copy=False) # type: ignore[attr-defined] + return obj.set_axis(index, axis=0, copy=False) else: # pragma: no cover - return obj.set_axis(index, axis=0) # type: ignore[attr-defined] + return obj.set_axis(index, axis=0) def rename( - obj: T, + obj: NativeNDFrameT, *args: Any, implementation: Implementation, backend_version: tuple[int, ...], **kwargs: Any, -) -> T: +) -> NativeNDFrameT: """Wrapper around pandas' rename so that we can set `copy` based on implementation/version.""" if implementation is Implementation.PANDAS and ( backend_version >= (3,) ): # pragma: no cover - return obj.rename(*args, **kwargs) # type: ignore[attr-defined] - return obj.rename(*args, **kwargs, copy=False) # type: ignore[attr-defined] + return obj.rename(*args, **kwargs, inplace=False) + return obj.rename(*args, **kwargs, copy=False, inplace=False) @functools.lru_cache(maxsize=16) @@ -506,8 +508,8 @@ def int_dtype_mapper(dtype: Any) -> str: def calculate_timestamp_datetime( # noqa: C901, PLR0912 - s: pd.Series[int], original_time_unit: str, time_unit: str -) -> pd.Series[int]: + s: NativeSeriesT, original_time_unit: str, time_unit: str +) -> NativeSeriesT: if original_time_unit == "ns": if time_unit == "ns": result = s @@ -542,7 +544,7 @@ def calculate_timestamp_datetime( # noqa: C901, PLR0912 return result -def calculate_timestamp_date(s: pd.Series[int], time_unit: str) -> pd.Series[int]: +def calculate_timestamp_date(s: NativeSeriesT, time_unit: str) -> NativeSeriesT: s = s * 86_400 # number of seconds in a day if time_unit == "ns": result = s * 1_000_000_000 @@ -554,36 +556,30 @@ def calculate_timestamp_date(s: pd.Series[int], time_unit: str) -> pd.Series[int def select_columns_by_name( - df: T, + df: NativeDataFrameT, column_names: list[str] | _1DArray, # NOTE: Cannot be a tuple! backend_version: tuple[int, ...], implementation: Implementation, -) -> T: +) -> NativeDataFrameT | Any: """Select columns by name. Prefer this over `df.loc[:, column_names]` as it's generally more performant. """ - if len(column_names) == df.shape[1] and all(column_names == df.columns): # type: ignore[attr-defined] + if len(column_names) == df.shape[1] and (df.columns == column_names).all(): return df - if (df.columns.dtype.kind == "b") or ( # type: ignore[attr-defined] + if (df.columns.dtype.kind == "b") or ( implementation is Implementation.PANDAS and backend_version < (1, 5) ): # See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122 # for why we need this - if error := check_columns_exist( - column_names, - available=df.columns.tolist(), # type: ignore[attr-defined] - ): + if error := check_columns_exist(column_names, available=df.columns.tolist()): raise error - return df.loc[:, column_names] # type: ignore[attr-defined] + return df.loc[:, column_names] try: - return df[column_names] # type: ignore[index] + return df[column_names] except KeyError as e: - if error := check_columns_exist( - column_names, - available=df.columns.tolist(), # type: ignore[attr-defined] - ): + if error := check_columns_exist(column_names, available=df.columns.tolist()): raise error from e raise