@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
20
20
import org .apache .spark .SparkUnsupportedOperationException
21
21
import org .apache .spark .rdd .RDD
22
22
import org .apache .spark .sql .{execution , DataFrame , Row }
23
+ import org .apache .spark .sql .AnalysisException
23
24
import org .apache .spark .sql .catalyst .InternalRow
24
25
import org .apache .spark .sql .catalyst .expressions ._
25
26
import org .apache .spark .sql .catalyst .plans ._
@@ -28,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
28
29
import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanHelper , DisableAdaptiveExecution }
29
30
import org .apache .spark .sql .execution .aggregate .{HashAggregateExec , ObjectHashAggregateExec , SortAggregateExec }
30
31
import org .apache .spark .sql .execution .columnar .{InMemoryRelation , InMemoryTableScanExec }
31
- import org .apache .spark .sql .execution .exchange .{EnsureRequirements , REPARTITION_BY_COL , ReusedExchangeExec , ShuffleExchangeExec }
32
+ import org .apache .spark .sql .execution .exchange .{BroadcastExchangeLike , EnsureRequirements , REPARTITION_BY_COL , ReusedExchangeExec , ShuffleExchangeExec , ShuffleExchangeLike }
32
33
import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , SortMergeJoinExec }
33
34
import org .apache .spark .sql .execution .reuse .ReuseExchangeAndSubquery
34
35
import org .apache .spark .sql .functions ._
@@ -1406,6 +1407,183 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
1406
1407
assert(planned.exists(_.isInstanceOf [GlobalLimitExec ]))
1407
1408
assert(planned.exists(_.isInstanceOf [LocalLimitExec ]))
1408
1409
}
1410
+
1411
+ test(" SPARK-53401: repartitionById - should partition rows to the specified partition ID" ) {
1412
+ val numPartitions = 10
1413
+ val df = spark.range(100 ).withColumn(" expected_p_id" , col(" id" ) % numPartitions)
1414
+
1415
+ val repartitioned = df.repartitionById(numPartitions, $" expected_p_id" .cast(" int" ))
1416
+ val result = repartitioned.withColumn(" actual_p_id" , spark_partition_id())
1417
+
1418
+ assert(result.filter(col(" expected_p_id" ) =!= col(" actual_p_id" )).count() == 0 )
1419
+
1420
+ assert(result.rdd.getNumPartitions == numPartitions)
1421
+ }
1422
+
1423
+ test(" SPARK-53401: repartitionById should handle negative partition ids correctly with pmod" ) {
1424
+ val df = spark.range(10 ).toDF(" id" )
1425
+ val repartitioned = df.repartitionById(10 , ($" id" - 5 ).cast(" int" ))
1426
+
1427
+ // With pmod, negative values should be converted to positive values
1428
+ // (-5) pmod 10 = 5, (-4) pmod 10 = 6
1429
+ val result = repartitioned.withColumn(" actual_p_id" , spark_partition_id()).collect()
1430
+
1431
+ assert(result.forall(row => {
1432
+ val actualPartitionId = row.getAs[Int ](" actual_p_id" )
1433
+ val id = row.getAs[Long ](" id" )
1434
+ val expectedPartitionId = {
1435
+ val mod = (id - 5 ) % 10
1436
+ if (mod < 0 ) mod + 10 else mod
1437
+ }
1438
+ actualPartitionId == expectedPartitionId
1439
+ }))
1440
+ }
1441
+
1442
+ test(" SPARK-53401: repartitionById should fail analysis for non-integral types" ) {
1443
+ val df = spark.range(5 ).withColumn(" s" , lit(" a" ))
1444
+ checkError(
1445
+ exception = intercept[AnalysisException ] {
1446
+ df.repartitionById(5 , $" s" ).collect()
1447
+ },
1448
+ condition = " DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE" ,
1449
+ parameters = Map (
1450
+ " sqlExpr" -> " \" direct_shuffle_partition_id(s)\" " ,
1451
+ " paramIndex" -> " first" ,
1452
+ " requiredType" -> " \" INT\" " ,
1453
+ " inputType" -> " \" STRING\" " ,
1454
+ " inputSql" -> " \" s\" "
1455
+ )
1456
+ )
1457
+ }
1458
+
1459
+ test(" SPARK-53401: repartitionById should send null partition ids to partition 0" ) {
1460
+ val df = spark.range(10 ).toDF(" id" )
1461
+ val partitionExpr = when($" id" < 5 , $" id" ).otherwise(lit(null )).cast(" int" )
1462
+ val repartitioned = df.repartitionById(10 , partitionExpr)
1463
+
1464
+ val result = repartitioned.withColumn(" actual_p_id" , spark_partition_id()).collect()
1465
+
1466
+ val nullRows = result.filter(_.getAs[Long ](" id" ) >= 5 )
1467
+ assert(nullRows.nonEmpty, " Should have rows with null partition expression" )
1468
+ assert(nullRows.forall(_.getAs[Int ](" actual_p_id" ) == 0 ),
1469
+ " All null partition id rows should go to partition 0" )
1470
+
1471
+ val nonNullRows = result.filter(_.getAs[Long ](" id" ) < 5 )
1472
+ nonNullRows.foreach { row =>
1473
+ val id = row.getAs[Long ](" id" ).toInt
1474
+ val actualPartitionId = row.getAs[Int ](" actual_p_id" )
1475
+ assert(actualPartitionId == id % 10 ,
1476
+ s " Row with id= $id should be in partition ${id % 10 }, " +
1477
+ s " but was in partition $actualPartitionId" )
1478
+ }
1479
+ }
1480
+
1481
+ test(" SPARK-53401: repartitionById should not" +
1482
+ " throw an exception for partition id >= numPartitions" ) {
1483
+ val numPartitions = 10
1484
+ val df = spark.range(20 ).toDF(" id" )
1485
+ val repartitioned = df.repartitionById(numPartitions, $" id" .cast(" int" ))
1486
+
1487
+ assert(repartitioned.collect().length == 20 )
1488
+ assert(repartitioned.rdd.getNumPartitions == numPartitions)
1489
+ }
1490
+
1491
+ /**
1492
+ * A helper function to check the number of shuffle exchanges in a physical plan.
1493
+ *
1494
+ * @param df The DataFrame whose physical plan will be examined.
1495
+ * @param expectedShuffles The expected number of shuffle exchanges.
1496
+ */
1497
+ private def checkShuffleCount (df : DataFrame , expectedShuffles : Int ): Unit = {
1498
+ val plan = df.queryExecution.executedPlan
1499
+ val shuffles = collect(plan) {
1500
+ case s : ShuffleExchangeLike => s
1501
+ case s : BroadcastExchangeLike => s
1502
+ }
1503
+ assert(
1504
+ shuffles.size == expectedShuffles,
1505
+ s " Expected $expectedShuffles shuffle(s), but found ${shuffles.size} in the plan: \n $plan"
1506
+ )
1507
+ }
1508
+
1509
+ test(" SPARK-53401: repartitionById followed by groupBy should only have one shuffle" ) {
1510
+ val df = spark.range(100 )
1511
+ .withColumn(" id" , col(" id" ).cast(" int" ))
1512
+ .toDF(" id" )
1513
+ val repartitioned = df.repartitionById(10 , $" id" )
1514
+ val grouped = repartitioned.groupBy($" id" ).count()
1515
+
1516
+ checkShuffleCount(grouped, 1 )
1517
+ }
1518
+
1519
+ test(" SPARK-53401: groupBy on a superset of partition keys should reuse the shuffle" ) {
1520
+ val df = spark.range(100 )
1521
+ .withColumn(" id" , col(" id" ).cast(" int" ))
1522
+ .select($" id" % 10 as " key1" , $" id" as " value" )
1523
+ val grouped = df.repartitionById(10 , $" key1" ).groupBy($" key1" , lit(1 )).count()
1524
+ checkShuffleCount(grouped, 1 )
1525
+ }
1526
+
1527
+ test(" SPARK-53401: shuffle reuse is not affected by spark.sql.shuffle.partitions" ) {
1528
+ withSQLConf(SQLConf .SHUFFLE_PARTITIONS .key -> " 5" ) {
1529
+ val df = spark.range(100 )
1530
+ .withColumn(" id" , col(" id" ).cast(" int" ))
1531
+ .select($" id" % 10 as " key" , $" id" as " value" )
1532
+ val grouped = df.repartitionById(10 , $" key" ).groupBy($" key" ).count()
1533
+
1534
+ checkShuffleCount(grouped, 1 )
1535
+ assert(grouped.rdd.getNumPartitions == 10 )
1536
+ }
1537
+ }
1538
+
1539
+ test(" SPARK-53401: join with id pass-through and hash partitioning requires shuffle" ) {
1540
+ val df1 = spark.range(100 )
1541
+ .withColumn(" id" , col(" id" ).cast(" int" ))
1542
+ .select($" id" % 10 as " key" , $" id" as " v1" )
1543
+ .repartitionById(10 , $" key" )
1544
+
1545
+ val df2 = spark.range(100 )
1546
+ .withColumn(" id" , col(" id" ).cast(" int" ))
1547
+ .select($" id" % 10 as " key" , $" id" as " v2" )
1548
+ .repartition($" key" )
1549
+
1550
+ val joined1 = df1.join(df2, " key" )
1551
+
1552
+ val grouped = joined1.groupBy(" key" ).count()
1553
+
1554
+ // Total shuffles: one for df1, one broadcast for df2, one for groupBy.
1555
+ // The groupBy reuse the output partitioning after DirectShufflePartitionID.
1556
+ checkShuffleCount(grouped, 3 )
1557
+
1558
+ val joined2 = df2.join(df1, " key" )
1559
+
1560
+ val grouped2 = joined2.groupBy(" key" ).count()
1561
+
1562
+ checkShuffleCount(grouped2, 3 )
1563
+ }
1564
+
1565
+ test(" SPARK-53401: shuffle reuse after a join doesn't preserve partitioning" ) {
1566
+ val df1 =
1567
+ spark
1568
+ .range(100 )
1569
+ .withColumn(" id" , col(" id" ).cast(" int" ))
1570
+ .select($" id" % 10 as " key" , $" id" as " v1" )
1571
+ .repartitionById(10 , $" key" )
1572
+ val df2 =
1573
+ spark
1574
+ .range(100 )
1575
+ .withColumn(" id" , col(" id" ).cast(" int" ))
1576
+ .select($" id" % 10 as " key" , $" id" as " v2" )
1577
+ .repartitionById(10 , $" key" )
1578
+
1579
+ val joined = df1.join(df2, " key" )
1580
+
1581
+ val grouped = joined.groupBy(" key" ).count()
1582
+
1583
+ // Total shuffles: one for df1, one for df2, one for groupBy.
1584
+ // The groupBy reuse the output partitioning after DirectShufflePartitionID.
1585
+ checkShuffleCount(grouped, 3 )
1586
+ }
1409
1587
}
1410
1588
1411
1589
// Used for unit-testing EnsureRequirements
0 commit comments