diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 0474269d95..7aecb6414e 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -331,7 +331,7 @@ def join( self, other: Self, *, - how: Literal["left", "inner", "cross", "anti", "semi"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index a4d0ea52cf..e94f97d4d8 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -244,7 +244,130 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> self.native.sort_values(list(by), ascending=ascending, na_position=position) ) - def join( # noqa: C901 + def _join_inner( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str + ) -> Self: + return self._with_native( + self.native.merge( + other.native, + left_on=left_on, + right_on=right_on, + how="inner", + suffixes=("", suffix), + ) + ) + + def _join_left( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str + ) -> Self: + result_native = self.native.merge( + other.native, + how="left", + left_on=left_on, + right_on=right_on, + suffixes=("", suffix), + ) + extra = [ + right_key if right_key not in self.columns else f"{right_key}{suffix}" + for left_key, right_key in zip(left_on, right_on) + if right_key != left_key + ] + return self._with_native(result_native.drop(columns=extra)) + + def _join_full( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str + ) -> Self: + # dask does not retain keys post-join + # we must append the suffix to each key before-hand + + right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix) + other_native = other.native.rename(columns=right_on_mapper) + check_column_names_are_unique(other_native.columns) + right_suffixed = list(right_on_mapper.values()) + return self._with_native( + self.native.merge( + other_native, + left_on=left_on, + right_on=right_suffixed, + how="outer", + suffixes=("", suffix), + ) + ) + + def _join_cross(self, other: Self, *, suffix: str) -> Self: + key_token = generate_temporary_column_name( + n_bytes=8, columns=(*self.columns, *other.columns) + ) + return self._with_native( + self.native.assign(**{key_token: 0}) + .merge( + other.native.assign(**{key_token: 0}), + how="inner", + left_on=key_token, + right_on=key_token, + suffixes=("", suffix), + ) + .drop(columns=key_token) + ) + + def _join_semi( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str] + ) -> Self: + other_native = self._join_filter_rename( + other=other, + columns_to_select=list(right_on), + columns_mapping=dict(zip(right_on, left_on)), + ) + return self._with_native( + self.native.merge( + other_native, how="inner", left_on=left_on, right_on=left_on + ) + ) + + def _join_anti( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str] + ) -> Self: + indicator_token = generate_temporary_column_name( + n_bytes=8, columns=(*self.columns, *other.columns) + ) + other_native = self._join_filter_rename( + other=other, + columns_to_select=list(right_on), + columns_mapping=dict(zip(right_on, left_on)), + ) + df = self.native.merge( + other_native, + how="left", + indicator=indicator_token, # pyright: ignore[reportArgumentType] + left_on=left_on, + right_on=left_on, + ) + return self._with_native( + df[df[indicator_token] == "left_only"].drop(columns=[indicator_token]) + ) + + def _join_filter_rename( + self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str] + ) -> dd.DataFrame: + """Helper function to avoid creating extra columns and row duplication. + + Used in `"anti"` and `"semi`" join's. + + Notice that a native object is returned. + """ + return ( + select_columns_by_name( + other.native, + column_names=columns_to_select, + backend_version=self._backend_version, + implementation=self._implementation, + ) + # rename to avoid creating extra columns in join + .rename(columns=columns_mapping) + .drop_duplicates() + ) + + def join( self, other: Self, *, @@ -254,122 +377,30 @@ def join( # noqa: C901 suffix: str, ) -> Self: if how == "cross": - key_token = generate_temporary_column_name( - n_bytes=8, columns=[*self.columns, *other.columns] - ) - - return self._with_native( - self.native.assign(**{key_token: 0}) - .merge( - other.native.assign(**{key_token: 0}), - how="inner", - left_on=key_token, - right_on=key_token, - suffixes=("", suffix), - ) - .drop(columns=key_token) - ) + return self._join_cross(other=other, suffix=suffix) - if how == "anti": - indicator_token = generate_temporary_column_name( - n_bytes=8, columns=[*self.columns, *other.columns] - ) + if left_on is None or right_on is None: # pragma: no cover + raise ValueError(left_on, right_on) - if right_on is None: # pragma: no cover - msg = "`right_on` cannot be `None` in anti-join" - raise TypeError(msg) - other_native = ( - select_columns_by_name( - other.native, - list(right_on), - self._backend_version, - self._implementation, - ) - .rename( # rename to avoid creating extra columns in join - columns=dict(zip(right_on, left_on)) # type: ignore[arg-type] - ) - .drop_duplicates() + if how == "inner": + return self._join_inner( + other=other, left_on=left_on, right_on=right_on, suffix=suffix ) - df = self.native.merge( - other_native, - how="outer", - indicator=indicator_token, # pyright: ignore[reportArgumentType] - left_on=left_on, - right_on=left_on, - ) - return self._with_native( - df[df[indicator_token] == "left_only"].drop(columns=[indicator_token]) - ) - + if how == "anti": + return self._join_anti(other=other, left_on=left_on, right_on=right_on) if how == "semi": - if right_on is None: # pragma: no cover - msg = "`right_on` cannot be `None` in semi-join" - raise TypeError(msg) - other_native = ( - select_columns_by_name( - other.native, - list(right_on), - self._backend_version, - self._implementation, - ) - .rename( # rename to avoid creating extra columns in join - columns=dict(zip(right_on, left_on)) # type: ignore[arg-type] - ) - .drop_duplicates() # avoids potential rows duplication from inner join - ) - return self._with_native( - self.native.merge( - other_native, how="inner", left_on=left_on, right_on=left_on - ) - ) - + return self._join_semi(other=other, left_on=left_on, right_on=right_on) if how == "left": - result_native = self.native.merge( - other.native, - how="left", - left_on=left_on, - right_on=right_on, - suffixes=("", suffix), + return self._join_left( + other=other, left_on=left_on, right_on=right_on, suffix=suffix ) - extra = [] - for left_key, right_key in zip(left_on, right_on): # type: ignore[arg-type] - if right_key != left_key and right_key not in self.columns: - extra.append(right_key) - elif right_key != left_key: - extra.append(f"{right_key}_right") - return self._with_native(result_native.drop(columns=extra)) - if how == "full": - # dask does not retain keys post-join - # we must append the suffix to each key before-hand - - # help mypy - assert left_on is not None # noqa: S101 - assert right_on is not None # noqa: S101 - - right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix) - other_native = other.native.rename(columns=right_on_mapper) - check_column_names_are_unique(other_native.columns) - right_on = list(right_on_mapper.values()) # we now have the suffixed keys - return self._with_native( - self.native.merge( - other_native, - left_on=left_on, - right_on=right_on, - how="outer", - suffixes=("", suffix), - ) + return self._join_full( + other=other, left_on=left_on, right_on=right_on, suffix=suffix ) - return self._with_native( - self.native.merge( - other.native, - left_on=left_on, - right_on=right_on, - how=how, - suffixes=("", suffix), - ) - ) + msg = f"Unreachable code, got unexpected join method: {how}" # pragma: no cover + raise AssertionError(msg) def join_asof( self, diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 881395c8e8..296817d600 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -559,153 +559,185 @@ def group_by( return PandasLikeGroupBy(self, keys, drop_null_keys=drop_null_keys) - def join( # noqa: C901, PLR0911, PLR0912 - self, - other: Self, - *, - how: JoinStrategy, - left_on: Sequence[str] | None, - right_on: Sequence[str] | None, - suffix: str, + def _join_inner( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str ) -> Self: - if how == "cross": - if ( - self._implementation is Implementation.MODIN - or self._implementation is Implementation.CUDF - ) or ( - self._implementation is Implementation.PANDAS - and self._backend_version < (1, 4) - ): - key_token = generate_temporary_column_name( - n_bytes=8, columns=[*self.columns, *other.columns] - ) + return self._with_native( + self.native.merge( + other.native, + left_on=left_on, + right_on=right_on, + how="inner", + suffixes=("", suffix), + ) + ) - return self._with_native( - self.native.assign(**{key_token: 0}) - .merge( - other.native.assign(**{key_token: 0}), - how="inner", - left_on=key_token, - right_on=key_token, - suffixes=("", suffix), - ) - .drop(columns=key_token) - ) - else: - return self._with_native( - self.native.merge(other.native, how="cross", suffixes=("", suffix)) - ) + def _join_left( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str + ) -> Self: + result_native = self.native.merge( + other.native, + how="left", + left_on=left_on, + right_on=right_on, + suffixes=("", suffix), + ) + extra = [ + right_key if right_key not in self.columns else f"{right_key}{suffix}" + for left_key, right_key in zip(left_on, right_on) + if right_key != left_key + ] + return self._with_native(result_native.drop(columns=extra)) - if how == "anti": - if self._implementation is Implementation.CUDF: - return self._with_native( - self.native.merge( - other.native, how="leftanti", left_on=left_on, right_on=right_on - ) - ) - else: - indicator_token = generate_temporary_column_name( - n_bytes=8, columns=[*self.columns, *other.columns] - ) - if right_on is None: # pragma: no cover - msg = "`right_on` cannot be `None` in anti-join" - raise TypeError(msg) - - # rename to avoid creating extra columns in join - other_native = rename( - select_columns_by_name( - other.native, - list(right_on), - self._backend_version, - self._implementation, - ), - columns=dict(zip(right_on, left_on)), # type: ignore[arg-type] - implementation=self._implementation, - backend_version=self._backend_version, - ).drop_duplicates() - return self._with_native( - self.native.merge( - other_native, - how="outer", - indicator=indicator_token, - left_on=left_on, - right_on=left_on, - ) - .loc[lambda t: t[indicator_token] == "left_only"] - .drop(columns=indicator_token) - ) + def _join_full( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str + ) -> Self: + # Pandas coalesces keys in full joins unless there's no collision + right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix) + other_native = other.native.rename(columns=right_on_mapper) + check_column_names_are_unique(other_native.columns) + right_suffixed = list(right_on_mapper.values()) + return self._with_native( + self.native.merge( + other_native, + left_on=left_on, + right_on=right_suffixed, + how="outer", + suffixes=("", suffix), + ) + ) - if how == "semi": - if right_on is None: # pragma: no cover - msg = "`right_on` cannot be `None` in semi-join" - raise TypeError(msg) - # rename to avoid creating extra columns in join - other_native = ( - rename( - select_columns_by_name( - other.native, - list(right_on), - self._backend_version, - self._implementation, - ), - columns=dict(zip(right_on, left_on)), # type: ignore[arg-type] - implementation=self._implementation, - backend_version=self._backend_version, - ).drop_duplicates() # avoids potential rows duplication from inner join + def _join_cross(self, other: Self, *, suffix: str) -> Self: + implementation = self._implementation + backend_version = self._backend_version + if (implementation.is_modin() or implementation.is_cudf()) or ( + implementation.is_pandas() and backend_version < (1, 4) + ): + key_token = generate_temporary_column_name( + n_bytes=8, columns=(*self.columns, *other.columns) ) return self._with_native( - self.native.merge( - other_native, how="inner", left_on=left_on, right_on=left_on + self.native.assign(**{key_token: 0}) + .merge( + other.native.assign(**{key_token: 0}), + how="inner", + left_on=key_token, + right_on=key_token, + suffixes=("", suffix), ) + .drop(columns=key_token) ) + return self._with_native( + self.native.merge(other.native, how="cross", suffixes=("", suffix)) + ) - if how == "left": - result_native = self.native.merge( - other.native, - how="left", - left_on=left_on, - right_on=right_on, - suffixes=("", suffix), + def _join_semi( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str] + ) -> Self: + other_native = self._join_filter_rename( + other=other, + columns_to_select=list(right_on), + columns_mapping=dict(zip(right_on, left_on)), + ) + return self._with_native( + self.native.merge( + other_native, how="inner", left_on=left_on, right_on=left_on ) - extra = [] - for left_key, right_key in zip(left_on, right_on): # type: ignore[arg-type] - if right_key != left_key and right_key not in self.columns: - extra.append(right_key) - elif right_key != left_key: - extra.append(f"{right_key}{suffix}") - return self._with_native(result_native.drop(columns=extra)) - - if how == "full": - # Pandas coalesces keys in full joins unless there's no collision + ) - # help mypy - assert left_on is not None # noqa: S101 - assert right_on is not None # noqa: S101 + def _join_anti( + self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str] + ) -> Self: + implementation = self._implementation - right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix) - other_native = other.native.rename(columns=right_on_mapper) - check_column_names_are_unique(other_native.columns) - right_on = list(right_on_mapper.values()) # we now have the suffixed keys + if implementation.is_cudf(): return self._with_native( self.native.merge( - other_native, - left_on=left_on, - right_on=right_on, - how="outer", - suffixes=("", suffix), + other.native, how="leftanti", left_on=left_on, right_on=right_on ) ) + indicator_token = generate_temporary_column_name( + n_bytes=8, columns=(*self.columns, *other.columns) + ) + + other_native = self._join_filter_rename( + other=other, + columns_to_select=list(right_on), + columns_mapping=dict(zip(right_on, left_on)), + ) return self._with_native( self.native.merge( - other.native, + other_native, + # TODO(FBruzzesi): Raise issue upstream for Modin + how="left" if implementation.is_pandas() else "outer", + indicator=indicator_token, left_on=left_on, - right_on=right_on, - how=how, - suffixes=("", suffix), + right_on=left_on, ) + .loc[lambda t: t[indicator_token] == "left_only"] + .drop(columns=indicator_token) ) + def _join_filter_rename( + self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str] + ) -> pd.DataFrame: + """Helper function to avoid creating extra columns and row duplication. + + Used in `"anti"` and `"semi`" join's. + + Notice that a native object is returned. + """ + implementation = self._implementation + backend_version = self._backend_version + + return rename( + select_columns_by_name( + other.native, + column_names=columns_to_select, + backend_version=backend_version, + implementation=implementation, + ), + columns=columns_mapping, + implementation=implementation, + backend_version=backend_version, + ).drop_duplicates() + + def join( + self, + other: Self, + *, + how: JoinStrategy, + left_on: Sequence[str] | None, + right_on: Sequence[str] | None, + suffix: str, + ) -> Self: + if how == "cross": + return self._join_cross(other=other, suffix=suffix) + + if left_on is None or right_on is None: # pragma: no cover + raise ValueError(left_on, right_on) + + if how == "inner": + return self._join_inner( + other=other, left_on=left_on, right_on=right_on, suffix=suffix + ) + if how == "anti": + return self._join_anti(other=other, left_on=left_on, right_on=right_on) + if how == "semi": + return self._join_semi(other=other, left_on=left_on, right_on=right_on) + if how == "left": + return self._join_left( + other=other, left_on=left_on, right_on=right_on, suffix=suffix + ) + if how == "full": + return self._join_full( + other=other, left_on=left_on, right_on=right_on, suffix=suffix + ) + + msg = f"Unreachable code, got unexpected join method: {how}" # pragma: no cover + raise AssertionError(msg) + def join_asof( self, other: Self, diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 303ff0636d..7d231a15d6 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -231,75 +231,66 @@ def sort( def join( self, other: Self, - on: str | list[str] | None = None, - how: JoinStrategy = "inner", + on: str | list[str] | None, + how: JoinStrategy, *, - left_on: str | list[str] | None = None, - right_on: str | list[str] | None = None, - suffix: str = "_right", + left_on: str | list[str] | None, + right_on: str | list[str] | None, + suffix: str, ) -> Self: + _supported_joins = ("inner", "left", "full", "cross", "anti", "semi") on = [on] if isinstance(on, str) else on left_on = [left_on] if isinstance(left_on, str) else left_on right_on = [right_on] if isinstance(right_on, str) else right_on + compliant = self._compliant_frame + other = self._extract_compliant(other) - if how not in ( - _supported_joins := ("inner", "left", "full", "cross", "anti", "semi") - ): + if how not in _supported_joins: msg = f"Only the following join strategies are supported: {_supported_joins}; found '{how}'." raise NotImplementedError(msg) - - if how == "cross" and ( - left_on is not None or right_on is not None or on is not None - ): - msg = "Can not pass `left_on`, `right_on` or `on` keys for cross join" - raise ValueError(msg) - - if how != "cross" and (on is None and (left_on is None or right_on is None)): - msg = f"Either (`left_on` and `right_on`) or `on` keys should be specified for {how}." - raise ValueError(msg) - - if how != "cross" and ( - on is not None and (left_on is not None or right_on is not None) - ): - msg = f"If `on` is specified, `left_on` and `right_on` should be None for {how}." - raise ValueError(msg) - - if on is not None: - left_on = right_on = on - - if (isinstance(left_on, list) and isinstance(right_on, list)) and ( - len(left_on) != len(right_on) - ): - msg = "`left_on` and `right_on` must have the same length." - raise ValueError(msg) - - return self._with_compliant( - self._compliant_frame.join( - self._extract_compliant(other), - how=how, - left_on=left_on, - right_on=right_on, - suffix=suffix, + if how == "cross": + if left_on is not None or right_on is not None or on is not None: + msg = "Can not pass `left_on`, `right_on` or `on` keys for cross join" + raise ValueError(msg) + result = compliant.join( + other, how=how, left_on=None, right_on=None, suffix=suffix ) - ) + elif on is None: + if left_on is None or right_on is None: + msg = f"Either (`left_on` and `right_on`) or `on` keys should be specified for {how}." + raise ValueError(msg) + if len(left_on) != len(right_on): + msg = "`left_on` and `right_on` must have the same length." + raise ValueError(msg) + result = compliant.join( + other, how=how, left_on=left_on, right_on=right_on, suffix=suffix + ) + else: + if left_on is not None or right_on is not None: + msg = f"If `on` is specified, `left_on` and `right_on` should be None for {how}." + raise ValueError(msg) + result = compliant.join( + other, how=how, left_on=on, right_on=on, suffix=suffix + ) + return self._with_compliant(result) def gather_every(self, n: int, offset: int = 0) -> Self: return self._with_compliant( self._compliant_frame.gather_every(n=n, offset=offset) ) - def join_asof( # noqa: C901 + def join_asof( self, other: Self, *, - left_on: str | None = None, - right_on: str | None = None, - on: str | None = None, - by_left: str | list[str] | None = None, - by_right: str | list[str] | None = None, - by: str | list[str] | None = None, - strategy: AsofJoinStrategy = "backward", - suffix: str = "_right", + left_on: str | None, + right_on: str | None, + on: str | None, + by_left: str | list[str] | None, + by_right: str | list[str] | None, + by: str | list[str] | None, + strategy: AsofJoinStrategy, + suffix: str, ) -> Self: _supported_strategies = ("backward", "forward", "nearest") @@ -328,10 +319,9 @@ def join_asof( # noqa: C901 left_on = right_on = on if by is not None: by_left = by_right = by - if isinstance(by_left, str): - by_left = [by_left] - if isinstance(by_right, str): - by_right = [by_right] + + by_left = [by_left] if isinstance(by_left, str) else by_left + by_right = [by_right] if isinstance(by_right, str) else by_right if (isinstance(by_left, list) and isinstance(by_right, list)) and ( len(by_left) != len(by_right)