Skip to content

Commit 3e22c47

Browse files
committed
[SPARK-48960][CONNECT] Makes spark-submit works with Spark connect
### What changes were proposed in this pull request? This PR proposes to add the support of `--remote` at `bin/spark-submit` so it can use Spark Connect easily. This PR inclues: - Make `bin/spark-submit` working with Scala Spark Connect client - Pass `--conf` and loaded configurations to both Scala and Python Spark Connect clients ### Why are the changes needed? `bin/pyspark --remote` already works. We should also make `bin/spark-submit` works in order for end users to try Spark Connect out and to have the consistent way. ### Does this PR introduce _any_ user-facing change? Yes, - `bin/spark-submit` supports `--remote` option in Scala. - `bin/spark-submit` supports `--conf` and loaded Spark configurations to pass to the clients in Scala and Python ### How was this patch tested? Python: ```bash echo "from pyspark.sql import SparkSession;spark = SparkSession.builder.getOrCreate();assert 'connect' in str(type(spark));assert spark.range(1).first()[0] == 0" > test.py ``` ```bash ./bin/spark-submit --name "testApp" --remote "local" test.py ``` Scala: https://github.com/HyukjinKwon/spark-connect-example ```bash git clone https://github.com/HyukjinKwon/spark-connect-example cd spark-connect-example build/sbt package cd .. git clone https://github.com/apache/spark.git cd spark build/sbt package sbin/start-connect-server.sh bin/spark-submit --name "testApp" --remote "sc://localhost" --class com.hyukjinkwon.SparkConnectExample ../spark-connect-example/target/scala-2.13/spark-connect-example_2.13-0.0.1.jar ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47434 from HyukjinKwon/SPARK-48960. Authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 90a236e commit 3e22c47

File tree

7 files changed

+149
-72
lines changed

7 files changed

+149
-72
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
package org.apache.spark.sql
1818

1919
import java.net.URI
20+
import java.nio.file.{Files, Paths}
2021
import java.util.concurrent.ConcurrentHashMap
2122
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
2223

2324
import scala.jdk.CollectionConverters._
2425
import scala.reflect.runtime.universe.TypeTag
26+
import scala.util.Try
2527

