Skip to content

Commit b017473

Browse files
[SPARK-53401][SQL] Enable Direct Passthrough Partitioning in the DataFrame API
### What changes were proposed in this pull request? Currently, Spark's DataFrame repartition() API only supports hash-based and range-based partitioning strategies. Users who need precise control over which partition each row goes to (similar to RDD's partitionBy with custom partitioners) have no direct way to achieve this at the DataFrame level. This PR introduces a new DataFrame API, `repartitionById(col, numPartitions)`, an API that allows users to directly specify target partition IDs in DataFrame repartitioning operations: ``` // Partition rows based on a computed partition ID val df = spark.range(100).withColumn("partition_id", col("id") % 10) val repartitioned = df.repartitionById($"partition_id", 10) ``` ### Why are the changes needed? Enable precise control over which partition each row goes to (similar to RDD's partitionBy with custom partitioners) at the DataFrame level ### Does this PR introduce _any_ user-facing change? Yes. ``` // Partition rows based on a computed partition ID val df = spark.range(100).withColumn("partition_id", col("id") % 10) val repartitioned = df.repartitionById($"partition_id", 10) ``` ### How was this patch tested? New Unit Tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #52153 from shujingyang-db/direct-partitionId-pass-through. Lead-authored-by: Shujing Yang <[email protected]> Co-authored-by: Shujing Yang <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent fbbc550 commit b017473

File tree

6 files changed

+313
-14
lines changed

6 files changed

