Skip to content

[SPARK-53097][CONNECT][SQL] Make WriteOperationV2 in SparkConnectPlanner side effect free #51813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
dataframe).foreach(responseObserver.onNext)
case proto.Plan.OpTypeCase.COMMAND =>
val command = request.getPlan.getCommand
planner.transformCommand(command, tracker) match {
case Some(plan) =>
val qe =
new QueryExecution(session, plan, tracker, shuffleCleanupMode = shuffleCleanupMode)
planner.transformCommand(command) match {
case Some(transformer) =>
val qe = new QueryExecution(
session,
transformer(tracker),
tracker,
shuffleCleanupMode = shuffleCleanupMode)
qe.assertCommandExecuted()
executeHolder.eventsManager.postFinished()
case None =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2646,12 +2646,12 @@ class SparkConnectPlanner(
process(command, new MockObserver())
}

def transformCommand(
command: proto.Command,
tracker: QueryPlanningTracker): Option[LogicalPlan] = {
def transformCommand(command: proto.Command): Option[QueryPlanningTracker => LogicalPlan] = {
command.getCommandTypeCase match {
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
Some(transformWriteOperation(command.getWriteOperation, tracker))
Some(transformWriteOperation(command.getWriteOperation))
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
Some(transformWriteOperationV2(command.getWriteOperationV2))
case _ =>
None
}
Expand All @@ -2660,19 +2660,20 @@ class SparkConnectPlanner(
def process(
command: proto.Command,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val transformerOpt = transformCommand(command)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heyihong Should unit tests be added to cover the changes in the method process?
The method flow has been significantly altered with new transformation logic and early returns. The PR mentions "Existing tests". Are there existing tests covering the change in the method? If yes, can you please point me to that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (transformerOpt.isDefined) {
transformAndRunCommand(transformerOpt.get)
return
}
command.getCommandTypeCase match {
case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
handleRegisterUserDefinedFunction(command.getRegisterFunction)
case proto.Command.CommandTypeCase.REGISTER_TABLE_FUNCTION =>
handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction)
case proto.Command.CommandTypeCase.REGISTER_DATA_SOURCE =>
handleRegisterUserDefinedDataSource(command.getRegisterDataSource)
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
handleWriteOperation(command.getWriteOperation)
case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
handleCreateViewCommand(command.getCreateDataframeView)
case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
handleWriteOperationV2(command.getWriteOperationV2)
case proto.Command.CommandTypeCase.EXTENSION =>
handleCommandPlugin(command.getExtension)
case proto.Command.CommandTypeCase.SQL_COMMAND =>
Expand Down Expand Up @@ -3089,8 +3090,16 @@ class SparkConnectPlanner(
executeHolder.eventsManager.postFinished()
}

private def transformWriteOperation(
writeOperation: proto.WriteOperation,
/**
* Transforms the write operation.
*
* The input write operation contains a reference to the input plan and transforms it to the
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
* parameters of the WriteOperation into the corresponding methods calls.
*
* @param writeOperation
*/
private def transformWriteOperation(writeOperation: proto.WriteOperation)(
tracker: QueryPlanningTracker): LogicalPlan = {
// Transform the input plan into the logical plan.
val plan = transformRelation(writeOperation.getInput)
Expand Down Expand Up @@ -3149,41 +3158,27 @@ class SparkConnectPlanner(
}
}

private def runCommand(command: LogicalPlan, tracker: QueryPlanningTracker): Unit = {
val qe = new QueryExecution(session, command, tracker)
qe.assertCommandExecuted()
}

/**
* Transforms the write operation and executes it.
*
* The input write operation contains a reference to the input plan and transforms it to the
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
* parameters of the WriteOperation into the corresponding methods calls.
*
* @param writeOperation
*/
private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = {
private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = {
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
runCommand(transformWriteOperation(writeOperation, tracker), tracker)

val qe = new QueryExecution(session, transformer(tracker), tracker)
qe.assertCommandExecuted()
executeHolder.eventsManager.postFinished()
}

/**
* Transforms the write operation and executes it.
* Transforms the write operation.
*
* The input write operation contains a reference to the input plan and transforms it to the
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
* parameters of the WriteOperation into the corresponding methods calls.
*
* @param writeOperation
*/
private def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = {
private def transformWriteOperationV2(writeOperation: proto.WriteOperationV2)(
tracker: QueryPlanningTracker): LogicalPlan = {
// Transform the input plan into the logical plan.
val plan = transformRelation(writeOperation.getInput)
// And create a Dataset from the plan.
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
val dataset = Dataset.ofRows(session, plan, tracker)

val w = dataset.writeTo(table = writeOperation.getTableName)
Expand Down Expand Up @@ -3214,32 +3209,28 @@ class SparkConnectPlanner(
writeOperation.getMode match {
case proto.WriteOperationV2.Mode.MODE_CREATE =>
if (writeOperation.hasProvider) {
w.using(writeOperation.getProvider).create()
} else {
w.create()
w.using(writeOperation.getProvider)
}
w.createCommand()
case proto.WriteOperationV2.Mode.MODE_OVERWRITE =>
w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition)))
w.overwriteCommand(Column(transformExpression(writeOperation.getOverwriteCondition)))
case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS =>
w.overwritePartitions()
w.overwritePartitionsCommand()
case proto.WriteOperationV2.Mode.MODE_APPEND =>
w.append()
w.appendCommand()
case proto.WriteOperationV2.Mode.MODE_REPLACE =>
if (writeOperation.hasProvider) {
w.using(writeOperation.getProvider).replace()
} else {
w.replace()
w.using(writeOperation.getProvider)
}
w.replaceCommand(orCreate = false)
case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE =>
if (writeOperation.hasProvider) {
w.using(writeOperation.getProvider).createOrReplace()
} else {
w.createOrReplace()
w.using(writeOperation.getProvider)
}
w.replaceCommand(orCreate = true)
case other =>
throw InvalidInputErrors.invalidEnum(other)
}
executeHolder.eventsManager.postFinished()
}

private def handleWriteStreamOperationStart(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,17 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])

/** @inheritdoc */
override def create(): Unit = {
runCommand(
CreateTableAsSelect(
UnresolvedIdentifier(tableName),
partitioning.getOrElse(Seq.empty) ++ clustering,
logicalPlan,
buildTableSpec(),
options.toMap,
false))
runCommand(createCommand())
}

private[sql] def createCommand(): LogicalPlan = {
CreateTableAsSelect(
UnresolvedIdentifier(tableName),
partitioning.getOrElse(Seq.empty) ++ clustering,
logicalPlan,
buildTableSpec(),
options.toMap,
false)
}

private def buildTableSpec(): UnresolvedTableSpec = {
Expand Down Expand Up @@ -186,28 +189,37 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
/** @inheritdoc */
@throws(classOf[NoSuchTableException])
def append(): Unit = {
val append = AppendData.byName(
runCommand(appendCommand())
}

private[sql] def appendCommand(): LogicalPlan = {
AppendData.byName(
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)),
logicalPlan, options.toMap)
runCommand(append)
}

/** @inheritdoc */
@throws(classOf[NoSuchTableException])
def overwrite(condition: Column): Unit = {
val overwrite = OverwriteByExpression.byName(
runCommand(overwriteCommand(condition))
}

private[sql] def overwriteCommand(condition: Column): LogicalPlan = {
OverwriteByExpression.byName(
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
logicalPlan, expression(condition), options.toMap)
runCommand(overwrite)
}

/** @inheritdoc */
@throws(classOf[NoSuchTableException])
def overwritePartitions(): Unit = {
val dynamicOverwrite = OverwritePartitionsDynamic.byName(
runCommand(overwritePartitionsCommand())
}

private[sql] def overwritePartitionsCommand(): LogicalPlan = {
OverwritePartitionsDynamic.byName(
UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
logicalPlan, options.toMap)
runCommand(dynamicOverwrite)
}

/**
Expand All @@ -220,13 +232,17 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
}

private def internalReplace(orCreate: Boolean): Unit = {
runCommand(ReplaceTableAsSelect(
runCommand(replaceCommand(orCreate))
}

private[sql] def replaceCommand(orCreate: Boolean): LogicalPlan = {
ReplaceTableAsSelect(
UnresolvedIdentifier(tableName),
partitioning.getOrElse(Seq.empty) ++ clustering,
logicalPlan,
buildTableSpec(),
writeOptions = options.toMap,
orCreate = orCreate))
orCreate = orCreate)
}
}

Expand Down