diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 2a71a94bde21e..150dc6484de8b 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -43,7 +43,7 @@ from pyspark.serializers import CloudPickleSerializer from pyspark.storagelevel import StorageLevel -from pyspark.sql.types import DataType +from pyspark.sql.types import DataType, StructType import pyspark.sql.connect.proto as proto from pyspark.sql.column import Column @@ -383,6 +383,40 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: return plan +class Parse(LogicalPlan): + """Parse a DataFrame with a single string column into a structured DataFrame.""" + + def __init__( + self, + child: "LogicalPlan", + format: int, + schema: Optional[str] = None, + options: Optional[Mapping[str, str]] = None, + ) -> None: + super().__init__(child) + self._format = format + self._schema = schema + self._options = options + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = self._create_proto_relation() + plan.parse.input.CopyFrom(self._child.plan(session)) + plan.parse.format = self._format + if self._schema is not None and len(self._schema) > 0: + plan.parse.schema.CopyFrom( + pyspark_types_to_proto_types( + StructType.fromDDL(self._schema) + if not self._schema.startswith("{") + else StructType.fromJson(self._schema) + ) + ) + if self._options is not None: + for k, v in self._options.items(): + plan.parse.options[k] = v + return plan + + class Read(LogicalPlan): def __init__( self, diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index c951a9caf6a56..fb97a8f4af4fe 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -25,7 +25,9 @@ LogicalPlan, WriteOperation, WriteOperationV2, + Parse, ) +import pyspark.sql.connect.proto as proto from pyspark.sql.types import StructType from pyspark.sql.utils import to_str from pyspark.sql.readwriter import ( @@ -220,7 +222,28 @@ def json( ) if isinstance(path, str): path = [path] - return self.load(path=path, format="json", schema=schema) + if isinstance(path, list): + return self.load(path=path, format="json", schema=schema) + + from pyspark.sql.connect.dataframe import DataFrame + + if isinstance(path, DataFrame): + return self._df( + Parse( + child=path._plan, + format=proto.Parse.ParseFormat.PARSE_FORMAT_JSON, + schema=self._schema, + options=self._options, + ) + ) + raise PySparkTypeError( + errorClass="NOT_EXPECTED_TYPE", + messageParameters={ + "arg_name": "path", + "expected_type": "str, list, or DataFrame", + "arg_type": type(path).__name__, + }, + ) json.__doc__ = PySparkDataFrameReader.json.__doc__ @@ -344,7 +367,28 @@ def csv( ) if isinstance(path, str): path = [path] - return self.load(path=path, format="csv", schema=schema) + if isinstance(path, list): + return self.load(path=path, format="csv", schema=schema) + + from pyspark.sql.connect.dataframe import DataFrame + + if isinstance(path, DataFrame): + return self._df( + Parse( + child=path._plan, + format=proto.Parse.ParseFormat.PARSE_FORMAT_CSV, + schema=self._schema, + options=self._options, + ) + ) + raise PySparkTypeError( + errorClass="NOT_EXPECTED_TYPE", + messageParameters={ + "arg_name": "path", + "expected_type": "str, list, or DataFrame", + "arg_type": type(path).__name__, + }, + ) csv.__doc__ = PySparkDataFrameReader.csv.__doc__ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f332d97def61e..4c36868a510bf 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -320,7 +320,7 @@ def load( def json( self, - path: Union[str, List[str], "RDD[str]"], + path: Union[str, List[str], "RDD[str]", "DataFrame"], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, @@ -361,11 +361,15 @@ def json( .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.1.0 + Supports DataFrame input. + Parameters ---------- - path : str, list or :class:`RDD` + path : str, list, :class:`RDD`, or :class:`DataFrame` string represents path to the JSON dataset, or a list of paths, - or RDD of Strings storing JSON objects. + or RDD of Strings storing JSON objects, + or a DataFrame with a single string column containing JSON strings. schema : :class:`pyspark.sql.types.StructType` or str, optional an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). @@ -434,6 +438,20 @@ def json( +----+---+ | Bob| 30| +----+---+ + + Example 4: Parse JSON from a DataFrame with a single string column. + + >>> json_df = spark.createDataFrame( + ... [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + ... schema="value STRING", + ... ) + >>> spark.read.json(json_df).sort("name").show() + +---+-----+ + |age| name| + +---+-----+ + | 25|Alice| + | 30| Bob| + +---+-----+ """ self._set_opts( schema=schema, @@ -486,12 +504,19 @@ def func(iterator: Iterable) -> Iterable: assert self._spark._jvm is not None jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) + + from pyspark.sql.dataframe import DataFrame + + if isinstance(path, DataFrame): + string_encoder = self._spark._jvm.Encoders.STRING() + jdataset = getattr(path._jdf, "as")(string_encoder) + return self._df(self._jreader.json(jdataset)) else: raise PySparkTypeError( errorClass="NOT_EXPECTED_TYPE", messageParameters={ "arg_name": "path", - "expected_type": "str or list[RDD]", + "expected_type": "str, list, RDD, or DataFrame", "arg_type": type(path).__name__, }, ) @@ -742,7 +767,7 @@ def text( def csv( self, - path: PathOrPaths, + path: Union[str, List[str], "RDD[str]", "DataFrame"], schema: Optional[Union[StructType, str]] = None, sep: Optional[str] = None, encoding: Optional[str] = None, @@ -788,11 +813,15 @@ def csv( .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.1.0 + Supports DataFrame input. + Parameters ---------- - path : str or list + path : str, list, :class:`RDD`, or :class:`DataFrame` string, or list of strings, for input path(s), - or RDD of Strings storing CSV rows. + or RDD of Strings storing CSV rows, + or a DataFrame with a single string column containing CSV rows. schema : :class:`pyspark.sql.types.StructType` or str, optional an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). @@ -889,19 +918,26 @@ def func(iterator): jrdd.rdd(), self._spark._jvm.Encoders.STRING() ) return self._df(self._jreader.csv(jdataset)) + + from pyspark.sql.dataframe import DataFrame + + if isinstance(path, DataFrame): + string_encoder = self._spark._jvm.Encoders.STRING() + jdataset = getattr(path._jdf, "as")(string_encoder) + return self._df(self._jreader.csv(jdataset)) else: raise PySparkTypeError( errorClass="NOT_EXPECTED_TYPE", messageParameters={ "arg_name": "path", - "expected_type": "str or list[RDD]", + "expected_type": "str, list, RDD, or DataFrame", "arg_type": type(path).__name__, }, ) def xml( self, - path: Union[str, List[str], "RDD[str]"], + path: Union[str, List[str], "RDD[str]", "DataFrame"], rowTag: Optional[str] = None, schema: Optional[Union[StructType, str]] = None, excludeAttribute: Optional[Union[bool, str]] = None, @@ -929,11 +965,15 @@ def xml( .. versionadded:: 4.0.0 + .. versionchanged:: 4.1.0 + Supports DataFrame input. + Parameters ---------- - path : str, list or :class:`RDD` + path : str, list, :class:`RDD`, or :class:`DataFrame` string, or list of strings, for input path(s), - or RDD of Strings storing XML rows. + or RDD of Strings storing XML rows, + or a DataFrame with a single string column containing XML rows. schema : :class:`pyspark.sql.types.StructType` or str, optional an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). @@ -1017,12 +1057,19 @@ def func(iterator: Iterable) -> Iterable: jrdd.rdd(), self._spark._jvm.Encoders.STRING() ) return self._df(self._jreader.xml(jdataset)) + + from pyspark.sql.dataframe import DataFrame + + if isinstance(path, DataFrame): + string_encoder = self._spark._jvm.Encoders.STRING() + jdataset = getattr(path._jdf, "as")(string_encoder) + return self._df(self._jreader.xml(jdataset)) else: raise PySparkTypeError( errorClass="NOT_EXPECTED_TYPE", messageParameters={ "arg_name": "path", - "expected_type": "str or list[RDD]", + "expected_type": "str, list, RDD, or DataFrame", "arg_type": type(path).__name__, }, ) diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py index fc27771fff74d..b77232bd3b554 100644 --- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py @@ -177,6 +177,34 @@ def test_csv(self): # Read the text file as a DataFrame. self.assert_eq(self.connect.read.csv(d).toPandas(), self.spark.read.csv(d).toPandas()) + def test_json_with_dataframe_input(self): + json_df = self.connect.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.connect.read.json(json_df) + expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_csv_with_dataframe_input(self): + csv_df = self.connect.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.connect.read.csv(csv_df) + expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] + self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + + def test_xml_with_dataframe_input_not_supported(self): + # XML DataFrame input is not supported on Spark Connect because + # the Parse proto only defines CSV and JSON formats. + xml_df = self.connect.createDataFrame( + [("Alice25",)], + schema="value STRING", + ) + with self.assertRaises(Exception): + self.connect.read.xml(xml_df, rowTag="person") + def test_multi_paths(self): # SPARK-42041: DataFrameReader should support list of paths diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index 1ceb74c1d907c..1996860a75452 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -93,6 +93,24 @@ def test_linesep_json(self): finally: shutil.rmtree(tpath) + def test_json_with_dataframe_input(self): + json_df = self.spark.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.spark.read.json(json_df) + expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_and_schema(self): + json_df = self.spark.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.spark.read.json(json_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True @@ -116,6 +134,15 @@ def test_ignorewhitespace_csv(self): self.assertEqual(readback.collect(), expected) shutil.rmtree(tmpPath) + def test_csv_with_dataframe_input(self): + csv_df = self.spark.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.spark.read.csv(csv_df) + expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] + self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + def test_xml(self): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -157,6 +184,18 @@ def test_xml(self): shutil.rmtree(tmpPath) shutil.rmtree(xsdPath) + def test_xml_with_dataframe_input(self): + xml_df = self.spark.createDataFrame( + [ + ("Alice25",), + ("Bob30",), + ], + schema="value STRING", + ) + result = self.spark.read.xml(xml_df, rowTag="person") + expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + def test_xml_sampling_ratio(self): rdd = self.spark.sparkContext.range(0, 100, 1, 1).map( lambda x: "

0.1

" if x == 1 else "

%s

" % str(x)