+313
-14
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType}
21+
22+
/**
23+
* Expression that takes a partition ID value and passes it through directly for use in
24+
* shuffle partitioning. This is used with RepartitionByExpression to allow users to
25+
* directly specify target partition IDs.
26+
*
27+
* The child expression must evaluate to an integral type and must not be null.
28+
* The resulting partition ID must be in the range [0, numPartitions).
29+
*/
30+
case class DirectShufflePartitionID(child: Expression)
31+
extends UnaryExpression
32+
with ExpectsInputTypes
33+
with Unevaluable {
34+
35+
override def dataType: DataType = child.dataType
36+
37+
override def inputTypes: Seq[AbstractDataType] = IntegerType :: Nil
38+
39+
override def nullable: Boolean = false
40+
41+
override val prettyName: String = "direct_shuffle_partition_id"
42+
43+
override protected def withNewChildInternal(newChild: Expression): DirectShufflePartitionID =
44+
copy(child = newChild)
45+
}
46+

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate}
2626
import org.apache.spark.sql.catalyst.plans._
27-
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
27+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, ShufflePartitionIdPassThrough, SinglePartition}
2828
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
2929
import org.apache.spark.sql.catalyst.trees.TreePattern._
3030
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -1871,19 +1871,29 @@ trait HasPartitionExpressions extends SQLConfHelper {
18711871
protected def partitioning: Partitioning = if (partitionExpressions.isEmpty) {
18721872
RoundRobinPartitioning(numPartitions)
18731873
} else {
1874-
val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder])
1875-
require(sortOrder.isEmpty || nonSortOrder.isEmpty,
1876-
s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of type " +
1877-
"`SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`, which " +
1878-
"means `HashPartitioning`. In this case we have:" +
1879-
s"""
1880-
|SortOrder: $sortOrder
1881-
|NonSortOrder: $nonSortOrder
1882-
""".stripMargin)
1883-
if (sortOrder.nonEmpty) {
1884-
RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions)
1874+
val directShuffleExprs = partitionExpressions.filter(_.isInstanceOf[DirectShufflePartitionID])
1875+
if (directShuffleExprs.nonEmpty) {
1876+
assert(directShuffleExprs.length == 1 && partitionExpressions.length == 1,
1877+
s"DirectShufflePartitionID can only be used as a single partition expression, " +
1878+
s"but found ${directShuffleExprs.length} DirectShufflePartitionID expressions " +
1879+
s"out of ${partitionExpressions.length} total expressions")
1880+
ShufflePartitionIdPassThrough(
1881+
partitionExpressions.head.asInstanceOf[DirectShufflePartitionID], numPartitions)
18851882
} else {
1886-
HashPartitioning(partitionExpressions, numPartitions)
1883+
val (sortOrder, nonSortOrder) = partitionExpressions.partition(_.isInstanceOf[SortOrder])
1884+
require(sortOrder.isEmpty || nonSortOrder.isEmpty,
1885+
s"${getClass.getSimpleName} expects that either all its `partitionExpressions` are of" +
1886+
" type `SortOrder`, which means `RangePartitioning`, or none of them are `SortOrder`," +
1887+
" which means `HashPartitioning`. In this case we have:" +
1888+
s"""
1889+
|SortOrder: $sortOrder
1890+
|NonSortOrder: $nonSortOrder
1891+
""".stripMargin)
1892+
if (sortOrder.nonEmpty) {
1893+
RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), numPartitions)
1894+
} else {
1895+
HashPartitioning(partitionExpressions, numPartitions)
1896+
}
18871897
}
18881898
}
18891899
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,47 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
626626
* - Creating a partitioning that can be used to re-partition another child, so that to make it
627627
* having a compatible partitioning as this node.
628628
*/
629+
630+
/**
631+
* Represents a partitioning where partition IDs are passed through directly from the
632+
* DirectShufflePartitionID expression. This partitioning scheme is used when users
633+
* want to directly control partition placement rather than using hash-based partitioning.
634+
*
635+
* This partitioning maps directly to the PartitionIdPassthrough RDD partitioner.
636+
*/
637+
case class ShufflePartitionIdPassThrough(
638+
expr: DirectShufflePartitionID,
639+
numPartitions: Int) extends Expression with Partitioning with Unevaluable {
640+
641+
// TODO(SPARK-53401): Support Shuffle Spec in Direct Partition ID Pass Through
642+
def partitionIdExpression: Expression = Pmod(expr.child, Literal(numPartitions))
643+
644+
def expressions: Seq[Expression] = expr :: Nil
645+
override def children: Seq[Expression] = expr :: Nil
646+
override def nullable: Boolean = false
647+
override def dataType: DataType = IntegerType
648+
649+
override def satisfies0(required: Distribution): Boolean = {
650+
super.satisfies0(required) || {
651+
required match {
652+
// TODO(SPARK-53428): Support Direct Passthrough Partitioning in the Streaming Joins
653+
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
654+
val partitioningExpressions = expr.child :: Nil
655+
if (requireAllClusterKeys) {
656+
c.areAllClusterKeysMatched(partitioningExpressions)
657+
} else {
658+
partitioningExpressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
659+
}
660+
case _ => false
661+
}
662+
}
663+
}
664+
665+
override protected def withNewChildrenInternal(
666+
newChildren: IndexedSeq[Expression]): ShufflePartitionIdPassThrough =
667+
copy(expr = newChildren.head.asInstanceOf[DirectShufflePartitionID])
668+
}
669+
629670
trait ShuffleSpec {
630671
/**
631672
* Returns the number of partitions of this shuffle spec

sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,21 @@ class Dataset[T] private[sql](
15441544
}
15451545
}
15461546

1547+
/**
1548+
* Repartitions the Dataset into the given number of partitions using the specified
1549+
* partition ID expression.
1550+
*
1551+
* @param numPartitions the number of partitions to use.
1552+
* @param partitionIdExpr the expression to be used as the partition ID. Must be an integer type.
1553+
*
1554+
* @group typedrel
1555+
* @since 4.1.0
1556+
*/
1557+
def repartitionById(numPartitions: Int, partitionIdExpr: Column): Dataset[T] = {
1558+
val directShufflePartitionIdCol = Column(DirectShufflePartitionID(partitionIdExpr.expr))
1559+
repartitionByExpression(Some(numPartitions), Seq(directShufflePartitionIdCol))
1560+
}
1561+
15471562
protected def repartitionByRange(
15481563
numPartitions: Option[Int],
15491564
partitionExprs: Seq[Column]): Dataset[T] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ object ShuffleExchangeExec {
344344
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
345345
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
346346
new PartitionIdPassthrough(n)
347+
case ShufflePartitionIdPassThrough(_, n) =>
348+
// For ShufflePartitionIdPassThrough, the DirectShufflePartitionID expression directly
349+
// produces partition IDs, so we use PartitionIdPassthrough to pass them through directly.
350+
new PartitionIdPassthrough(n)
347351
case RangePartitioning(sortingExpressions, numPartitions) =>
348352
// Extract only fields used for sorting to avoid collecting large fields that does not
349353
// affect sorting result when deciding partition bounds in RangePartitioner
@@ -399,6 +403,11 @@ object ShuffleExchangeExec {
399403
case SinglePartition => identity
400404
case KeyGroupedPartitioning(expressions, _, _, _) =>
401405
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
406+
case s: ShufflePartitionIdPassThrough =>
407+
// For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID
408+
// If the value is null, `InternalRow#getInt` returns 0.
409+
val projection = UnsafeProjection.create(s.partitionIdExpression :: Nil, outputAttributes)
410+
row => projection(row).getInt(0)
402411
case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning")
403412
}
404413

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 179 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
2020
import org.apache.spark.SparkUnsupportedOperationException
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.{execution, DataFrame, Row}
23+
import org.apache.spark.sql.AnalysisException
2324
import org.apache.spark.sql.catalyst.InternalRow
2425
import org.apache.spark.sql.catalyst.expressions._
2526
import org.apache.spark.sql.catalyst.plans._
@@ -28,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
2829
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
2930
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
3031
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
31-
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec}
32+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
3233
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
3334
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
3435
import org.apache.spark.sql.functions._
@@ -1406,6 +1407,183 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
14061407
assert(planned.exists(_.isInstanceOf[GlobalLimitExec]))
14071408
assert(planned.exists(_.isInstanceOf[LocalLimitExec]))
14081409
}
1410+
1411+
test("SPARK-53401: repartitionById - should partition rows to the specified partition ID") {
1412+
val numPartitions = 10
1413+
val df = spark.range(100).withColumn("expected_p_id", col("id") % numPartitions)
1414+
1415+
val repartitioned = df.repartitionById(numPartitions, $"expected_p_id".cast("int"))
1416+
val result = repartitioned.withColumn("actual_p_id", spark_partition_id())
1417+
1418+
assert(result.filter(col("expected_p_id") =!= col("actual_p_id")).count() == 0)
1419+
1420+
assert(result.rdd.getNumPartitions == numPartitions)
1421+
}
1422+
1423+
test("SPARK-53401: repartitionById should handle negative partition ids correctly with pmod") {
1424+
val df = spark.range(10).toDF("id")
1425+
val repartitioned = df.repartitionById(10, ($"id" - 5).cast("int"))
1426+
1427+
// With pmod, negative values should be converted to positive values
1428+
// (-5) pmod 10 = 5, (-4) pmod 10 = 6
1429+
val result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect()
1430+
1431+
assert(result.forall(row => {
1432+
val actualPartitionId = row.getAs[Int]("actual_p_id")
1433+
val id = row.getAs[Long]("id")
1434+
val expectedPartitionId = {
1435+
val mod = (id - 5) % 10
1436+
if (mod < 0) mod + 10 else mod
1437+
}
1438+
actualPartitionId == expectedPartitionId
1439+
}))
1440+
}
1441+
1442+
test("SPARK-53401: repartitionById should fail analysis for non-integral types") {
1443+
val df = spark.range(5).withColumn("s", lit("a"))
1444+
checkError(
1445+
exception = intercept[AnalysisException] {
1446+
df.repartitionById(5, $"s").collect()
1447+
},
1448+
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
1449+
parameters = Map(
1450+
"sqlExpr" -> "\"direct_shuffle_partition_id(s)\"",
1451+
"paramIndex" -> "first",
1452+
"requiredType" -> "\"INT\"",
1453+
"inputType" -> "\"STRING\"",
1454+
"inputSql" -> "\"s\""
1455+
)
1456+
)
1457+
}
1458+
1459+
test("SPARK-53401: repartitionById should send null partition ids to partition 0") {
1460+
val df = spark.range(10).toDF("id")
1461+
val partitionExpr = when($"id" < 5, $"id").otherwise(lit(null)).cast("int")
1462+
val repartitioned = df.repartitionById(10, partitionExpr)
1463+
1464+
val result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect()
1465+
1466+
val nullRows = result.filter(_.getAs[Long]("id") >= 5)
1467+
assert(nullRows.nonEmpty, "Should have rows with null partition expression")
1468+
assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0),
1469+
"All null partition id rows should go to partition 0")
1470+
1471+
val nonNullRows = result.filter(_.getAs[Long]("id") < 5)
1472+
nonNullRows.foreach { row =>
1473+
val id = row.getAs[Long]("id").toInt
1474+
val actualPartitionId = row.getAs[Int]("actual_p_id")
1475+
assert(actualPartitionId == id % 10,
1476+
s"Row with id=$id should be in partition ${id % 10}, " +
1477+
s"but was in partition $actualPartitionId")
1478+
}
1479+
}
1480+
1481+
test("SPARK-53401: repartitionById should not" +
1482+
" throw an exception for partition id >= numPartitions") {
1483+
val numPartitions = 10
1484+
val df = spark.range(20).toDF("id")
1485+
val repartitioned = df.repartitionById(numPartitions, $"id".cast("int"))
1486+
1487+
assert(repartitioned.collect().length == 20)
1488+
assert(repartitioned.rdd.getNumPartitions == numPartitions)
1489+
}
1490+
1491+
/**
1492+
* A helper function to check the number of shuffle exchanges in a physical plan.
1493+
*
1494+
* @param df The DataFrame whose physical plan will be examined.
1495+
* @param expectedShuffles The expected number of shuffle exchanges.
1496+
*/
1497+
private def checkShuffleCount(df: DataFrame, expectedShuffles: Int): Unit = {
1498+
val plan = df.queryExecution.executedPlan
1499+
val shuffles = collect(plan) {
1500+
case s: ShuffleExchangeLike => s
1501+
case s: BroadcastExchangeLike => s
1502+
}
1503+
assert(
1504+
shuffles.size == expectedShuffles,
1505+
s"Expected $expectedShuffles shuffle(s), but found ${shuffles.size} in the plan:\n$plan"
1506+
)
1507+
}
1508+
1509+
test("SPARK-53401: repartitionById followed by groupBy should only have one shuffle") {
1510+
val df = spark.range(100)
1511+
.withColumn("id", col("id").cast("int"))
1512+
.toDF("id")
1513+
val repartitioned = df.repartitionById(10, $"id")
1514+
val grouped = repartitioned.groupBy($"id").count()
1515+
1516+
checkShuffleCount(grouped, 1)
1517+
}
1518+
1519+
test("SPARK-53401: groupBy on a superset of partition keys should reuse the shuffle") {
1520+
val df = spark.range(100)
1521+
.withColumn("id", col("id").cast("int"))
1522+
.select($"id" % 10 as "key1", $"id" as "value")
1523+
val grouped = df.repartitionById(10, $"key1").groupBy($"key1", lit(1)).count()
1524+
checkShuffleCount(grouped, 1)
1525+
}
1526+
1527+
test("SPARK-53401: shuffle reuse is not affected by spark.sql.shuffle.partitions") {
1528+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
1529+
val df = spark.range(100)
1530+
.withColumn("id", col("id").cast("int"))
1531+
.select($"id" % 10 as "key", $"id" as "value")
1532+
val grouped = df.repartitionById(10, $"key").groupBy($"key").count()
1533+
1534+
checkShuffleCount(grouped, 1)
1535+
assert(grouped.rdd.getNumPartitions == 10)
1536+
}
1537+
}
1538+
1539+
test("SPARK-53401: join with id pass-through and hash partitioning requires shuffle") {
1540+
val df1 = spark.range(100)
1541+
.withColumn("id", col("id").cast("int"))
1542+
.select($"id" % 10 as "key", $"id" as "v1")
1543+
.repartitionById(10, $"key")
1544+
1545+
val df2 = spark.range(100)
1546+
.withColumn("id", col("id").cast("int"))
1547+
.select($"id" % 10 as "key", $"id" as "v2")
1548+
.repartition($"key")
1549+
1550+
val joined1 = df1.join(df2, "key")
1551+
1552+
val grouped = joined1.groupBy("key").count()
1553+
1554+
// Total shuffles: one for df1, one broadcast for df2, one for groupBy.
1555+
// The groupBy reuse the output partitioning after DirectShufflePartitionID.
1556+
checkShuffleCount(grouped, 3)
1557+
1558+
val joined2 = df2.join(df1, "key")
1559+
1560+
val grouped2 = joined2.groupBy("key").count()
1561+
1562+
checkShuffleCount(grouped2, 3)
1563+
}
1564+
1565+
test("SPARK-53401: shuffle reuse after a join doesn't preserve partitioning") {
1566+
val df1 =
1567+
spark
1568+
.range(100)
1569+
.withColumn("id", col("id").cast("int"))
1570+
.select($"id" % 10 as "key", $"id" as "v1")
1571+
.repartitionById(10, $"key")
1572+
val df2 =
1573+
spark
1574+
.range(100)
1575+
.withColumn("id", col("id").cast("int"))
1576+
.select($"id" % 10 as "key", $"id" as "v2")
1577+
.repartitionById(10, $"key")
1578+
1579+
val joined = df1.join(df2, "key")
1580+
1581+
val grouped = joined.groupBy("key").count()
1582+
1583+
// Total shuffles: one for df1, one for df2, one for groupBy.
1584+
// The groupBy reuse the output partitioning after DirectShufflePartitionID.
1585+
checkShuffleCount(grouped, 3)
1586+
}
14091587
}
14101588

14111589
// Used for unit-testing EnsureRequirements

0 commit comments

Comments
 (0)