From b17a539e82edca5037a172e5909a5737c8316e5f Mon Sep 17 00:00:00 2001 From: Takuya Ueshin Date: Tue, 31 Mar 2026 11:59:27 -0700 Subject: [PATCH] Handle pandas 3 string dtype in DataFrame.toPandas --- .../pandas/data_type_ops/string_ops.py | 4 + python/pyspark/sql/pandas/types.py | 11 ++- .../test_connect_dataframe_property.py | 6 +- python/pyspark/sql/tests/test_collection.py | 77 ++++++++++++------- 4 files changed, 68 insertions(+), 30 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/string_ops.py b/python/pyspark/pandas/data_type_ops/string_ops.py index de35568cf2de4..c416d03a9c8f6 100644 --- a/python/pyspark/pandas/data_type_ops/string_ops.py +++ b/python/pyspark/pandas/data_type_ops/string_ops.py @@ -17,6 +17,7 @@ from typing import Any, Union, cast +import numpy as np import pandas as pd from pandas.api.types import CategoricalDtype @@ -153,6 +154,9 @@ def restore(self, col: pd.Series) -> pd.Series: if LooseVersion(pd.__version__) < "3.0.0": return super().restore(col) else: + if is_str_dtype(col.dtype) and not is_str_dtype(self.dtype): + # treat missing values as None for string dtype + col = col.replace({np.nan: None}) return col.astype(self.dtype) diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 029478bb7fca0..a98f8dc516bd4 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -896,6 +896,11 @@ def _to_corrected_pandas_type(dt: DataType) -> Optional[Any]: return np.dtype("timedelta64[ns]") else: return np.dtype("timedelta64[us]") + elif type(dt) == StringType: + if LooseVersion(pd.__version__) < "3.0.0": + return None + else: + return pd.StringDtype(na_value=np.nan) else: return None @@ -1007,7 +1012,11 @@ def correct_dtype(pser: pd.Series) -> pd.Series: def correct_dtype(pser: pd.Series) -> pd.Series: if not isinstance(pser.dtype, pd.DatetimeTZDtype): pser = pser.astype(pandas_type, copy=False) - return _check_series_convert_timestamps_local_tz(pser, timezone=timezone) + pser = _check_series_convert_timestamps_local_tz(pser, timezone=timezone) + if LooseVersion(pd.__version__) < "3.0.0": + return pser + else: + return pser.astype(pandas_type, copy=False) else: diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index 318bebf60b0dd..44f56828685c0 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -17,6 +17,7 @@ import unittest +from pyspark.loose_version import LooseVersion from pyspark.sql.types import ( StructType, StructField, @@ -113,7 +114,10 @@ def test_cached_schema_map_in_pandas(self): def func(iterator): for pdf in iterator: assert isinstance(pdf, pd.DataFrame) - assert [d.name for d in list(pdf.dtypes)] == ["int32", "object"] + if LooseVersion(pd.__version__) < "3.0.0": + assert [d.name for d in list(pdf.dtypes)] == ["int32", "object"] + else: + assert [d.name for d in list(pdf.dtypes)] == ["int32", "str"] yield pdf schema = StructType( diff --git a/python/pyspark/sql/tests/test_collection.py b/python/pyspark/sql/tests/test_collection.py index c6d44d6144740..5d0e48b73ce83 100644 --- a/python/pyspark/sql/tests/test_collection.py +++ b/python/pyspark/sql/tests/test_collection.py @@ -18,6 +18,7 @@ import datetime import unittest +from pyspark.loose_version import LooseVersion from pyspark.sql.types import ( Row, ArrayType, @@ -95,18 +96,27 @@ def _to_pandas(self): @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_to_pandas(self): + import pandas as pd import numpy as np pdf = self._to_pandas() types = pdf.dtypes - self.assertEqual(types[0], np.int32) - self.assertEqual(types[1], object) - self.assertEqual(types[2], bool) - self.assertEqual(types[3], np.float32) - self.assertEqual(types[4], object) # datetime.date - self.assertEqual(types[5], "datetime64[ns]") - self.assertEqual(types[6], "datetime64[ns]") - self.assertEqual(types[7], "timedelta64[ns]") + self.assertEqual(types.iloc[0], np.int32) + if LooseVersion(pd.__version__) < "3.0.0": + self.assertEqual(types.iloc[1], object) + else: + self.assertEqual(types.iloc[1], pd.StringDtype(na_value=np.nan)) # datetime.date + self.assertEqual(types.iloc[2], bool) + self.assertEqual(types.iloc[3], np.float32) + self.assertEqual(types.iloc[4], object) # datetime.date + if LooseVersion(pd.__version__) < "3.0.0": + self.assertEqual(types.iloc[5], "datetime64[ns]") + self.assertEqual(types.iloc[6], "datetime64[ns]") + self.assertEqual(types.iloc[7], "timedelta64[ns]") + else: + self.assertEqual(types.iloc[5], "datetime64[us]") + self.assertEqual(types.iloc[6], "datetime64[us]") + self.assertEqual(types.iloc[7], "timedelta64[us]") @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_to_pandas_with_duplicated_column_names(self): @@ -155,22 +165,27 @@ def test_to_pandas_required_pandas_not_found(self): @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_to_pandas_avoid_astype(self): + import pandas as pd import numpy as np schema = StructType().add("a", IntegerType()).add("b", StringType()).add("c", IntegerType()) data = [(1, "foo", 16777220), (None, "bar", None)] df = self.spark.createDataFrame(data, schema) types = df.toPandas().dtypes - self.assertEqual(types[0], np.float64) # doesn't convert to np.int32 due to NaN value. - self.assertEqual(types[1], object) - self.assertEqual(types[2], np.float64) + self.assertEqual(types.iloc[0], np.float64) # doesn't convert to np.int32 due to NaN value. + if LooseVersion(pd.__version__) < "3.0.0": + self.assertEqual(types.iloc[1], object) + else: + self.assertEqual(types.iloc[1], pd.StringDtype(na_value=np.nan)) + self.assertEqual(types.iloc[2], np.float64) @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_to_pandas_from_empty_dataframe(self): is_arrow_enabled = [True, False] for value in is_arrow_enabled: - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): - self.check_to_pandas_from_empty_dataframe() + with self.subTest(arrow_enabled=value): + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): + self.check_to_pandas_from_empty_dataframe() def check_to_pandas_from_empty_dataframe(self): # SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes @@ -199,13 +214,15 @@ def check_to_pandas_from_empty_dataframe(self): def test_to_pandas_from_null_dataframe(self): is_arrow_enabled = [True, False] for value in is_arrow_enabled: - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): - self.check_to_pandas_from_null_dataframe() + with self.subTest(arrow_enabled=value): + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): + self.check_to_pandas_from_null_dataframe() def check_to_pandas_from_null_dataframe(self): # SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes # SPARK-30537 test that toPandas() on a dataframe with only nulls has correct dtypes # using arrow + import pandas as pd import numpy as np sql = """ @@ -223,24 +240,28 @@ def check_to_pandas_from_null_dataframe(self): """ pdf = self.spark.sql(sql).toPandas() types = pdf.dtypes - self.assertEqual(types[0], np.float64) - self.assertEqual(types[1], np.float64) - self.assertEqual(types[2], np.float64) - self.assertEqual(types[3], np.float64) - self.assertEqual(types[4], np.float32) - self.assertEqual(types[5], np.float64) - self.assertEqual(types[6], object) - self.assertEqual(types[7], object) - self.assertTrue(np.can_cast(np.datetime64, types[8])) - self.assertTrue(np.can_cast(np.datetime64, types[9])) - self.assertTrue(np.can_cast(np.timedelta64, types[10])) + self.assertEqual(types.iloc[0], np.float64) + self.assertEqual(types.iloc[1], np.float64) + self.assertEqual(types.iloc[2], np.float64) + self.assertEqual(types.iloc[3], np.float64) + self.assertEqual(types.iloc[4], np.float32) + self.assertEqual(types.iloc[5], np.float64) + self.assertEqual(types.iloc[6], object) + if LooseVersion(pd.__version__) < "3.0.0": + self.assertEqual(types.iloc[7], object) + else: + self.assertEqual(types.iloc[7], pd.StringDtype(na_value=np.nan)) + self.assertTrue(np.can_cast(np.datetime64, types.iloc[8])) + self.assertTrue(np.can_cast(np.datetime64, types.iloc[9])) + self.assertTrue(np.can_cast(np.timedelta64, types.iloc[10])) @unittest.skipIf(not have_pandas, pandas_requirement_message) def test_to_pandas_from_mixed_dataframe(self): is_arrow_enabled = [True, False] for value in is_arrow_enabled: - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): - self.check_to_pandas_from_mixed_dataframe() + with self.subTest(arrow_enabled=value): + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): + self.check_to_pandas_from_mixed_dataframe() def check_to_pandas_from_mixed_dataframe(self): # SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes