diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 256cee6003189..b0879c9410f97 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -2464,30 +2464,24 @@ def idxmax(self, skipna: bool = True) -> FrameLike: index = self._psdf._internal.index_spark_column_names[0] index_spark_type = self._psdf._internal.index_fields[0].spark_type + pd_version = LooseVersion(pd.__version__) + stat_exprs = [] for psser, scol in zip(self._agg_columns, self._agg_columns_scols): name = psser._internal.data_spark_column_names[0] - if LooseVersion(pd.__version__) < "3.0.0" or skipna: + if pd_version < "3.0.0" or skipna: order_column = scol.desc_nulls_last() window = Window.partitionBy(*groupkey_names).orderBy( order_column, NATURAL_ORDER_COLUMN_NAME ) - has_na_name = "__has_na_{}__".format(name) - sdf = sdf.withColumn(has_na_name, scol.isNull()).withColumn( + sdf = sdf.withColumn( name, F.when(F.row_number().over(window) == 1, scol_for(sdf, index)).otherwise(None), ) - if skipna: - stat_exprs.append(F.max(scol_for(sdf, name)).alias(name)) - else: - stat_exprs.append( - F.when(F.max(scol_for(sdf, has_na_name)), None) - .otherwise(F.max(scol_for(sdf, name))) - .alias(name) - ) + stat_exprs.append(F.max(scol_for(sdf, name)).alias(name)) else: # pandas 3 skipna=False: raise on any NA, otherwise return all-missing labels stat_exprs.append( @@ -2565,29 +2559,24 @@ def idxmin(self, skipna: bool = True) -> FrameLike: index = self._psdf._internal.index_spark_column_names[0] index_spark_type = self._psdf._internal.index_fields[0].spark_type + pd_version = LooseVersion(pd.__version__) + stat_exprs = [] for psser, scol in zip(self._agg_columns, self._agg_columns_scols): name = psser._internal.data_spark_column_names[0] - if LooseVersion(pd.__version__) < "3.0.0" or skipna: + if pd_version < "3.0.0" or skipna: order_column = scol.asc_nulls_last() window = Window.partitionBy(*groupkey_names).orderBy( order_column, NATURAL_ORDER_COLUMN_NAME ) - has_na_name = "__has_na_{}__".format(name) - sdf = sdf.withColumn(has_na_name, scol.isNull()).withColumn( + + sdf = sdf.withColumn( name, F.when(F.row_number().over(window) == 1, scol_for(sdf, index)).otherwise(None), ) - if skipna: - stat_exprs.append(F.max(scol_for(sdf, name)).alias(name)) - else: - stat_exprs.append( - F.when(F.max(scol_for(sdf, has_na_name)), None) - .otherwise(F.max(scol_for(sdf, name))) - .alias(name) - ) + stat_exprs.append(F.max(scol_for(sdf, name)).alias(name)) else: # pandas 3 skipna=False: raise on any NA, otherwise return all-missing labels stat_exprs.append( diff --git a/python/pyspark/pandas/tests/groupby/test_index.py b/python/pyspark/pandas/tests/groupby/test_index.py index a9d4ec6c2bbe2..f43fd9549fb53 100644 --- a/python/pyspark/pandas/tests/groupby/test_index.py +++ b/python/pyspark/pandas/tests/groupby/test_index.py @@ -242,23 +242,25 @@ def test_idxmax_idxmin_skipna_false_with_na(self): with self.subTest(i=i): psdf = ps.from_pandas(pdf) if LooseVersion(pd.__version__) < "3.0.0": + # pandas-on-Spark preserves the legacy idxmax/idxmin result for skipna=False. self.assert_eq( - pdf.groupby(["a"]).idxmax(skipna=False).sort_index(), + pdf.groupby(["a"]).idxmax().sort_index(), psdf.groupby(["a"]).idxmax(skipna=False).sort_index(), ) self.assert_eq( - pdf.groupby(["a"]).idxmin(skipna=False).sort_index(), + pdf.groupby(["a"]).idxmin().sort_index(), psdf.groupby(["a"]).idxmin(skipna=False).sort_index(), ) self.assert_eq( - pdf.groupby(["a"])["b"].idxmax(skipna=False).sort_index(), + pdf.groupby(["a"])["b"].idxmax().sort_index(), psdf.groupby(["a"])["b"].idxmax(skipna=False).sort_index(), ) self.assert_eq( - pdf.groupby(["a"])["b"].idxmin(skipna=False).sort_index(), + pdf.groupby(["a"])["b"].idxmin().sort_index(), psdf.groupby(["a"])["b"].idxmin(skipna=False).sort_index(), ) else: + # pandas 3 raises for skipna=False when NA values are present. with self.assertRaisesRegex( Exception, "idxmax with skipna=False encountered an NA value" ):