diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 2a71a94bde21e..b051be800c817 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -43,7 +43,7 @@ from pyspark.serializers import CloudPickleSerializer from pyspark.storagelevel import StorageLevel -from pyspark.sql.types import DataType +from pyspark.sql.types import DataType, StructType import pyspark.sql.connect.proto as proto from pyspark.sql.column import Column @@ -383,6 +383,40 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: return plan +class Parse(LogicalPlan): + """Parse a DataFrame with a single string column into a structured DataFrame.""" + + def __init__( + self, + child: "LogicalPlan", + format: "proto.Parse.ParseFormat.ValueType", + schema: Optional[str] = None, + options: Optional[Mapping[str, str]] = None, + ) -> None: + super().__init__(child) + self._format = format + self._schema = schema + self._options = options + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = self._create_proto_relation() + plan.parse.input.CopyFrom(self._child.plan(session)) + plan.parse.format = self._format + if self._schema is not None and len(self._schema) > 0: + plan.parse.schema.CopyFrom( + pyspark_types_to_proto_types( + StructType.fromDDL(self._schema) + if not self._schema.startswith("{") + else StructType.fromJson(json.loads(self._schema)) + ) + ) + if self._options is not None: + for k, v in self._options.items(): + plan.parse.options[k] = v + return plan + + class Read(LogicalPlan): def __init__( self, diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index c951a9caf6a56..1096e771e1d01 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -25,7 +25,9 @@ LogicalPlan, WriteOperation, WriteOperationV2, + Parse, ) +import pyspark.sql.connect.proto as proto from pyspark.sql.types import StructType from pyspark.sql.utils import to_str from pyspark.sql.readwriter import ( @@ -220,7 +222,32 @@ def json( ) if isinstance(path, str): path = [path] - return self.load(path=path, format="json", schema=schema) + if isinstance(path, list): + return self.load(path=path, format="json", schema=schema) + + from pyspark.sql.connect.dataframe import DataFrame + + if isinstance(path, DataFrame): + # Schema must be set explicitly here because the DataFrame path + # bypasses load(), which normally calls self.schema(schema). + if schema is not None: + self.schema(schema) + return self._df( + Parse( + child=path._plan, + format=proto.Parse.ParseFormat.PARSE_FORMAT_JSON, + schema=self._schema, + options=self._options, + ) + ) + raise PySparkTypeError( + errorClass="NOT_EXPECTED_TYPE", + messageParameters={ + "arg_name": "path", + "expected_type": "str, list, or DataFrame", + "arg_type": type(path).__name__, + }, + ) json.__doc__ = PySparkDataFrameReader.json.__doc__ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f332d97def61e..e6a0e8c7237a4 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -320,7 +320,7 @@ def load( def json( self, - path: Union[str, List[str], "RDD[str]"], + path: Union[str, List[str], "RDD[str]", "DataFrame"], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, @@ -361,11 +361,15 @@ def json( .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.1.0 + Supports DataFrame input. + Parameters ---------- - path : str, list or :class:`RDD` + path : str, list, :class:`RDD`, or :class:`DataFrame` string represents path to the JSON dataset, or a list of paths, - or RDD of Strings storing JSON objects. + or RDD of Strings storing JSON objects, + or a DataFrame with a single string column containing JSON strings. schema : :class:`pyspark.sql.types.StructType` or str, optional an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). @@ -434,6 +438,20 @@ def json( +----+---+ | Bob| 30| +----+---+ + + Example 4: Parse JSON from a DataFrame with a single string column. + + >>> json_df = spark.createDataFrame( + ... [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + ... schema="value STRING", + ... ) + >>> spark.read.json(json_df).sort("name").show() + +---+-----+ + |age| name| + +---+-----+ + | 25|Alice| + | 30| Bob| + +---+-----+ """ self._set_opts( schema=schema, @@ -486,12 +504,20 @@ def func(iterator: Iterable) -> Iterable: assert self._spark._jvm is not None jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) + + from pyspark.sql.dataframe import DataFrame + + if isinstance(path, DataFrame): + assert self._spark._jvm is not None + return self._df( + self._spark._jvm.PythonSQLUtils.jsonFromDataFrame(self._jreader, path._jdf) + ) else: raise PySparkTypeError( errorClass="NOT_EXPECTED_TYPE", messageParameters={ "arg_name": "path", - "expected_type": "str or list[RDD]", + "expected_type": "str, list, RDD, or DataFrame", "arg_type": type(path).__name__, }, ) diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py index fc27771fff74d..d5ef75d963527 100644 --- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py @@ -177,6 +177,41 @@ def test_csv(self): # Read the text file as a DataFrame. self.assert_eq(self.connect.read.csv(d).toPandas(), self.spark.read.csv(d).toPandas()) + def test_json_with_dataframe_input(self): + json_df = self.connect.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.connect.read.json(json_df) + expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_and_schema(self): + json_df = self.connect.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.connect.read.json(json_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_non_string_column(self): + int_df = self.connect.createDataFrame([(1,), (2,)], schema="value INT") + with self.assertRaises(Exception): + self.connect.read.json(int_df).collect() + + def test_json_with_dataframe_input_multiple_columns(self): + multi_df = self.connect.createDataFrame( + [("a", "b"), ("c", "d")], schema="col1 STRING, col2 STRING" + ) + with self.assertRaises(Exception): + self.connect.read.json(multi_df).collect() + + def test_json_with_dataframe_input_zero_columns(self): + empty_schema_df = self.connect.range(1).select() + with self.assertRaises(Exception): + self.connect.read.json(empty_schema_df).collect() + def test_multi_paths(self): # SPARK-42041: DataFrameReader should support list of paths diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index 1ceb74c1d907c..9084f23207b19 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -93,6 +93,41 @@ def test_linesep_json(self): finally: shutil.rmtree(tpath) + def test_json_with_dataframe_input(self): + json_df = self.spark.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.spark.read.json(json_df) + expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_and_schema(self): + json_df = self.spark.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.spark.read.json(json_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_non_string_column(self): + int_df = self.spark.createDataFrame([(1,), (2,)], schema="value INT") + with self.assertRaises(Exception): + self.spark.read.json(int_df).collect() + + def test_json_with_dataframe_input_multiple_columns(self): + multi_df = self.spark.createDataFrame( + [("a", "b"), ("c", "d")], schema="col1 STRING, col2 STRING" + ) + with self.assertRaises(Exception): + self.spark.read.json(multi_df).collect() + + def test_json_with_dataframe_input_zero_columns(self): + empty_schema_df = self.spark.range(1).select() + with self.assertRaises(Exception): + self.spark.read.json(empty_schema_df).collect() + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 37bcf995ee16d..ed55234844ac4 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1760,7 +1760,15 @@ class SparkConnectPlanner( localMap.foreach { case (key, value) => reader.option(key, value) } reader } - def ds: Dataset[String] = Dataset(session, transformRelation(rel.getInput))(Encoders.STRING) + def ds: Dataset[String] = { + val input = transformRelation(rel.getInput) + val inputSchema = Dataset.ofRows(session, input).schema + require(inputSchema.fields.length == 1, + s"Input DataFrame must have exactly one column, but got ${inputSchema.fields.length}") + require(inputSchema.fields.head.dataType == org.apache.spark.sql.types.StringType, + s"Input DataFrame column must be StringType, but got ${inputSchema.fields.head.dataType}") + Dataset(session, input)(Encoders.STRING) + } rel.getFormat match { case ParseFormat.PARSE_FORMAT_CSV => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 5607c98bf29e5..308d422368bd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer -import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession, TableArg} +import org.apache.spark.sql.{internal, Column, DataFrame, DataFrameReader, Encoders, Row, SparkSession, TableArg} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -193,6 +193,21 @@ private[sql] object PythonSQLUtils extends Logging { @scala.annotation.varargs def internalFn(name: String, inputs: Column*): Column = Column.internalFn(name, inputs: _*) + /** + * Parses a [[DataFrame]] containing JSON strings into a structured [[DataFrame]]. + * The input DataFrame must have exactly one column of StringType. + * This is used by PySpark to avoid manual Dataset[String] conversion on the Python side. + */ + def jsonFromDataFrame( + reader: DataFrameReader, + df: DataFrame): DataFrame = { + require(df.schema.fields.length == 1, + s"Input DataFrame must have exactly one column, but got ${df.schema.fields.length}") + require(df.schema.fields.head.dataType == org.apache.spark.sql.types.StringType, + s"Input DataFrame column must be StringType, but got ${df.schema.fields.head.dataType}") + reader.json(df.as(Encoders.STRING)) + } + def cleanupPythonWorkerLogs(sessionUUID: String, sparkContext: SparkContext): Unit = { if (!sparkContext.isStopped) { try {