Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

from pyspark.serializers import CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.sql.types import DataType
from pyspark.sql.types import DataType, StructType

import pyspark.sql.connect.proto as proto
from pyspark.sql.column import Column
Expand Down Expand Up @@ -383,6 +383,40 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class Parse(LogicalPlan):
"""Parse a DataFrame with a single string column into a structured DataFrame."""

def __init__(
self,
child: "LogicalPlan",
format: int,
schema: Optional[str] = None,
options: Optional[Mapping[str, str]] = None,
) -> None:
super().__init__(child)
self._format = format
self._schema = schema
self._options = options

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.parse.input.CopyFrom(self._child.plan(session))
plan.parse.format = self._format
if self._schema is not None and len(self._schema) > 0:
plan.parse.schema.CopyFrom(
pyspark_types_to_proto_types(
StructType.fromDDL(self._schema)
if not self._schema.startswith("{")
else StructType.fromJson(self._schema)
)
)
if self._options is not None:
for k, v in self._options.items():
plan.parse.options[k] = v
return plan


class Read(LogicalPlan):
def __init__(
self,
Expand Down
48 changes: 46 additions & 2 deletions python/pyspark/sql/connect/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
LogicalPlan,
WriteOperation,
WriteOperationV2,
Parse,
)
import pyspark.sql.connect.proto as proto
from pyspark.sql.types import StructType
from pyspark.sql.utils import to_str
from pyspark.sql.readwriter import (
Expand Down Expand Up @@ -220,7 +222,28 @@ def json(
)
if isinstance(path, str):
path = [path]
return self.load(path=path, format="json", schema=schema)
if isinstance(path, list):
return self.load(path=path, format="json", schema=schema)

from pyspark.sql.connect.dataframe import DataFrame

if isinstance(path, DataFrame):
return self._df(
Parse(
child=path._plan,
format=proto.Parse.ParseFormat.PARSE_FORMAT_JSON,
schema=self._schema,
options=self._options,
)
)
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
"expected_type": "str, list, or DataFrame",
"arg_type": type(path).__name__,
},
)

json.__doc__ = PySparkDataFrameReader.json.__doc__

Expand Down Expand Up @@ -344,7 +367,28 @@ def csv(
)
if isinstance(path, str):
path = [path]
return self.load(path=path, format="csv", schema=schema)
if isinstance(path, list):
return self.load(path=path, format="csv", schema=schema)

from pyspark.sql.connect.dataframe import DataFrame

if isinstance(path, DataFrame):
return self._df(
Parse(
child=path._plan,
format=proto.Parse.ParseFormat.PARSE_FORMAT_CSV,
schema=self._schema,
options=self._options,
)
)
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
"expected_type": "str, list, or DataFrame",
"arg_type": type(path).__name__,
},
)

csv.__doc__ = PySparkDataFrameReader.csv.__doc__

