diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 2a71a94bde21e..150dc6484de8b 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -43,7 +43,7 @@
from pyspark.serializers import CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
-from pyspark.sql.types import DataType
+from pyspark.sql.types import DataType, StructType
import pyspark.sql.connect.proto as proto
from pyspark.sql.column import Column
@@ -383,6 +383,40 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan
+class Parse(LogicalPlan):
+ """Parse a DataFrame with a single string column into a structured DataFrame."""
+
+ def __init__(
+ self,
+ child: "LogicalPlan",
+ format: int,
+ schema: Optional[str] = None,
+ options: Optional[Mapping[str, str]] = None,
+ ) -> None:
+ super().__init__(child)
+ self._format = format
+ self._schema = schema
+ self._options = options
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+ plan = self._create_proto_relation()
+ plan.parse.input.CopyFrom(self._child.plan(session))
+ plan.parse.format = self._format
+ if self._schema is not None and len(self._schema) > 0:
+ plan.parse.schema.CopyFrom(
+ pyspark_types_to_proto_types(
+ StructType.fromDDL(self._schema)
+ if not self._schema.startswith("{")
+ else StructType.fromJson(self._schema)
+ )
+ )
+ if self._options is not None:
+ for k, v in self._options.items():
+ plan.parse.options[k] = v
+ return plan
+
+
class Read(LogicalPlan):
def __init__(
self,
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index c951a9caf6a56..fb97a8f4af4fe 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -25,7 +25,9 @@
LogicalPlan,
WriteOperation,
WriteOperationV2,
+ Parse,
)
+import pyspark.sql.connect.proto as proto
from pyspark.sql.types import StructType
from pyspark.sql.utils import to_str
from pyspark.sql.readwriter import (
@@ -220,7 +222,28 @@ def json(
)
if isinstance(path, str):
path = [path]
- return self.load(path=path, format="json", schema=schema)
+ if isinstance(path, list):
+ return self.load(path=path, format="json", schema=schema)
+
+ from pyspark.sql.connect.dataframe import DataFrame
+
+ if isinstance(path, DataFrame):
+ return self._df(
+ Parse(
+ child=path._plan,
+ format=proto.Parse.ParseFormat.PARSE_FORMAT_JSON,
+ schema=self._schema,
+ options=self._options,
+ )
+ )
+ raise PySparkTypeError(
+ errorClass="NOT_EXPECTED_TYPE",
+ messageParameters={
+ "arg_name": "path",
+ "expected_type": "str, list, or DataFrame",
+ "arg_type": type(path).__name__,
+ },
+ )
json.__doc__ = PySparkDataFrameReader.json.__doc__
@@ -344,7 +367,28 @@ def csv(
)
if isinstance(path, str):
path = [path]
- return self.load(path=path, format="csv", schema=schema)
+ if isinstance(path, list):
+ return self.load(path=path, format="csv", schema=schema)
+
+ from pyspark.sql.connect.dataframe import DataFrame
+
+ if isinstance(path, DataFrame):
+ return self._df(
+ Parse(
+ child=path._plan,
+ format=proto.Parse.ParseFormat.PARSE_FORMAT_CSV,
+ schema=self._schema,
+ options=self._options,
+ )
+ )
+ raise PySparkTypeError(
+ errorClass="NOT_EXPECTED_TYPE",
+ messageParameters={
+ "arg_name": "path",
+ "expected_type": "str, list, or DataFrame",
+ "arg_type": type(path).__name__,
+ },
+ )
csv.__doc__ = PySparkDataFrameReader.csv.__doc__
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index f332d97def61e..4c36868a510bf 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -320,7 +320,7 @@ def load(
def json(
self,
- path: Union[str, List[str], "RDD[str]"],
+ path: Union[str, List[str], "RDD[str]", "DataFrame"],
schema: Optional[Union[StructType, str]] = None,
primitivesAsString: Optional[Union[bool, str]] = None,
prefersDecimal: Optional[Union[bool, str]] = None,
@@ -361,11 +361,15 @@ def json(
.. versionchanged:: 3.4.0
Supports Spark Connect.
+ .. versionchanged:: 4.1.0
+ Supports DataFrame input.
+
Parameters
----------
- path : str, list or :class:`RDD`
+ path : str, list, :class:`RDD`, or :class:`DataFrame`
string represents path to the JSON dataset, or a list of paths,
- or RDD of Strings storing JSON objects.
+ or RDD of Strings storing JSON objects,
+ or a DataFrame with a single string column containing JSON strings.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema or
a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
@@ -434,6 +438,20 @@ def json(
+----+---+
| Bob| 30|
+----+---+
+
+ Example 4: Parse JSON from a DataFrame with a single string column.
+
+ >>> json_df = spark.createDataFrame(
+ ... [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)],
+ ... schema="value STRING",
+ ... )
+ >>> spark.read.json(json_df).sort("name").show()
+ +---+-----+
+ |age| name|
+ +---+-----+
+ | 25|Alice|
+ | 30| Bob|
+ +---+-----+
"""
self._set_opts(
schema=schema,
@@ -486,12 +504,19 @@ def func(iterator: Iterable) -> Iterable:
assert self._spark._jvm is not None
jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
return self._df(self._jreader.json(jrdd))
+
+ from pyspark.sql.dataframe import DataFrame
+
+ if isinstance(path, DataFrame):
+ string_encoder = self._spark._jvm.Encoders.STRING()
+ jdataset = getattr(path._jdf, "as")(string_encoder)
+ return self._df(self._jreader.json(jdataset))
else:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
- "expected_type": "str or list[RDD]",
+ "expected_type": "str, list, RDD, or DataFrame",
"arg_type": type(path).__name__,
},
)
@@ -742,7 +767,7 @@ def text(
def csv(
self,
- path: PathOrPaths,
+ path: Union[str, List[str], "RDD[str]", "DataFrame"],
schema: Optional[Union[StructType, str]] = None,
sep: Optional[str] = None,
encoding: Optional[str] = None,
@@ -788,11 +813,15 @@ def csv(
.. versionchanged:: 3.4.0
Supports Spark Connect.
+ .. versionchanged:: 4.1.0
+ Supports DataFrame input.
+
Parameters
----------
- path : str or list
+ path : str, list, :class:`RDD`, or :class:`DataFrame`
string, or list of strings, for input path(s),
- or RDD of Strings storing CSV rows.
+ or RDD of Strings storing CSV rows,
+ or a DataFrame with a single string column containing CSV rows.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
@@ -889,19 +918,26 @@ def func(iterator):
jrdd.rdd(), self._spark._jvm.Encoders.STRING()
)
return self._df(self._jreader.csv(jdataset))
+
+ from pyspark.sql.dataframe import DataFrame
+
+ if isinstance(path, DataFrame):
+ string_encoder = self._spark._jvm.Encoders.STRING()
+ jdataset = getattr(path._jdf, "as")(string_encoder)
+ return self._df(self._jreader.csv(jdataset))
else:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
- "expected_type": "str or list[RDD]",
+ "expected_type": "str, list, RDD, or DataFrame",
"arg_type": type(path).__name__,
},
)
def xml(
self,
- path: Union[str, List[str], "RDD[str]"],
+ path: Union[str, List[str], "RDD[str]", "DataFrame"],
rowTag: Optional[str] = None,
schema: Optional[Union[StructType, str]] = None,
excludeAttribute: Optional[Union[bool, str]] = None,
@@ -929,11 +965,15 @@ def xml(
.. versionadded:: 4.0.0
+ .. versionchanged:: 4.1.0
+ Supports DataFrame input.
+
Parameters
----------
- path : str, list or :class:`RDD`
+ path : str, list, :class:`RDD`, or :class:`DataFrame`
string, or list of strings, for input path(s),
- or RDD of Strings storing XML rows.
+ or RDD of Strings storing XML rows,
+ or a DataFrame with a single string column containing XML rows.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
@@ -1017,12 +1057,19 @@ def func(iterator: Iterable) -> Iterable:
jrdd.rdd(), self._spark._jvm.Encoders.STRING()
)
return self._df(self._jreader.xml(jdataset))
+
+ from pyspark.sql.dataframe import DataFrame
+
+ if isinstance(path, DataFrame):
+ string_encoder = self._spark._jvm.Encoders.STRING()
+ jdataset = getattr(path._jdf, "as")(string_encoder)
+ return self._df(self._jreader.xml(jdataset))
else:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
- "expected_type": "str or list[RDD]",
+ "expected_type": "str, list, RDD, or DataFrame",
"arg_type": type(path).__name__,
},
)
diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py
index fc27771fff74d..b77232bd3b554 100644
--- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py
+++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py
@@ -177,6 +177,34 @@ def test_csv(self):
# Read the text file as a DataFrame.
self.assert_eq(self.connect.read.csv(d).toPandas(), self.spark.read.csv(d).toPandas())
+ def test_json_with_dataframe_input(self):
+ json_df = self.connect.createDataFrame(
+ [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)],
+ schema="value STRING",
+ )
+ result = self.connect.read.json(json_df)
+ expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")]
+ self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected)
+
+ def test_csv_with_dataframe_input(self):
+ csv_df = self.connect.createDataFrame(
+ [("Alice,25",), ("Bob,30",)],
+ schema="value STRING",
+ )
+ result = self.connect.read.csv(csv_df)
+ expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")]
+ self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected)
+
+ def test_xml_with_dataframe_input_not_supported(self):
+ # XML DataFrame input is not supported on Spark Connect because
+ # the Parse proto only defines CSV and JSON formats.
+ xml_df = self.connect.createDataFrame(
+ [("