2628
import com.google.common.cache.{CacheBuilder, CacheLoader}
2729
import io.grpc.ClientInterceptor
@@ -591,6 +593,10 @@ class SparkSession private[sql] (
591593
object SparkSession extends Logging {
592594
private val MAX_CACHED_SESSIONS = 100
593595
private val planIdGenerator = new AtomicLong
596+
private var server: Option[Process] = None
597+
private[sql] val sparkOptions = sys.props.filter { p =>
598+
p._1.startsWith("spark.") && p._2.nonEmpty
599+
}.toMap
594600

595601
private val sessions = CacheBuilder
596602
.newBuilder()
@@ -623,6 +629,51 @@ object SparkSession extends Logging {
623629
}
624630
}
625631

632+
/**
633+
* Create a new Spark Connect server to connect locally.
634+
*/
635+
private[sql] def withLocalConnectServer[T](f: => T): T = {
636+
synchronized {
637+
val remoteString = sparkOptions
638+
.get("spark.remote")
639+
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
640+
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
641+
642+
val maybeConnectScript =
643+
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
644+
645+
if (server.isEmpty &&
646+
remoteString.exists(_.startsWith("local")) &&
647+
maybeConnectScript.exists(Files.exists(_))) {
648+
server = Some {
649+
val args =
650+
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
651+
.filter(p => !p._1.startsWith("spark.remote"))
652+
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
653+
val pb = new ProcessBuilder(args: _*)
654+
// So don't exclude spark-sql jar in classpath
655+
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
656+
pb.start()
657+
}
658+
659+
// Let the server start. We will directly request to set the configurations
660+
// and this sleep makes less noisy with retries.
661+
Thread.sleep(2000L)
662+
System.setProperty("spark.remote", "sc://localhost")
663+
664+
// scalastyle:off runtimeaddshutdownhook
665+
Runtime.getRuntime.addShutdownHook(new Thread() {
666+
override def run(): Unit = if (server.isDefined) {
667+
new ProcessBuilder(maybeConnectScript.get.toString)
668+
.start()
669+
}
670+
})
671+
// scalastyle:on runtimeaddshutdownhook
672+
}
673+
}
674+
f
675+
}
676+
626677
/**
627678
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
628679
*/
@@ -765,6 +816,16 @@ object SparkSession extends Logging {
765816
}
766817

767818
private def applyOptions(session: SparkSession): Unit = {
819+
// Only attempts to set Spark SQL configurations.
820+
// If the configurations are static, it might throw an exception so
821+
// simply ignore it for now.
822+
sparkOptions
823+
.filter { case (k, _) =>
824+
k.startsWith("spark.sql.")
825+
}
826+
.foreach { case (key, value) =>
827+
Try(session.conf.set(key, value))
828+
}
768829
options.foreach { case (key, value) =>
769830
session.conf.set(key, value)
770831
}
@@ -787,7 +848,7 @@ object SparkSession extends Logging {
787848
*
788849
* @since 3.5.0
789850
*/
790-
def create(): SparkSession = {
851+
def create(): SparkSession = withLocalConnectServer {
791852
val session = tryCreateSessionFromClient()
792853
.getOrElse(SparkSession.this.create(builder.configuration))
793854
setDefaultAndActiveSession(session)
@@ -807,7 +868,7 @@ object SparkSession extends Logging {
807868
*
808869
* @since 3.5.0
809870
*/
810-
def getOrCreate(): SparkSession = {
871+
def getOrCreate(): SparkSession = withLocalConnectServer {
811872
val session = tryCreateSessionFromClient()
812873
.getOrElse({
813874
var existingSession = sessions.get(builder.configuration)

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
package org.apache.spark.sql.application
1818

1919
import java.io.{InputStream, OutputStream}
20-
import java.nio.file.Paths
2120
import java.util.concurrent.Semaphore
2221

23-
import scala.util.Try
2422
import scala.util.control.NonFatal
2523

2624
import ammonite.compiler.CodeClassWrapper
@@ -34,6 +32,7 @@ import ammonite.util.Util.newLine
3432
import org.apache.spark.SparkBuildInfo.spark_version
3533
import org.apache.spark.annotation.DeveloperApi
3634
import org.apache.spark.sql.SparkSession
35+
import org.apache.spark.sql.SparkSession.withLocalConnectServer
3736
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser}
3837

3938
/**
@@ -64,37 +63,7 @@ Spark session available as 'spark'.
6463
semaphore: Option[Semaphore] = None,
6564
inputStream: InputStream = System.in,
6665
outputStream: OutputStream = System.out,
67-
errorStream: OutputStream = System.err): Unit = {
68-
val configs: Map[String, String] =
69-
sys.props
70-
.filter(p =>
71-
p._1.startsWith("spark.") &&
72-
p._2.nonEmpty &&
73-
// Don't include spark.remote that we manually set later.
74-
!p._1.startsWith("spark.remote"))
75-
.toMap
76-
77-
val remoteString: Option[String] =
78-
Option(System.getProperty("spark.remote")) // Set from Spark Submit
79-
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
80-
81-
if (remoteString.exists(_.startsWith("local"))) {
82-
server = Some {
83-
val args = Seq(
84-
Paths.get(sparkHome, "sbin", "start-connect-server.sh").toString,
85-
"--master",
86-
remoteString.get) ++ configs.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
87-
val pb = new ProcessBuilder(args: _*)
88-
// So don't exclude spark-sql jar in classpath
89-
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
90-
pb.start()
91-
}
92-
// Let the server start. We will directly request to set the configurations
93-
// and this sleep makes less noisy with retries.
94-
Thread.sleep(2000L)
95-
System.setProperty("spark.remote", "sc://localhost")
96-
}
97-
66+
errorStream: OutputStream = System.err): Unit = withLocalConnectServer {
9867
// Build the client.
9968
val client =
10069
try {
@@ -118,13 +87,6 @@ Spark session available as 'spark'.
11887

11988
// Build the session.
12089
val spark = SparkSession.builder().client(client).getOrCreate()
121-
122-
// The configurations might not be all runtime configurations.
123-
// Try to set them with ignoring failures for now.
124-
configs
125-
.filter(_._1.startsWith("spark.sql"))
126-
.foreach { case (k, v) => Try(spark.conf.set(k, v)) }
127-
12890
val sparkBind = new Bind("spark", spark)
12991

13092
// Add the proper imports and register a [[ClassFinder]].
@@ -197,18 +159,12 @@ Spark session available as 'spark'.
197159
}
198160
}
199161
}
200-
try {
201-
if (semaphore.nonEmpty) {
202-
// Used for testing.
203-
main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
204-
} else {
205-
main.run(sparkBind)
206-
}
207-
} finally {
208-
if (server.isDefined) {
209-
new ProcessBuilder(Paths.get(sparkHome, "sbin", "stop-connect-server.sh").toString)
210-
.start()
211-
}
162+
163+
if (semaphore.nonEmpty) {
164+
// Used for testing.
165+
main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
166+
} else {
167+
main.run(sparkBind)
212168
}
213169
}
214170
}

core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ import scala.collection.mutable.ArrayBuffer
2525
import scala.jdk.CollectionConverters._
2626
import scala.util.Try
2727

28+
import org.json4s.JsonDSL._
29+
import org.json4s.jackson.JsonMethods.{compact, render}
30+
2831
import org.apache.spark.{SparkConf, SparkUserAppException}
2932
import org.apache.spark.api.python.{Py4JServer, PythonUtils}
3033
import org.apache.spark.internal.config._
@@ -50,18 +53,21 @@ object PythonRunner {
5053
val formattedPythonFile = formatPath(pythonFile)
5154
val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles))
5255

53-
val gatewayServer = new Py4JServer(sparkConf)
56+
var gatewayServer: Option[Py4JServer] = None
57+
if (sparkConf.getOption("spark.remote").isEmpty) {
58+
gatewayServer = Some(new Py4JServer(sparkConf))
5459

55-
val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() })
56-
thread.setName("py4j-gateway-init")
57-
thread.setDaemon(true)
58-
thread.start()
60+
val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.get.start() })
61+
thread.setName("py4j-gateway-init")
62+
thread.setDaemon(true)
63+
thread.start()
5964

60-
// Wait until the gateway server has started, so that we know which port is it bound to.
61-
// `gatewayServer.start()` will start a new thread and run the server code there, after
62-
// initializing the socket, so the thread started above will end as soon as the server is
63-
// ready to serve connections.
64-
thread.join()
65+
// Wait until the gateway server has started, so that we know which port is it bound to.
66+
// `gatewayServer.start()` will start a new thread and run the server code there, after
67+
// initializing the socket, so the thread started above will end as soon as the server is
68+
// ready to serve connections.
69+
thread.join()
70+
}
6571

6672
// Build up a PYTHONPATH that includes the Spark assembly (where this class is), the
6773
// python directories in SPARK_HOME (if set), and any files in the pyFiles argument
@@ -74,12 +80,22 @@ object PythonRunner {
7480
// Launch Python process
7581
val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
7682
val env = builder.environment()
83+
if (sparkConf.getOption("spark.remote").nonEmpty) {
84+
// For non-local remote, pass configurations to environment variables so
85+
// Spark Connect client sets them. For local remotes, they will be set
86+
// via Py4J.
87+
val grouped = sparkConf.getAll.toMap.grouped(10).toSeq
88+
env.put("PYSPARK_REMOTE_INIT_CONF_LEN", grouped.length.toString)
89+
grouped.zipWithIndex.foreach { case (group, idx) =>
90+
env.put(s"PYSPARK_REMOTE_INIT_CONF_$idx", compact(render(group)))
91+
}
92+
}
7793
sparkConf.getOption("spark.remote").foreach(url => env.put("SPARK_REMOTE", url))
7894
env.put("PYTHONPATH", pythonPath)
7995
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
8096
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
81-
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
82-
env.put("PYSPARK_GATEWAY_SECRET", gatewayServer.secret)
97+
gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_PORT", s.getListeningPort.toString))
98+
gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_SECRET", s.secret))
8399
// pass conf spark.pyspark.python to python process, the only way to pass info to
84100
// python process is through environment variable.
85101
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
@@ -103,7 +119,7 @@ object PythonRunner {
103119
throw new SparkUserAppException(exitCode)
104120
}
105121
} finally {
106-
gatewayServer.shutdown()
122+
gatewayServer.foreach(_.shutdown())
107123
}
108124
}
109125

core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,23 @@ class SparkSubmitSuite
18021802
val (_, classpath, _, _) = submit.prepareSubmitEnvironment(appArgs)
18031803
assert(classpath.contains("."))
18041804
}
1805+
1806+
// Requires Python dependencies for Spark Connect. Should be enabled by default.
1807+
ignore("Spark Connect application submission (Python)") {
1808+
val pyFile = File.createTempFile("remote_test", ".py")
1809+
pyFile.deleteOnExit()
1810+
val content =
1811+
"from pyspark.sql import SparkSession;" +
1812+
"spark = SparkSession.builder.getOrCreate();" +
1813+
"assert 'connect' in str(type(spark));" +
1814+
"assert spark.range(1).first()[0] == 0"
1815+
FileUtils.write(pyFile, content, StandardCharsets.UTF_8)
1816+
val args = Seq(
1817+
"--name", "testPyApp",
1818+
"--remote", "local",
1819+
pyFile.getAbsolutePath)
1820+
runSparkSubmit(args)
1821+
}
18051822
}
18061823

18071824
object JarCreationTest extends Logging {

launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ List<String> buildClassPath(String appClassPath) throws IOException {
214214
addToClassPath(cp, f.toString());
215215
}
216216
}
217-
if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL"))) {
217+
// If we're in 'spark.local.connect', it should create a Spark Classic Spark Context
218+
// that launches Spark Connect server.
219+
if (isRemote && System.getenv("SPARK_LOCAL_CONNECT") == null) {
218220
for (File f: new File(jarsDir).listFiles()) {
219221
// Exclude Spark Classic SQL and Spark Connect server jars
220222
// if we're in Spark Connect Shell. Also exclude Spark SQL API and

launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,6 @@ public List<String> buildCommand(Map<String, String> env)
8282
javaOptsKeys.add("SPARK_BEELINE_OPTS");
8383
yield "SPARK_BEELINE_MEMORY";
8484
}
85-
case "org.apache.spark.sql.application.ConnectRepl" -> {
86-
isRemote = true;
87-
yield "SPARK_DRIVER_MEMORY";
88-
}
8985
default -> "SPARK_DRIVER_MEMORY";
9086
};
9187

python/pyspark/sql/connect/session.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
check_dependencies(__name__)
2121

22+
import json
2223
import threading
2324
import os
2425
import warnings
@@ -200,6 +201,26 @@ def enableHiveSupport(self) -> "SparkSession.Builder":
200201
)
201202

202203
def _apply_options(self, session: "SparkSession") -> None:
204+
init_opts = {}
205+
for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))):
206+
init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"])
207+
208+
with self._lock:
209+
for k, v in init_opts.items():
210+
# the options are applied after session creation,
211+
# so following options always take no effect
212+
if k not in [
213+
"spark.remote",
214+
"spark.master",
215+
] and k.startswith("spark.sql."):
216+
# Only attempts to set Spark SQL configurations.
217+
# If the configurations are static, it might throw an exception so
218+
# simply ignore it for now.
219+
try:
220+
session.conf.set(k, v)
221+
except Exception:
222+
pass
223+
203224
with self._lock:
204225
for k, v in self._options.items():
205226
# the options are applied after session creation,
@@ -993,10 +1014,17 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
9931014

9941015
session = PySparkSession._instantiatedSession
9951016
if session is None or session._sc._jsc is None:
1017+
init_opts = {}
1018+
for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))):
1019+
init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"])
1020+
init_opts.update(opts)
1021+
opts = init_opts
1022+
9961023
# Configurations to be overwritten
9971024
overwrite_conf = opts
9981025
overwrite_conf["spark.master"] = master
9991026
overwrite_conf["spark.local.connect"] = "1"
1027+
os.environ["SPARK_LOCAL_CONNECT"] = "1"
10001028

10011029
# Configurations to be set if unset.
10021030
default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"}
@@ -1030,6 +1058,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
10301058
finally:
10311059
if origin_remote is not None:
10321060
os.environ["SPARK_REMOTE"] = origin_remote
1061+
del os.environ["SPARK_LOCAL_CONNECT"]
10331062
else:
10341063
raise PySparkRuntimeError(
10351064
errorClass="SESSION_OR_CONTEXT_EXISTS",

0 commit comments

Comments
 (0)