Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/pyspark/pandas/data_type_ops/string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from typing import Any, Union, cast

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype

Expand Down Expand Up @@ -153,6 +154,9 @@ def restore(self, col: pd.Series) -> pd.Series:
if LooseVersion(pd.__version__) < "3.0.0":
return super().restore(col)
else:
if is_str_dtype(col.dtype) and not is_str_dtype(self.dtype):
# treat missing values as None for string dtype
col = col.replace({np.nan: None})
return col.astype(self.dtype)


Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,11 @@ def _to_corrected_pandas_type(dt: DataType) -> Optional[Any]:
return np.dtype("timedelta64[ns]")
else:
return np.dtype("timedelta64[us]")
elif type(dt) == StringType:
if LooseVersion(pd.__version__) < "3.0.0":
return None
else:
return pd.StringDtype(na_value=np.nan)
else:
return None

Expand Down Expand Up @@ -1007,7 +1012,11 @@ def correct_dtype(pser: pd.Series) -> pd.Series:
def correct_dtype(pser: pd.Series) -> pd.Series:
if not isinstance(pser.dtype, pd.DatetimeTZDtype):
pser = pser.astype(pandas_type, copy=False)
return _check_series_convert_timestamps_local_tz(pser, timezone=timezone)
pser = _check_series_convert_timestamps_local_tz(pser, timezone=timezone)
if LooseVersion(pd.__version__) < "3.0.0":
return pser
else:
return pser.astype(pandas_type, copy=False)

else:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import unittest

