Skip to content

Commit db59d2d

Browse files
committed
[SPARK-53097] Make WriteOperationV2 in SparkConnectPlanner side effect free
1 parent 6ef9a9d commit db59d2d

File tree

2 files changed

+61
-38
lines changed

2 files changed

+61
-38
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3146,16 +3146,12 @@ class SparkConnectPlanner(
31463146
executeHolder.eventsManager.postFinished()
31473147
}
31483148

3149-
/**
3150-
* Transforms the write operation and executes it.
3151-
*
3152-
* The input write operation contains a reference to the input plan and transforms it to the
3153-
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
3154-
* parameters of the WriteOperation into the corresponding methods calls.
3155-
*
3156-
* @param writeOperation
3157-
*/
3158-
private def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = {
3149+
private def runCommand(command: LogicalPlan, tracker: QueryPlanningTracker): Unit = {
3150+
val qe = new QueryExecution(session, command, tracker)
3151+
qe.assertCommandExecuted()
3152+
}
3153+
3154+
private def transformWriteOperationV2(writeOperation: proto.WriteOperationV2): LogicalPlan = {
31593155
// Transform the input plan into the logical plan.
31603156
val plan = transformRelation(writeOperation.getInput)
31613157
// And create a Dataset from the plan.
@@ -3190,31 +3186,42 @@ class SparkConnectPlanner(
31903186
writeOperation.getMode match {
31913187
case proto.WriteOperationV2.Mode.MODE_CREATE =>
31923188
if (writeOperation.hasProvider) {
3193-
w.using(writeOperation.getProvider).create()
3194-
} else {
3195-
w.create()
3189+
w.using(writeOperation.getProvider)
31963190
}
3191+
w.createCommand()
31973192
case proto.WriteOperationV2.Mode.MODE_OVERWRITE =>
3198-
w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition)))
3193+
w.overwriteCommand(Column(transformExpression(writeOperation.getOverwriteCondition)))
31993194
case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS =>
3200-
w.overwritePartitions()
3195+
w.overwritePartitionsCommand()
32013196
case proto.WriteOperationV2.Mode.MODE_APPEND =>
3202-
w.append()
3197+
w.appendCommand()
32033198
case proto.WriteOperationV2.Mode.MODE_REPLACE =>
32043199
if (writeOperation.hasProvider) {
3205-
w.using(writeOperation.getProvider).replace()
3206-
} else {
3207-
w.replace()
3200+
w.using(writeOperation.getProvider)
32083201
}
3202+
w.replaceCommand(orCreate = false)
32093203
case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE =>
32103204
if (writeOperation.hasProvider) {
3211-
w.using(writeOperation.getProvider).createOrReplace()
3212-
} else {
3213-
w.createOrReplace()
3205+
w.using(writeOperation.getProvider)
32143206
}
3207+
w.replaceCommand(orCreate = true)
32153208
case other =>
32163209
throw InvalidInputErrors.invalidEnum(other)
32173210
}
3211+
}
3212+
3213+
/**
3214+
* Transforms the write operation and executes it.
3215+
*
3216+
* The input write operation contains a reference to the input plan and transforms it to the
3217+
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
3218+
* parameters of the WriteOperation into the corresponding methods calls.
3219+
*
3220+
* @param writeOperation
3221+
*/
3222+
private def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = {
3223+
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3224+
runCommand(transformWriteOperationV2(writeOperation), tracker)
32183225
executeHolder.eventsManager.postFinished()
32193226
}
32203227

sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,17 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
148148

149149
/** @inheritdoc */
150150
override def create(): Unit = {
151-
runCommand(
152-
CreateTableAsSelect(
153-
UnresolvedIdentifier(tableName),
154-
partitioning.getOrElse(Seq.empty) ++ clustering,
155-
logicalPlan,
156-
buildTableSpec(),
157-
options.toMap,
158-
false))
151+
runCommand(createCommand())
152+
}
153+
154+
private[sql] def createCommand(): LogicalPlan = {
155+
CreateTableAsSelect(
156+
UnresolvedIdentifier(tableName),
157+
partitioning.getOrElse(Seq.empty) ++ clustering,
158+
logicalPlan,
159+
buildTableSpec(),
160+
options.toMap,
161+
false)
159162
}
160163

161164
private def buildTableSpec(): UnresolvedTableSpec = {
@@ -186,28 +189,37 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
186189
/** @inheritdoc */
187190
@throws(classOf[NoSuchTableException])
188191
def append(): Unit = {
189-
val append = AppendData.byName(
192+
runCommand(appendCommand())
193+
}
194+
195+
private[sql] def appendCommand(): LogicalPlan = {
196+
AppendData.byName(
190197
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)),
191198
logicalPlan, options.toMap)
192-
runCommand(append)
193199
}
194200

195201
/** @inheritdoc */
196202
@throws(classOf[NoSuchTableException])
197203
def overwrite(condition: Column): Unit = {
198-
val overwrite = OverwriteByExpression.byName(
204+
runCommand(overwriteCommand(condition))
205+
}
206+
207+
private[sql] def overwriteCommand(condition: Column): LogicalPlan = {
208+
OverwriteByExpression.byName(
199209
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
200210
logicalPlan, expression(condition), options.toMap)
201-
runCommand(overwrite)
202211
}
203212

204213
/** @inheritdoc */
205214
@throws(classOf[NoSuchTableException])
206215
def overwritePartitions(): Unit = {
207-
val dynamicOverwrite = OverwritePartitionsDynamic.byName(
216+
runCommand(overwritePartitionsCommand())
217+
}
218+
219+
private[sql] def overwritePartitionsCommand(): LogicalPlan = {
220+
OverwritePartitionsDynamic.byName(
208221
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
209222
logicalPlan, options.toMap)
210-
runCommand(dynamicOverwrite)
211223
}
212224

213225
/**
@@ -220,13 +232,17 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
220232
}
221233

222234
private def internalReplace(orCreate: Boolean): Unit = {
223-
runCommand(ReplaceTableAsSelect(
235+
runCommand(replaceCommand(orCreate))
236+
}
237+
238+
private[sql] def replaceCommand(orCreate: Boolean): LogicalPlan = {
239+
ReplaceTableAsSelect(
224240
UnresolvedIdentifier(tableName),
225241
partitioning.getOrElse(Seq.empty) ++ clustering,
226242
logicalPlan,
227243
buildTableSpec(),
228244
writeOptions = options.toMap,
229-
orCreate = orCreate))
245+
orCreate = orCreate)
230246
}
231247
}
232248

0 commit comments

Comments
 (0)