Expand Down
71 changes: 59 additions & 12 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def load(

def json(
self,
path: Union[str, List[str], "RDD[str]"],
path: Union[str, List[str], "RDD[str]", "DataFrame"],
schema: Optional[Union[StructType, str]] = None,
primitivesAsString: Optional[Union[bool, str]] = None,
prefersDecimal: Optional[Union[bool, str]] = None,
Expand Down Expand Up @@ -361,11 +361,15 @@ def json(
.. versionchanged:: 3.4.0
Supports Spark Connect.

.. versionchanged:: 4.1.0
Supports DataFrame input.

Parameters
----------
path : str, list or :class:`RDD`
path : str, list, :class:`RDD`, or :class:`DataFrame`
string represents path to the JSON dataset, or a list of paths,
or RDD of Strings storing JSON objects.
or RDD of Strings storing JSON objects,
or a DataFrame with a single string column containing JSON strings.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema or
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
Expand Down Expand Up @@ -434,6 +438,20 @@ def json(
+----+---+
| Bob| 30|
+----+---+

Example 4: Parse JSON from a DataFrame with a single string column.

>>> json_df = spark.createDataFrame(
... [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)],
... schema="value STRING",
... )
>>> spark.read.json(json_df).sort("name").show()
+---+-----+
|age| name|
+---+-----+
| 25|Alice|
| 30| Bob|
+---+-----+
"""
self._set_opts(
schema=schema,
Expand Down Expand Up @@ -486,12 +504,19 @@ def func(iterator: Iterable) -> Iterable:
assert self._spark._jvm is not None
jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
return self._df(self._jreader.json(jrdd))

from pyspark.sql.dataframe import DataFrame

if isinstance(path, DataFrame):
string_encoder = self._spark._jvm.Encoders.STRING()
jdataset = getattr(path._jdf, "as")(string_encoder)
return self._df(self._jreader.json(jdataset))
else:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
"expected_type": "str or list[RDD]",
"expected_type": "str, list, RDD, or DataFrame",
"arg_type": type(path).__name__,
},
)
Expand Down Expand Up @@ -742,7 +767,7 @@ def text(

def csv(
self,
path: PathOrPaths,
path: Union[str, List[str], "RDD[str]", "DataFrame"],
schema: Optional[Union[StructType, str]] = None,
sep: Optional[str] = None,
encoding: Optional[str] = None,
Expand Down Expand Up @@ -788,11 +813,15 @@ def csv(
.. versionchanged:: 3.4.0
Supports Spark Connect.

.. versionchanged:: 4.1.0
Supports DataFrame input.

Parameters
----------
path : str or list
path : str, list, :class:`RDD`, or :class:`DataFrame`
string, or list of strings, for input path(s),
or RDD of Strings storing CSV rows.
or RDD of Strings storing CSV rows,
or a DataFrame with a single string column containing CSV rows.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
Expand Down Expand Up @@ -889,19 +918,26 @@ def func(iterator):
jrdd.rdd(), self._spark._jvm.Encoders.STRING()
)
return self._df(self._jreader.csv(jdataset))

from pyspark.sql.dataframe import DataFrame

if isinstance(path, DataFrame):
string_encoder = self._spark._jvm.Encoders.STRING()
jdataset = getattr(path._jdf, "as")(string_encoder)
return self._df(self._jreader.csv(jdataset))
else:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
"expected_type": "str or list[RDD]",
"expected_type": "str, list, RDD, or DataFrame",
"arg_type": type(path).__name__,
},
)

def xml(
self,
path: Union[str, List[str], "RDD[str]"],
path: Union[str, List[str], "RDD[str]", "DataFrame"],
rowTag: Optional[str] = None,
schema: Optional[Union[StructType, str]] = None,
excludeAttribute: Optional[Union[bool, str]] = None,
Expand Down Expand Up @@ -929,11 +965,15 @@ def xml(

.. versionadded:: 4.0.0

.. versionchanged:: 4.1.0
Supports DataFrame input.

Parameters
----------
path : str, list or :class:`RDD`
path : str, list, :class:`RDD`, or :class:`DataFrame`
string, or list of strings, for input path(s),
or RDD of Strings storing XML rows.
or RDD of Strings storing XML rows,
or a DataFrame with a single string column containing XML rows.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
Expand Down Expand Up @@ -1017,12 +1057,19 @@ def func(iterator: Iterable) -> Iterable:
jrdd.rdd(), self._spark._jvm.Encoders.STRING()
)
return self._df(self._jreader.xml(jdataset))

from pyspark.sql.dataframe import DataFrame

if isinstance(path, DataFrame):
string_encoder = self._spark._jvm.Encoders.STRING()
jdataset = getattr(path._jdf, "as")(string_encoder)
return self._df(self._jreader.xml(jdataset))
else:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
"expected_type": "str or list[RDD]",
"expected_type": "str, list, RDD, or DataFrame",
"arg_type": type(path).__name__,
},
)
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,34 @@ def test_csv(self):
# Read the text file as a DataFrame.
self.assert_eq(self.connect.read.csv(d).toPandas(), self.spark.read.csv(d).toPandas())

def test_json_with_dataframe_input(self):
json_df = self.connect.createDataFrame(
[('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)],
schema="value STRING",
)
result = self.connect.read.json(json_df)
expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")]
self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected)

def test_csv_with_dataframe_input(self):
csv_df = self.connect.createDataFrame(
[("Alice,25",), ("Bob,30",)],
schema="value STRING",
)
result = self.connect.read.csv(csv_df)
expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")]
self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected)

def test_xml_with_dataframe_input_not_supported(self):
# XML DataFrame input is not supported on Spark Connect because
# the Parse proto only defines CSV and JSON formats.
xml_df = self.connect.createDataFrame(
[("<person><name>Alice</name><age>25</age></person>",)],
schema="value STRING",
)
with self.assertRaises(Exception):
self.connect.read.xml(xml_df, rowTag="person")

def test_multi_paths(self):
# SPARK-42041: DataFrameReader should support list of paths

Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/tests/test_datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ def test_linesep_json(self):
finally:
shutil.rmtree(tpath)

def test_json_with_dataframe_input(self):
json_df = self.spark.createDataFrame(
[('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)],
schema="value STRING",
)
result = self.spark.read.json(json_df)
expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")]
self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected)

def test_json_with_dataframe_input_and_schema(self):
json_df = self.spark.createDataFrame(
[('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)],
schema="value STRING",
)
result = self.spark.read.json(json_df, schema="name STRING, age INT")
expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)]
self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected)

def test_multiline_csv(self):
ages_newlines = self.spark.read.csv(
"python/test_support/sql/ages_newlines.csv", multiLine=True
Expand All @@ -116,6 +134,15 @@ def test_ignorewhitespace_csv(self):
self.assertEqual(readback.collect(), expected)
shutil.rmtree(tmpPath)

def test_csv_with_dataframe_input(self):
csv_df = self.spark.createDataFrame(
[("Alice,25",), ("Bob,30",)],
schema="value STRING",
)
result = self.spark.read.csv(csv_df)
expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")]
self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected)

def test_xml(self):
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
Expand Down Expand Up @@ -157,6 +184,18 @@ def test_xml(self):
shutil.rmtree(tmpPath)
shutil.rmtree(xsdPath)

def test_xml_with_dataframe_input(self):
xml_df = self.spark.createDataFrame(
[
("<person><name>Alice</name><age>25</age></person>",),
("<person><name>Bob</name><age>30</age></person>",),
],
schema="value STRING",
)
result = self.spark.read.xml(xml_df, rowTag="person")
expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")]
self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected)

def test_xml_sampling_ratio(self):
rdd = self.spark.sparkContext.range(0, 100, 1, 1).map(
lambda x: "<p><a>0.1</a></p>" if x == 1 else "<p><a>%s</a></p>" % str(x)
Expand Down