from pyspark.loose_version import LooseVersion
from pyspark.sql.types import (
StructType,
StructField,
Expand Down Expand Up @@ -113,7 +114,10 @@ def test_cached_schema_map_in_pandas(self):
def func(iterator):
for pdf in iterator:
assert isinstance(pdf, pd.DataFrame)
assert [d.name for d in list(pdf.dtypes)] == ["int32", "object"]
if LooseVersion(pd.__version__) < "3.0.0":
assert [d.name for d in list(pdf.dtypes)] == ["int32", "object"]
else:
assert [d.name for d in list(pdf.dtypes)] == ["int32", "str"]
yield pdf

schema = StructType(
Expand Down
77 changes: 49 additions & 28 deletions python/pyspark/sql/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import datetime
import unittest

from pyspark.loose_version import LooseVersion
from pyspark.sql.types import (
Row,
ArrayType,
Expand Down Expand Up @@ -95,18 +96,27 @@ def _to_pandas(self):

@unittest.skipIf(not have_pandas, pandas_requirement_message)
def test_to_pandas(self):
import pandas as pd
import numpy as np

pdf = self._to_pandas()
types = pdf.dtypes
self.assertEqual(types[0], np.int32)
self.assertEqual(types[1], object)
self.assertEqual(types[2], bool)
self.assertEqual(types[3], np.float32)
self.assertEqual(types[4], object) # datetime.date
self.assertEqual(types[5], "datetime64[ns]")
self.assertEqual(types[6], "datetime64[ns]")
self.assertEqual(types[7], "timedelta64[ns]")
self.assertEqual(types.iloc[0], np.int32)
if LooseVersion(pd.__version__) < "3.0.0":
self.assertEqual(types.iloc[1], object)
else:
self.assertEqual(types.iloc[1], pd.StringDtype(na_value=np.nan)) # datetime.date
self.assertEqual(types.iloc[2], bool)
self.assertEqual(types.iloc[3], np.float32)
self.assertEqual(types.iloc[4], object) # datetime.date
if LooseVersion(pd.__version__) < "3.0.0":
self.assertEqual(types.iloc[5], "datetime64[ns]")
self.assertEqual(types.iloc[6], "datetime64[ns]")
self.assertEqual(types.iloc[7], "timedelta64[ns]")
else:
self.assertEqual(types.iloc[5], "datetime64[us]")
self.assertEqual(types.iloc[6], "datetime64[us]")
self.assertEqual(types.iloc[7], "timedelta64[us]")

@unittest.skipIf(not have_pandas, pandas_requirement_message)
def test_to_pandas_with_duplicated_column_names(self):
Expand Down Expand Up @@ -155,22 +165,27 @@ def test_to_pandas_required_pandas_not_found(self):

@unittest.skipIf(not have_pandas, pandas_requirement_message)
def test_to_pandas_avoid_astype(self):
import pandas as pd
import numpy as np

schema = StructType().add("a", IntegerType()).add("b", StringType()).add("c", IntegerType())
data = [(1, "foo", 16777220), (None, "bar", None)]
df = self.spark.createDataFrame(data, schema)
types = df.toPandas().dtypes
self.assertEqual(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
self.assertEqual(types[1], object)
self.assertEqual(types[2], np.float64)
self.assertEqual(types.iloc[0], np.float64) # doesn't convert to np.int32 due to NaN value.
if LooseVersion(pd.__version__) < "3.0.0":
self.assertEqual(types.iloc[1], object)
else:
self.assertEqual(types.iloc[1], pd.StringDtype(na_value=np.nan))
self.assertEqual(types.iloc[2], np.float64)

@unittest.skipIf(not have_pandas, pandas_requirement_message)
def test_to_pandas_from_empty_dataframe(self):
is_arrow_enabled = [True, False]
for value in is_arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_empty_dataframe()
with self.subTest(arrow_enabled=value):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_empty_dataframe()

def check_to_pandas_from_empty_dataframe(self):
# SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes
Expand Down Expand Up @@ -199,13 +214,15 @@ def check_to_pandas_from_empty_dataframe(self):
def test_to_pandas_from_null_dataframe(self):
is_arrow_enabled = [True, False]
for value in is_arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_null_dataframe()
with self.subTest(arrow_enabled=value):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_null_dataframe()

def check_to_pandas_from_null_dataframe(self):
# SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes
# SPARK-30537 test that toPandas() on a dataframe with only nulls has correct dtypes
# using arrow
import pandas as pd
import numpy as np

sql = """
Expand All @@ -223,24 +240,28 @@ def check_to_pandas_from_null_dataframe(self):
"""
pdf = self.spark.sql(sql).toPandas()
types = pdf.dtypes
self.assertEqual(types[0], np.float64)
self.assertEqual(types[1], np.float64)
self.assertEqual(types[2], np.float64)
self.assertEqual(types[3], np.float64)
self.assertEqual(types[4], np.float32)
self.assertEqual(types[5], np.float64)
self.assertEqual(types[6], object)
self.assertEqual(types[7], object)
self.assertTrue(np.can_cast(np.datetime64, types[8]))
self.assertTrue(np.can_cast(np.datetime64, types[9]))
self.assertTrue(np.can_cast(np.timedelta64, types[10]))
self.assertEqual(types.iloc[0], np.float64)
self.assertEqual(types.iloc[1], np.float64)
self.assertEqual(types.iloc[2], np.float64)
self.assertEqual(types.iloc[3], np.float64)
self.assertEqual(types.iloc[4], np.float32)
self.assertEqual(types.iloc[5], np.float64)
self.assertEqual(types.iloc[6], object)
if LooseVersion(pd.__version__) < "3.0.0":
self.assertEqual(types.iloc[7], object)
else:
self.assertEqual(types.iloc[7], pd.StringDtype(na_value=np.nan))
self.assertTrue(np.can_cast(np.datetime64, types.iloc[8]))
self.assertTrue(np.can_cast(np.datetime64, types.iloc[9]))
self.assertTrue(np.can_cast(np.timedelta64, types.iloc[10]))

@unittest.skipIf(not have_pandas, pandas_requirement_message)
def test_to_pandas_from_mixed_dataframe(self):
is_arrow_enabled = [True, False]
for value in is_arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_mixed_dataframe()
with self.subTest(arrow_enabled=value):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
self.check_to_pandas_from_mixed_dataframe()

def check_to_pandas_from_mixed_dataframe(self):
# SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes
Expand Down