diff --git a/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala
new file mode 100644
index 0000000000000..4cd37e52dd8be
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.{IOException, ObjectOutputStream}
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, TaskContext}
+import org.apache.spark.util.Utils
+
+/**
+ * An RDD that coalesces partitions while preserving ordering through k-way merge.
+ *
+ * Unlike CoalescedRDD which simply concatenates partitions, this RDD performs a sorted
+ * merge of multiple input partitions to maintain ordering. This is useful when input
+ * partitions are locally sorted and we want to preserve that ordering after coalescing.
+ *
+ * The merge is performed using a priority queue (min-heap) which provides O(n log k)
+ * time complexity, where n is the total number of elements and k is the number of
+ * partitions being merged.
+ *
+ * @param prev The parent RDD
+ * @param numPartitions The number of output partitions after coalescing
+ * @param ordering The ordering to maintain during merge
+ * @param partitionCoalescer The coalescer defining how to group input partitions
+ * @tparam T The element type
+ */
+private[spark] class SortedMergeCoalescedRDD[T: ClassTag](
+ @transient var prev: RDD[T],
+ numPartitions: Int,
+ partitionCoalescer: PartitionCoalescer,
+ ordering: Ordering[T])
+ extends RDD[T](prev.context, Nil) {
+
+ override def getPartitions: Array[Partition] = {
+ partitionCoalescer.coalesce(numPartitions, prev).zipWithIndex.map {
+ case (pg, i) =>
+ val parentIndices = pg.partitions.map(_.index).toSeq
+ new SortedMergePartition(i, prev, parentIndices, pg.prefLoc)
+ }
+ }
+
+ override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
+ val mergePartition = partition.asInstanceOf[SortedMergePartition]
+ val parentPartitions = mergePartition.parents
+
+ if (parentPartitions.isEmpty) {
+ Iterator.empty
+ } else if (parentPartitions.length == 1) {
+ // No merge needed for single partition
+ firstParent[T].iterator(parentPartitions.head, context)
+ } else {
+ // Perform k-way merge
+ new SortedMergeIterator[T](
+ parentPartitions.map(p => firstParent[T].iterator(p, context)),
+ ordering
+ )
+ }
+ }
+
+ override def getDependencies: Seq[org.apache.spark.Dependency[_]] = {
+ Seq(new org.apache.spark.NarrowDependency(prev) {
+ def getParents(id: Int): Seq[Int] =
+ partitions(id).asInstanceOf[SortedMergePartition].parentsIndices
+ })
+ }
+
+ override def getPreferredLocations(partition: Partition): Seq[String] = {
+ partition.asInstanceOf[SortedMergePartition].prefLoc.toSeq
+ }
+
+ override def clearDependencies(): Unit = {
+ super.clearDependencies()
+ prev = null
+ }
+}
+
+/**
+ * Partition for SortedMergeCoalescedRDD that tracks which parent partitions to merge.
+ * @param index of this coalesced partition
+ * @param rdd which it belongs to
+ * @param parentsIndices list of indices in the parent that have been coalesced into this partition
+ * @param prefLoc the preferred location for this partition
+ */
+private[spark] class SortedMergePartition(
+ idx: Int,
+ @transient private val rdd: RDD[_],
+ val parentsIndices: Seq[Int],
+ val prefLoc: Option[String] = None) extends Partition {
+ override val index: Int = idx
+ var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
+ // Update the reference to parent partition at the time of task serialization
+ parents = parentsIndices.map(rdd.partitions(_))
+ oos.defaultWriteObject()
+ }
+}
+
+/**
+ * Iterator that performs k-way merge of sorted iterators.
+ *
+ * Uses a priority queue (min-heap) to efficiently find the next smallest element
+ * across all input iterators according to the specified ordering. This provides
+ * O(n log k) time complexity where n is the total number of elements and k is
+ * the number of iterators being merged.
+ *
+ * @param iterators The sequence of sorted iterators to merge
+ * @param ordering The ordering to use for comparison
+ * @tparam T The element type
+ */
+private[spark] class SortedMergeIterator[T](
+ iterators: Seq[Iterator[T]],
+ ordering: Ordering[T]) extends Iterator[T] {
+
+ // Priority queue entry: (current element, iterator index)
+ private case class QueueEntry(element: T, iteratorIdx: Int)
+
+ // Min-heap ordered by element according to the provided ordering
+ private implicit val queueOrdering: Ordering[QueueEntry] = new Ordering[QueueEntry] {
+ override def compare(x: QueueEntry, y: QueueEntry): Int = {
+ // Reverse for min-heap (PriorityQueue is max-heap by default)
+ ordering.compare(y.element, x.element)
+ }
+ }
+
+ private val queue = mutable.PriorityQueue.empty[QueueEntry]
+
+ // Initialize queue with first element from each non-empty iterator
+ iterators.zipWithIndex.foreach { case (iter, idx) =>
+ if (iter.hasNext) {
+ queue.enqueue(QueueEntry(iter.next(), idx))
+ }
+ }
+
+ override def hasNext: Boolean = queue.nonEmpty
+
+ override def next(): T = {
+ if (!hasNext) {
+ throw new NoSuchElementException("next on empty iterator")
+ }
+
+ val entry = queue.dequeue()
+ val result = entry.element
+
+ // If the iterator has more elements, add the next one to the queue
+ val iter = iterators(entry.iteratorIdx)
+ if (iter.hasNext) {
+ queue.enqueue(QueueEntry(iter.next(), entry.iteratorIdx))
+ }
+
+ result
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala
new file mode 100644
index 0000000000000..1cb83af2b3b69
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
+
+class SortedMergeCoalescedRDDSuite extends SparkFunSuite with SharedSparkContext {
+
+ test("SPARK-55715: k-way merge maintains ordering - integers") {
+ // Create RDD with 4 partitions, each sorted
+ val data = Seq(
+ Seq(1, 5, 9, 13), // partition 0
+ Seq(2, 6, 10, 14), // partition 1
+ Seq(3, 7, 11, 15), // partition 2
+ Seq(4, 8, 12, 16) // partition 3
+ )
+
+ val rdd = sc.parallelize(data, data.size).flatMap(identity)
+
+ // Coalesce to 2 partitions using sorted merge
+ val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1), Seq(2, 3)))
+ val merged = new SortedMergeCoalescedRDD[Int](rdd, 2, coalescer, Ordering.Int)
+
+ // Verify per-partition contents: group (0,1) merges elements from partitions 0+1,
+ // group (2,3) merges elements from partitions 2+3
+ val partitionData = merged
+ .mapPartitionsWithIndex { (idx, iter) => Iterator.single((idx, iter.toSeq)) }
+ .collect().toMap
+
+ assert(partitionData(0) === Seq(1, 2, 5, 6, 9, 10, 13, 14))
+ assert(partitionData(1) === Seq(3, 4, 7, 8, 11, 12, 15, 16))
+ }
+
+ test("SPARK-55715: k-way merge handles empty partitions") {
+ val data = Seq(
+ Seq(1, 5, 9), // partition 0
+ Seq.empty[Int], // partition 1 - empty
+ Seq(3, 7, 11), // partition 2
+ Seq.empty[Int] // partition 3 - empty
+ )
+
+ val rdd = sc.parallelize(data, data.size).flatMap(identity)
+
+ val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2, 3)))
+ val merged = new SortedMergeCoalescedRDD[Int]( rdd, 1, coalescer, Ordering.Int)
+
+ val result = merged.collect()
+ assert(result === Seq(1, 3, 5, 7, 9, 11))
+ }
+
+ test("SPARK-55715: k-way merge handles all-empty partitions in a group") {
+ val data = Seq(
+ Seq.empty[Int], // partition 0 - empty
+ Seq.empty[Int], // partition 1 - empty
+ Seq(1, 2, 3) // partition 2 - non-empty (different group)
+ )
+
+ val rdd = sc.parallelize(data, data.size).flatMap(identity)
+
+ val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1), Seq(2)))
+ val merged = new SortedMergeCoalescedRDD[Int]( rdd, 2, coalescer, Ordering.Int)
+
+ val partitionData = merged
+ .mapPartitionsWithIndex { (idx, iter) => Iterator.single((idx, iter.toSeq)) }
+ .collect().toMap
+
+ assert(partitionData(0) === Seq.empty)
+ assert(partitionData(1) === Seq(1, 2, 3))
+ }
+
+ test("SPARK-55715: k-way merge with single partition per group - no merge needed") {
+ val data = Seq(1, 2, 3, 4, 5, 6)
+ val rdd = sc.parallelize(data, 3)
+
+ // Each group has only 1 partition - should just pass through
+ val coalescer = new TestPartitionCoalescer(Seq(Seq(0), Seq(1), Seq(2)))
+ val merged = new SortedMergeCoalescedRDD[Int](rdd, 3, coalescer, Ordering.Int)
+
+ assert(merged.collect() === data)
+ }
+
+ test("SPARK-55715: k-way merge with reverse ordering") {
+ val data = Seq(
+ Seq(13, 9, 5, 1), // partition 0 - descending
+ Seq(14, 10, 6, 2), // partition 1 - descending
+ Seq(15, 11, 7, 3), // partition 2 - descending
+ Seq(16, 12, 8, 4) // partition 3 - descending
+ )
+
+ val rdd = sc.parallelize(data, data.size).flatMap(identity)
+
+ // Use reverse ordering
+ val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2, 3)))
+ val merged = new SortedMergeCoalescedRDD[Int](
+ rdd, 1, coalescer, Ordering.Int.reverse)
+
+ val result = merged.collect()
+ assert(result === (1 to 16).reverse)
+ }
+
+ test("SPARK-55715: k-way merge with many partitions") {
+ val numPartitions = 20
+ val rowsPerPartition = 50
+
+ // Create sorted data where partition i starts at i and increments by numPartitions
+ val data = (0 until numPartitions).map { partIdx =>
+ (0 until rowsPerPartition).map(i => partIdx + i * numPartitions)
+ }
+
+ val rdd = sc.parallelize(data, numPartitions).flatMap(identity)
+
+ // Coalesce all partitions into one
+ val coalescer = new TestPartitionCoalescer(Seq((0 until numPartitions)))
+ val merged = new SortedMergeCoalescedRDD[Int](rdd, 1, coalescer, Ordering.Int)
+
+ val result = merged.collect()
+ assert(result.length === numPartitions * rowsPerPartition)
+ assert(result === result.sorted)
+ }
+
+ test("SPARK-55715: k-way merge preserves duplicate elements across partitions") {
+ val data = Seq(
+ Seq(1, 2, 3), // partition 0
+ Seq(1, 2, 3), // partition 1 - identical to partition 0
+ Seq(2, 2, 4) // partition 2 - contains repeated value within partition
+ )
+
+ val rdd = sc.parallelize(data, data.size).flatMap(identity)
+
+ val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2)))
+ val merged = new SortedMergeCoalescedRDD[Int](rdd, 1, coalescer, Ordering.Int)
+
+ val result = merged.collect()
+ assert(result === Seq(1, 1, 2, 2, 2, 2, 3, 3, 4))
+ }
+
+ test("SPARK-55715: k-way merge with strings") {
+ val data = Seq(
+ Seq("apple", "cherry", "grape"),
+ Seq("banana", "date", "kiwi"),
+ Seq("apricot", "fig", "mango")
+ )
+
+ val rdd = sc.parallelize(data, data.size).flatMap(identity)
+
+ val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2)))
+ val merged = new SortedMergeCoalescedRDD[String](rdd, 1, coalescer, Ordering.String)
+
+ val result = merged.collect()
+ assert(result === result.sorted)
+ }
+
+ test("SPARK-55715: k-way merge with tuples") {
+ val data = Seq(
+ Seq((1, "a"), (3, "c"), (5, "e")),
+ Seq((2, "b"), (4, "d"), (6, "f")),
+ Seq((1, "z"), (4, "y"), (7, "x"))
+ )
+
+ val rdd = sc.parallelize(data, data.size).flatMap(identity)
+
+ implicit val tupleOrdering: Ordering[(Int, String)] = Ordering.by[(Int, String), Int](_._1)
+
+ val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2)))
+ val merged = new SortedMergeCoalescedRDD[(Int, String)](rdd, 1, coalescer, tupleOrdering)
+
+ val result = merged.collect()
+ val expected = data.flatten.sortBy(_._1)
+ assert(result === expected)
+ }
+
+ test("SPARK-55715: SortedMergeIterator - next() on empty iterator throws " +
+ "NoSuchElementException") {
+ val iter = new SortedMergeIterator[Int](Seq.empty, Ordering.Int)
+ assert(!iter.hasNext)
+ intercept[NoSuchElementException] { iter.next() }
+ }
+
+ test("SPARK-55715: SortedMergeIterator - empty iterators list") {
+ val iter = new SortedMergeIterator[Int](Seq(Iterator.empty, Iterator.empty), Ordering.Int)
+ assert(!iter.hasNext)
+ assert(iter.toSeq === Seq.empty)
+ }
+
+ test("SPARK-55715: SortedMergeIterator - single iterator passes through unchanged") {
+ val iter = new SortedMergeIterator[Int](Seq(Iterator(3, 1, 2)), Ordering.Int)
+ assert(iter.toSeq === Seq(3, 1, 2))
+ }
+}
+
+/**
+ * Test partition coalescer that groups partitions according to a predefined plan.
+ */
+class TestPartitionCoalescer(grouping: Seq[Seq[Int]]) extends PartitionCoalescer with Serializable {
+ override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = {
+ grouping.map { partitionIndices =>
+ val pg = new PartitionGroup(None)
+ partitionIndices.foreach { idx =>
+ pg.partitions += parent.partitions(idx)
+ }
+ pg
+ }.toArray
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java
index 9e4b75c4bd112..0dba93af5f8d6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java
@@ -47,4 +47,27 @@ public interface CustomTaskMetric {
* Returns the long value of custom task metric.
*/
long value();
+
+ /**
+ * Merges this metric with another metric of the same name, returning a new
+ * {@link CustomTaskMetric} that represents the combined value. This is called when a task reads
+ * multiple partitions concurrently (e.g., k-way merge coalescing) to produce a single
+ * task-level value before reporting to the driver.
+ *
+ *
The default implementation returns a new metric whose value is the sum of the two values,
+ * which is correct for count-type metrics. Data sources with non-additive metrics (e.g., max,
+ * average) should override this method.
+ *
+ * @param other another metric with the same name to merge with
+ * @return a new metric representing the merged value
+ * @since 4.1.0
+ */
+ default CustomTaskMetric mergeWith(CustomTaskMetric other) {
+ final String metricName = this.name();
+ final long mergedValue = this.value() + other.value();
+ return new CustomTaskMetric() {
+ @Override public String name() { return metricName; }
+ @Override public long value() { return mergedValue; }
+ };
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
index c12bc14a49c44..8420b6bdbdaf5 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
@@ -65,6 +65,12 @@ default CustomTaskMetric[] currentMetricsValues() {
* {@link org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} and the reader
* is initialized with the metrics returned by the previous reader that belongs to the same
* partition. By default, this method does nothing.
+ *
+ * @deprecated Use {@link CustomTaskMetric#mergeWith(CustomTaskMetric)} instead. When a task
+ * reads multiple partitions concurrently or sequentially, {@code DataSourceRDD} now merges
+ * metrics from all readers via {@code mergeWith} at reporting time, removing the need to
+ * seed each new reader with the prior reader's values.
*/
+ @Deprecated(since = "4.2.0")
default void initMetricsValues(CustomTaskMetric[] metrics) {}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5639b6bbfbf4f..7bd53ed4eda18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2191,6 +2191,21 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED =
+ buildConf("spark.sql.sources.v2.bucketing.preserveOrderingOnCoalesce.enabled")
+ .doc(s"When turned on, GroupPartitionsExec will use sorted merge to preserve full " +
+ s"ordering (as opposed to the key-derived ordering preserved by " +
+ s"${V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED.key}) when coalescing " +
+ s"multiple partitions with the same key. This allows eliminating downstream sorts when " +
+ s"data is both partitioned and sorted. However, sorted merge uses more resources " +
+ s"(priority queue, comparison overhead) than simple concatenation, especially when " +
+ s"coalescing many partitions. When turned off, only key-derived ordering is preserved " +
+ s"during coalescing. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled.")
+ .version("4.2.0")
+ .withBindingPolicy(ConfigBindingPolicy.SESSION)
+ .booleanConf
+ .createWithDefault(false)
+
val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
@@ -7763,6 +7778,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingPreserveKeyOrderingOnCoalesceEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED)
+ def v2BucketingPreserveOrderingOnCoalesceEnabled: Boolean =
+ getConf(SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED)
+
def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index e7762565f47ec..decff8e2bcbf0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -429,8 +429,15 @@ abstract class InMemoryBaseTable(
private var _pushedFilters: Array[Filter] = Array.empty
override def build: Scan = {
- val scan = InMemoryBatchScan(
- data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options)
+ val scan = if (InMemoryBaseTable.this.ordering.nonEmpty) {
+ new InMemoryBatchScanWithOrdering(
+ data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema,
+ options)
+ } else {
+ InMemoryBatchScan(
+ data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema,
+ options)
+ }
if (evaluableFilters.nonEmpty) {
scan.filter(evaluableFilters)
}
@@ -596,6 +603,19 @@ abstract class InMemoryBaseTable(
}
}
+ // Extends InMemoryBatchScan with SupportsReportOrdering. Only instantiated when the table has a
+ // non-empty ordering, so that V2ScanPartitioningAndOrdering only sets ordering = Some(...) on the
+ // logical plan when there is actual ordering to report.
+ private class InMemoryBatchScanWithOrdering(
+ data: Seq[InputPartition],
+ readSchema: StructType,
+ tableSchema: StructType,
+ options: CaseInsensitiveStringMap)
+ extends InMemoryBatchScan(data, readSchema, tableSchema, options)
+ with SupportsReportOrdering {
+ override def outputOrdering(): Array[SortOrder] = InMemoryBaseTable.this.ordering
+ }
+
abstract class InMemoryWriterBuilder(val info: LogicalWriteInfo)
extends SupportsTruncate with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend {
@@ -846,8 +866,13 @@ private class BufferedRowsReader(
private var index: Int = -1
private var rowsRead: Long = 0
+ private var closed: Boolean = false
+
+ private def checkNotClosed(op: String): Unit =
+ if (closed) throw new IllegalStateException(s"$op called on a closed BufferedRowsReader")
override def next(): Boolean = {
+ checkNotClosed("next()")
index += 1
val hasNext = index < partition.rows.length
if (hasNext) rowsRead += 1
@@ -855,6 +880,7 @@ private class BufferedRowsReader(
}
override def get(): InternalRow = {
+ checkNotClosed("get()")
val originalRow = partition.rows(index)
val values = new Array[Any](nonMetadataColumns.length)
nonMetadataColumns.zipWithIndex.foreach { case (col, idx) =>
@@ -864,7 +890,10 @@ private class BufferedRowsReader(
addMetadata(new GenericInternalRow(values))
}
- override def close(): Unit = {}
+ override def close(): Unit = {
+ checkNotClosed("close()")
+ closed = true
+ }
private def extractFieldValue(
field: StructField,
@@ -995,15 +1024,8 @@ private class BufferedRowsReader(
private def castElement(elem: Any, toType: DataType, fromType: DataType): Any =
Cast(Literal(elem, fromType), toType, None, EvalMode.TRY).eval(null)
- override def initMetricsValues(metrics: Array[CustomTaskMetric]): Unit = {
- metrics.foreach { m =>
- m.name match {
- case "rows_read" => rowsRead = m.value()
- }
- }
- }
-
override def currentMetricsValues(): Array[CustomTaskMetric] = {
+ checkNotClosed("currentMetricsValues()")
val metric = new CustomTaskMetric {
override def name(): String = "rows_read"
override def value(): Long = rowsRead
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala
new file mode 100644
index 0000000000000..08be66e3f4477
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+/**
+ * Marker trait for execution nodes that are safe to use with SortedMergeCoalescedRDD
+ * (concurrent k-way merge). Nodes implementing this trait must store no per-partition
+ * mutable state in shared plan-node instance fields; all per-partition state must be
+ * captured inside the partition's iterator closure (e.g. via the
+ * PartitionEvaluatorFactory pattern).
+ */
+trait SafeForKWayMerge
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 246508965d3d6..c1c6f36475475 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -501,7 +501,8 @@ trait InputRDDCodegen extends CodegenSupport {
* This is the leaf node of a tree with WholeStageCodegen that is used to generate code
* that consumes an RDD iterator of InternalRow.
*/
-case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen {
+case class InputAdapter(child: SparkPlan)
+ extends UnaryExecNode with InputRDDCodegen with SafeForKWayMerge {
override def output: Seq[Attribute] = child.output
@@ -633,7 +634,7 @@ object WholeStageCodegenExec {
* used to generated code for [[BoundReference]].
*/
case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
- extends UnaryExecNode with CodegenSupport {
+ extends UnaryExecNode with CodegenSupport with SafeForKWayMerge {
override def output: Seq[Attribute] = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 5e73cdd6da8f4..bafe7e568bf2e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -42,6 +42,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode
with CodegenSupport
+ with SafeForKWayMerge
with PartitioningPreservingUnaryExecNode
with OrderPreservingUnaryExecNode {
@@ -222,7 +223,7 @@ trait GeneratePredicateHelper extends PredicateHelper {
/** Physical plan for Filter. */
case class FilterExec(condition: Expression, child: SparkPlan)
- extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper {
+ extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper with SafeForKWayMerge {
// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 2fedb97e8461e..d7bb1d9093edf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -19,34 +19,111 @@ package org.apache.spark.sql.execution.datasources.v2
import java.util.concurrent.ConcurrentHashMap
-import scala.language.existentials
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
import org.apache.spark.sql.vectorized.ColumnarBatch
-import org.apache.spark.util.ArrayImplicits._
class DataSourceRDDPartition(val index: Int, val inputPartition: Option[InputPartition])
extends Partition with Serializable
/**
- * Holds the state for a reader in a task, used by the completion listener to access the most
- * recently created reader and iterator for final metrics updates and cleanup.
+ * Holds all mutable state for a single Spark task reading from a {@link DataSourceRDD}:
+ *
+ * - {@code partitionIterators}: all {@link PartitionIterator}s opened so far for this task.
+ * One iterator is appended per {@code compute()} call (one per input partition).
+ * - Spark input metrics ({@code recordsRead}, {@code bytesRead}): owned exclusively by this
+ * object so that {@code setBytesRead} -- a set, not an increment -- is called from a single
+ * owner even when multiple iterators are live concurrently.
+ * - {@code closedMetrics}: a pre-merged map of custom metrics from readers closed by natural
+ * exhaustion, kept so iterator references can be released as readers finish.
+ *
*
- * When `compute()` is called multiple times for the same task (e.g., when DataSourceRDD is
- * coalesced), this state is updated on each call to track the most recent reader. The task
- * completion listener then uses this most recent reader for final cleanup and metrics reporting.
+ * When metrics are reported:
+ *
+ * - Periodically: {@link #updateMetrics} is called on every {@code next()} and
+ * throttles updates to once per {@code UPDATE_INPUT_METRICS_INTERVAL_RECORDS} rows (input
+ * metrics) and {@code NUM_ROWS_PER_UPDATE} rows (custom metrics).
+ * - On natural exhaustion: when a reader's iterator is fully consumed, {@code hasNext}
+ * calls {@code updateMetrics(0, force=true)} to flush both input and custom metrics
+ * immediately, then closes the reader and drops the reference.
+ * - At task completion: the task completion listener calls
+ * {@code updateMetrics(0, force=true)} for a final flush, then closes any iterators not
+ * yet exhausted.
+ *
*
- * @param reader The partition reader
- * @param iterator The metrics iterator wrapping the reader
+ * Why this works across all execution modes:
+ *
+ * - One partition per task: a single iterator is opened and closed; periodic + final
+ * updates cover the full read.
+ * - Sequential coalescing ({@link CoalescedRDD}): partitions are read one at a time.
+ * Each reader is naturally exhausted before the next opens, so its final metrics are folded
+ * into {@code closedMetrics} before its reference is released. The merged view in
+ * {@code mergeAndUpdateCustomMetrics} therefore always includes all partitions read so
+ * far.
+ * - Concurrent k-way merge ({@link SortedMergeCoalescedRDD}): all N iterators are
+ * opened upfront and interleaved on a single thread. All live readers' current metrics are
+ * merged together on each update, so none are lost. When individual readers exhaust their
+ * metrics are folded into {@code closedMetrics} for continued accounting.
+ *
*/
-private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIterator[_])
+private class TaskState(customMetrics: Map[String, SQLMetric]) {
+ val partitionIterators = new ArrayBuffer[PartitionIterator[_]]()
+
+ // Input metrics (recordsRead, bytesRead) tracked for this task.
+ private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics
+ private val startingBytesRead = inputMetrics.bytesRead
+ private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ private var recordsReadAtLastBytesUpdate = 0L
+ private var recordsReadAtLastCustomMetricsUpdate = 0L
+
+ // Pre-merged custom metrics snapshot of all readers closed by natural exhaustion.
+ // Maintained as a map (one entry per metric name).
+ private val closedMetrics = new HashMap[String, CustomTaskMetric]()
+
+ def updateMetrics(numRows: Int, force: Boolean = false): Unit = {
+ inputMetrics.incRecordsRead(numRows)
+ val shouldUpdateBytesRead = force ||
+ inputMetrics.recordsRead - recordsReadAtLastBytesUpdate >=
+ SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS
+ if (shouldUpdateBytesRead) {
+ recordsReadAtLastBytesUpdate = inputMetrics.recordsRead
+ inputMetrics.setBytesRead(startingBytesRead + getBytesRead())
+ }
+ val shouldUpdateCustomMetrics = force ||
+ inputMetrics.recordsRead - recordsReadAtLastCustomMetricsUpdate >=
+ CustomMetrics.NUM_ROWS_PER_UPDATE
+ if (shouldUpdateCustomMetrics) {
+ recordsReadAtLastCustomMetricsUpdate = inputMetrics.recordsRead
+ mergeAndUpdateCustomMetrics()
+ }
+ }
+
+ private def mergeAndUpdateCustomMetrics(): Unit = {
+ partitionIterators.filterInPlace { iter =>
+ if (iter.isClosed) {
+ iter.finalMetrics.foreach { m =>
+ closedMetrics.update(m.name(), closedMetrics.get(m.name()).fold(m)(_.mergeWith(m)))
+ }
+ false
+ } else true
+ }
+ val mergedMetrics = (partitionIterators.flatMap(_.currentMetricsValues) ++ closedMetrics.values)
+ .groupMapReduce(_.name())(identity)(_.mergeWith(_))
+ .values
+ .toSeq
+ if (mergedMetrics.nonEmpty) {
+ CustomMetrics.updateMetrics(mergedMetrics, customMetrics)
+ }
+ }
+}
// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for
// columnar scan.
@@ -71,10 +148,8 @@ class DataSourceRDD(
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {
- // Map from task attempt ID to the most recently created ReaderState for that task.
- // When compute() is called multiple times for the same task (due to coalescing), the map entry
- // is updated each time so the completion listener always closes the last reader.
- @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long, ReaderState]()
+ // One TaskState per task attempt.
+ @transient private lazy val taskStates = new ConcurrentHashMap[Long, TaskState]()
override protected def getPartitions: Array[Partition] = {
inputPartitions.zipWithIndex.map {
@@ -90,52 +165,39 @@ class DataSourceRDD(
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val taskAttemptId = context.taskAttemptId()
- // Add completion listener only once per task attempt. When compute() is called a second time
- // for the same task (e.g., due to coalescing), the first call will have already put a
- // ReaderState into taskReaderStates, so containsKey returns true and we skip this block.
- if (!taskReaderStates.containsKey(taskAttemptId)) {
+ // Ensure a TaskState exists for this task and register the completion listener on the
+ // first compute() call. computeIfAbsent is atomic; same-task calls are always on one
+ // thread, so partitionIterators.isEmpty reliably identifies the first call.
+ val taskState = taskStates.computeIfAbsent(taskAttemptId, _ => new TaskState(customMetrics))
+
+ if (taskState.partitionIterators.isEmpty) {
context.addTaskCompletionListener[Unit] { ctx =>
- // In case of early stopping before consuming the entire iterator,
- // we need to do one more metric update at the end of the task.
+ // In case of early stopping, do a final metrics update and close all readers.
try {
- val readerState = taskReaderStates.get(ctx.taskAttemptId())
- if (readerState != null) {
- CustomMetrics.updateMetrics(
- readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
- readerState.iterator.forceUpdateMetrics()
- readerState.reader.close()
+ val taskState = taskStates.get(ctx.taskAttemptId())
+ if (taskState != null) {
+ taskState.updateMetrics(0, force = true)
+ taskState.partitionIterators.foreach(_.close())
}
} finally {
- taskReaderStates.remove(ctx.taskAttemptId())
+ taskStates.remove(ctx.taskAttemptId())
}
}
}
castPartition(split).inputPartition.iterator.flatMap { inputPartition =>
- val (iter, reader) = if (columnarReads) {
+ val iter = if (columnarReads) {
val batchReader = partitionReaderFactory.createColumnarReader(inputPartition)
- val iter = new MetricsBatchIterator(
- new PartitionIterator[ColumnarBatch](batchReader, customMetrics))
- (iter, batchReader)
+ new PartitionIterator[ColumnarBatch](batchReader, _.numRows, taskState)
} else {
val rowReader = partitionReaderFactory.createReader(inputPartition)
- val iter = new MetricsRowIterator(
- new PartitionIterator[InternalRow](rowReader, customMetrics))
- (iter, rowReader)
+ new PartitionIterator[InternalRow](rowReader, _ => 1, taskState)
}
- // Flush metrics and close the previous reader before advancing to the next one.
- // Pass the accumulated metrics to the new reader so they carry forward correctly.
- val prevState = taskReaderStates.get(taskAttemptId)
- if (prevState != null) {
- val metrics = prevState.reader.currentMetricsValues
- CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
- reader.initMetricsValues(metrics)
- prevState.reader.close()
- }
-
- // Update the map so the completion listener always references the latest reader.
- taskReaderStates.put(taskAttemptId, ReaderState(reader, iter))
+ // Track this iterator; early-stop close and final metrics-flush for iterators not yet
+ // naturally exhausted are handled by the task completion listener. This avoids closing
+ // live iterators prematurely in the concurrent k-way merge (SortedMergeCoalescedRDD).
+ taskState.partitionIterators += iter
// TODO: SPARK-25083 remove the type erasure hack in data source scan
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]])
@@ -148,16 +210,36 @@ class DataSourceRDD(
}
private class PartitionIterator[T](
- reader: PartitionReader[T],
- customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
- private[this] var valuePrepared = false
- private[this] var hasMoreInput = true
+ private var reader: PartitionReader[T],
+ rowCount: T => Int,
+ taskState: TaskState) extends Iterator[T] {
+ private var valuePrepared = false
+ private var hasMoreInput = true
+
+ // Cached final metrics snapshot, captured just before the reader is closed on natural
+ // exhaustion. Allows mergeAndUpdateCustomMetrics() to include this reader's contribution
+ // after close() has been called and currentMetricsValues() is no longer valid.
+ private var cachedFinalMetrics: Array[CustomTaskMetric] = Array.empty
+
+ def isClosed: Boolean = reader == null
+
+ def finalMetrics: Array[CustomTaskMetric] = cachedFinalMetrics
+
+ def close(): Unit = if (reader != null) {
+ reader.close()
+ reader = null
+ }
- private var numRow = 0L
+ def currentMetricsValues: Array[CustomTaskMetric] = reader.currentMetricsValues
override def hasNext: Boolean = {
if (!valuePrepared && hasMoreInput) {
hasMoreInput = reader.next()
+ if (!hasMoreInput) {
+ cachedFinalMetrics = reader.currentMetricsValues
+ taskState.updateMetrics(0, force = true)
+ close()
+ }
valuePrepared = hasMoreInput
}
valuePrepared
@@ -167,59 +249,9 @@ private class PartitionIterator[T](
if (!hasNext) {
throw QueryExecutionErrors.endOfStreamError()
}
- if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
- CustomMetrics.updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
- }
- numRow += 1
valuePrepared = false
- reader.get()
- }
-}
-
-private class MetricsHandler extends Logging with Serializable {
- private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics
- private val startingBytesRead = inputMetrics.bytesRead
- private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
-
- def updateMetrics(numRows: Int, force: Boolean = false): Unit = {
- inputMetrics.incRecordsRead(numRows)
- val shouldUpdateBytesRead =
- inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0
- if (shouldUpdateBytesRead || force) {
- inputMetrics.setBytesRead(startingBytesRead + getBytesRead())
- }
- }
-}
-
-private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] {
- protected val metricsHandler = new MetricsHandler
-
- override def hasNext: Boolean = {
- if (iter.hasNext) {
- true
- } else {
- forceUpdateMetrics()
- false
- }
- }
-
- def forceUpdateMetrics(): Unit = metricsHandler.updateMetrics(0, force = true)
-}
-
-private class MetricsRowIterator(
- iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) {
- override def next(): InternalRow = {
- val item = iter.next()
- metricsHandler.updateMetrics(1)
- item
- }
-}
-
-private class MetricsBatchIterator(
- iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) {
- override def next(): ColumnarBatch = {
- val batch: ColumnarBatch = iter.next()
- metricsHandler.updateMetrics(batch.numRows)
- batch
+ val result = reader.get()
+ taskState.updateMetrics(rowCount(result))
+ result
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
index a1a6c6e022482..0bec918039775 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala
@@ -24,14 +24,14 @@ import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan}
-import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SQLExecution}
+import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SafeForKWayMerge, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.connector.SupportsMetadata
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils
-trait DataSourceV2ScanExecBase extends LeafExecNode {
+trait DataSourceV2ScanExecBase extends LeafExecNode with SafeForKWayMerge {
lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric =>
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
index 81981c29b2b31..e552c4f71641f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala
@@ -21,13 +21,14 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Partition, SparkException}
-import org.apache.spark.rdd.{CoalescedRDD, PartitionCoalescer, PartitionGroup, RDD}
+import org.apache.spark.rdd.{CoalescedRDD, PartitionCoalescer, PartitionGroup, RDD, SortedMergeCoalescedRDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
import org.apache.spark.sql.connector.catalog.functions.Reducer
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.{SafeForKWayMerge, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -158,16 +159,91 @@ case class GroupPartitionsExec(
@transient lazy val isGrouped: Boolean = groupedPartitionsTuple._2
+ // Whether the child subtree is safe to use with SortedMergeCoalescedRDD (k-way merge).
+ //
+ // --- The general problem ---
+ //
+ // Unlike a simple CoalescedRDD, which processes input partitions sequentially (open P0,
+ // exhaust P0, open P1, exhaust P1, ...), SortedMergeCoalescedRDD must perform a k-way merge:
+ // it opens *all* input partition iterators upfront and then interleaves reads across them to
+ // produce a globally sorted output partition. Crucially, all of this happens on a single JVM
+ // thread within a single Spark task.
+ //
+ // A SparkPlan object is shared across all partition computations of a task. When N partition
+ // iterators are live concurrently on the same thread, N independent "computations" are all
+ // operating through the same plan node instances. If any node in the subtree stores
+ // per-partition state in an instance field rather than as a local variable captured inside the
+ // partition's iterator closure, that field is aliased across all N concurrent computations.
+ // Whichever computation last wrote the field "wins", and any computation that then reads or
+ // frees it based on its own earlier write will operate on the wrong state.
+ //
+ // --- The correct pattern: PartitionEvaluatorFactory ---
+ //
+ // The PartitionEvaluatorFactory / PartitionEvaluator pattern is specifically designed to avoid
+ // this problem. The factory's createEvaluator() is called once per partition and returns a
+ // fresh PartitionEvaluator instance. All per-partition mutable state lives inside that
+ // evaluator instance, not on the shared plan node. Operators that follow this pattern
+ // exclusively (and hold no other mutable state on the plan node) are safe for k-way merge.
+ //
+ // --- Concrete example of an unsafe operator: SortExec + SortMergeJoinExec ---
+ //
+ // SortExec stores its active sorter in a plain var field (`rowSorter`) on the plan node.
+ // When the k-way merge initialises its N partition iterators, each one drives the same SortExec
+ // instance and calls createSorter(), which assigns rowSorter = newSorter -- overwriting the
+ // field each time. After all N iterators are initialised, rowSorter holds only the sorter
+ // created for the *last* partition.
+ //
+ // SortMergeJoinExec performs eager resource cleanup: when a join partition is exhausted it
+ // calls cleanupResources() on its children, which reaches SortExec.cleanupResources(). That
+ // method calls rowSorter.cleanupResources() -- but rowSorter now holds the last-created sorter,
+ // not the one belonging to the just-exhausted partition. If that last sorter is still being
+ // actively read by another partition in the k-way merge, freeing it causes a use-after-free.
+ //
+ // To be conservative, we use a whitelist: unknown node types fall through to unsafe, causing
+ // a fallback to simple sequential coalescing. Only node types explicitly confirmed to store no
+ // per-partition state in shared (plan node) instance fields are permitted.
+ @transient private lazy val childIsSafeForKWayMerge: Boolean =
+ !child.exists {
+ case _: SafeForKWayMerge => false
+ case _ => true
+ }
+
override protected def doExecute(): RDD[InternalRow] = {
if (groupedPartitions.isEmpty) {
sparkContext.emptyRDD
+ } else if (SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled &&
+ child.outputOrdering.nonEmpty &&
+ childIsSafeForKWayMerge &&
+ groupedPartitions.exists(_._2.size > 1)) {
+ // Use sorted merge when:
+ // 1. Config is enabled
+ // 2. Child has ordering
+ // 3. Actually coalescing multiple partitions
+ // Convert SortOrder expressions to Ordering[InternalRow]
+ val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2))
+ val rowOrdering = new InterpretedOrdering(child.outputOrdering, child.output)
+ new SortedMergeCoalescedRDD[InternalRow](
+ child.execute(),
+ groupedPartitions.size,
+ partitionCoalescer,
+ rowOrdering)
} else {
val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2))
+ // Use simple coalescing when config is disabled, no ordering, or no actual coalescing
new CoalescedRDD(child.execute(), groupedPartitions.size, Some(partitionCoalescer))
}
}
- override def supportsColumnar: Boolean = child.supportsColumnar
+ override def supportsColumnar: Boolean = {
+ // Don't use columnar when sorted merge coalescing is needed, since we can't preserve
+ // ordering with sorted merge for columnar batches
+ val needsSortedMerge = SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled &&
+ child.outputOrdering.nonEmpty &&
+ childIsSafeForKWayMerge &&
+ groupedPartitions.exists(_._2.size > 1)
+
+ child.supportsColumnar && !needsSortedMerge
+ }
override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
if (groupedPartitions.isEmpty) {
@@ -189,6 +265,12 @@ case class GroupPartitionsExec(
// within-partition ordering is fully preserved (including any key-derived ordering that
// `DataSourceV2ScanExecBase` already prepended).
child.outputOrdering
+ } else if (SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled &&
+ child.outputOrdering.nonEmpty &&
+ childIsSafeForKWayMerge) {
+ // Coalescing with sorted merge: SortedMergeCoalescedRDD performs a k-way merge using the
+ // child's ordering, so the full within-partition ordering is preserved end-to-end.
+ child.outputOrdering
} else {
// Coalescing: multiple input partitions are merged into one output partition. The child's
// within-partition ordering is lost due to concatenation -- for example, if two input
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
index 5d06c8786d894..7f51875b971fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala
@@ -69,7 +69,8 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with Logging {
private def ordering(plan: LogicalPlan) = plan.transformDown {
case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportOrdering, _, _, _) =>
- val ordering = V2ExpressionUtils.toCatalystOrdering(scan.outputOrdering(), relation)
+ val ordering =
+ V2ExpressionUtils.toCatalystOrdering(scan.outputOrdering(), relation, relation.funCatalog)
d.copy(ordering = Some(ordering))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index 44cb3a23cfa7c..4a406322a5a19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -20,6 +20,7 @@ import java.sql.Timestamp
import java.util.Collections
import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.rdd.SortedMergeCoalescedRDD
import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Literal, TransformExpression}
@@ -252,9 +253,10 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
table: String,
columns: Array[Column],
partitions: Array[Transform],
+ ordering: Array[SortOrder] = Array.empty,
catalog: InMemoryTableCatalog = catalog): Unit = {
catalog.createTable(Identifier.of(Array("ns"), table),
- columns, partitions, emptyProps, Distributions.unspecified(), Array.empty, None, None,
+ columns, partitions, emptyProps, Distributions.unspecified(), ordering, None, None,
numRowsPerSplit = 1)
}
@@ -3091,7 +3093,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
assert(groupPartitions(0).outputPartitioning.numPartitions === 2,
"group partitions should have 2 partition groups")
}
- assert(metrics("number of rows read") == "3")
+ assert(metrics.collect {
+ case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v
+ } === Seq("3"))
}
test("SPARK-55619: Custom metrics of coalesced partitions") {
@@ -3106,7 +3110,68 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1)
df.collect()
}
- assert(metrics("number of rows read") == "2")
+ assert(metrics.collect {
+ case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v
+ } === Seq("2"))
+ }
+
+ test("SPARK-55715: Custom metrics of sorted-merge coalesced partitions") {
+ // items has id=1 on three splits with interleaved arrive_times -- out of order across splits.
+ // purchases has item_id=1 on two splits, also out of order. Both sides coalesce under SMJ,
+ // using SortedMergeCoalescedRDD with multiple concurrent readers per task. This test verifies
+ // that all rows from both tables (5 + 4 = 9) are accounted for in the per-scan metrics.
+ val itemOrdering = Array(
+ sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST),
+ sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST))
+ createTable(items, itemsColumns, Array(identity("id")), itemOrdering)
+ // Rows inserted out of order: id=1 lands on partitions 1, 3, 4 with arrive_times
+ // [2022-03-10, 2021-05-20, 2025-09-01] -- out of order.
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(3, 'dd', 40.0, cast('2024-01-01' as timestamp)), " +
+ "(1, 'bb', 20.0, cast('2022-03-10' as timestamp)), " +
+ "(2, 'cc', 30.0, cast('2023-06-15' as timestamp)), " +
+ "(1, 'aa', 10.0, cast('2021-05-20' as timestamp)), " +
+ "(1, 'ee', 50.0, cast('2025-09-01' as timestamp))")
+
+ val purchaseOrdering = Array(
+ sort(FieldReference("item_id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST),
+ sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST))
+ createTable(purchases, purchasesColumns, Array(identity("item_id")), purchaseOrdering)
+ // item_id=1 lands on partitions 1 and 3 with times [2022-03-10, 2021-05-20] -- out of order.
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(2, 30.0, cast('2023-06-15' as timestamp)), " +
+ "(1, 20.0, cast('2022-03-10' as timestamp)), " +
+ "(3, 40.0, cast('2024-01-01' as timestamp)), " +
+ "(1, 10.0, cast('2021-05-20' as timestamp))")
+
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true",
+ SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "true") {
+ val metrics = runAndFetchMetrics {
+ val df = sql(
+ s"""${selectWithMergeJoinHint("i", "p")}
+ |i.id, i.name
+ |FROM testcat.ns.$items i
+ |JOIN testcat.ns.$purchases p ON p.item_id = i.id AND p.time = i.arrive_time
+ |""".stripMargin)
+ checkAnswer(df, Seq(Row(1, "aa"), Row(1, "bb"), Row(2, "cc"), Row(3, "dd")))
+ val plan = df.queryExecution.executedPlan
+ val groupPartitions = collectAllGroupPartitions(plan)
+ val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1))
+ assert(coalescingGP.nonEmpty, "expected a coalescing GroupPartitionsExec")
+ coalescingGP.foreach { gp =>
+ assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]],
+ "should use SortedMergeCoalescedRDD when preserve-ordering config is enabled")
+ }
+ }
+ assert(metrics.collect {
+ case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v
+ } === Seq("5"))
+ assert(metrics.collect {
+ case ((_, "BatchScan testcat.ns.purchases", "number of rows read"), v) => v
+ } === Seq("4"))
+ }
}
test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
@@ -3690,4 +3755,162 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
}
}
}
+
+ test("SPARK-55715: preserve outputOrdering when coalescing partitions with sorted merge") {
+ // Both tables are partitioned by their id column and report ordering [id ASC, price ASC]
+ // via SupportsReportOrdering. Each has two rows with id=1 (two splits), so GroupPartitionsExec
+ // must coalesce them. We join on (id, price) = (item_id, price) using SMJ.
+ //
+ // With config enabled: SortedMergeCoalescedRDD performs a k-way merge preserving the full
+ // [id ASC, price ASC] ordering -> EnsureRequirements is satisfied -> no SortExec added.
+ // With config disabled: simple CoalescedRDD concatenates the splits and only the key-derived
+ // [id ASC] ordering survives -> price ordering is lost -> SortExec is added for price.
+ val itemOrdering = Array(
+ sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST),
+ sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST))
+ createTable(items, itemsColumns, Array(identity("id")), itemOrdering)
+ // Rows inserted out of order: id values are interleaved and arrive_time is not monotone
+ // within each id group, so ordering by [id ASC, arrive_time ASC] is non-trivial.
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(2, 'cc', 30.0, cast('2023-06-15' as timestamp)), " +
+ "(1, 'bb', 20.0, cast('2022-03-10' as timestamp)), " +
+ "(3, 'dd', 40.0, cast('2024-01-01' as timestamp)), " +
+ "(1, 'aa', 10.0, cast('2021-05-20' as timestamp)), " +
+ "(2, 'ee', 50.0, cast('2025-09-01' as timestamp))")
+
+ val purchaseOrdering = Array(
+ sort(FieldReference("item_id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST),
+ sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST))
+ createTable(purchases, purchasesColumns, Array(identity("item_id")), purchaseOrdering)
+ // Also inserted out of order
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(2, 50.0, cast('2025-09-01' as timestamp)), " +
+ "(1, 10.0, cast('2021-05-20' as timestamp)), " +
+ "(3, 40.0, cast('2024-01-01' as timestamp)), " +
+ "(2, 30.0, cast('2023-06-15' as timestamp)), " +
+ "(1, 20.0, cast('2022-03-10' as timestamp))")
+
+ Seq(true, false).foreach { preserveOrdering =>
+ withSQLConf(
+ SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
+ SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true",
+ SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key ->
+ preserveOrdering.toString) {
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint("i", "p")}
+ |i.id, i.name
+ |FROM testcat.ns.$items i
+ |JOIN testcat.ns.$purchases p ON p.item_id = i.id AND p.time = i.arrive_time
+ |""".stripMargin)
+ checkAnswer(df, Seq(
+ Row(1, "aa"), Row(1, "bb"), Row(2, "cc"), Row(2, "ee"), Row(3, "dd")))
+
+ val plan = df.queryExecution.executedPlan
+ assert(collectAllShuffles(plan).isEmpty, "should not contain any shuffle")
+
+ val groupPartitions = collectAllGroupPartitions(plan)
+ assert(groupPartitions.nonEmpty, "should contain GroupPartitionsExec for coalescing")
+ assert(groupPartitions.exists(_.groupedPartitions.exists(_._2.size > 1)),
+ "expected coalescing GroupPartitionsExec")
+
+ val smjs = collect(plan) { case j: SortMergeJoinExec => j }
+ assert(smjs.nonEmpty, "expected SortMergeJoinExec in plan")
+ smjs.foreach { smj =>
+ val sorts = smj.children.flatMap(child => collect(child) { case s: SortExec => s })
+ if (preserveOrdering) {
+ assert(sorts.isEmpty,
+ "config enabled: SortedMergeCoalescedRDD preserves [id ASC, arrive_time ASC], " +
+ "no SortExec should be added before SMJ")
+
+ // Also verify the k-way merge RDD is actually used
+ val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1))
+ coalescingGP.foreach { gp =>
+ assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]],
+ "config enabled: should use SortedMergeCoalescedRDD")
+ }
+ } else {
+ assert(sorts.nonEmpty,
+ "config disabled: simple coalescing loses arrive_time ordering, " +
+ "SortExec should be added before SMJ")
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-55715: preserve outputOrdering when coalescing transform-partitioned splits") {
+ // Both tables are partitioned by years("arrive_time") / years("time") and report ordering
+ // [arrive_time ASC] / [time ASC]. Two rows share the same year bucket (2022 and 2023), so
+ // GroupPartitionsExec coalesces two splits per year. We join solely on
+ // p.time = i.arrive_time (the partition key expression) using SMJ.
+ //
+ // With config enabled: SortedMergeCoalescedRDD k-way merge preserves [arrive_time ASC]
+ // ordering -> EnsureRequirements is satisfied -> no SortExec added.
+ // With config disabled: simple CoalescedRDD only preserves the key-derived year ordering ->
+ // arrive_time ordering within a year is lost -> SortExec is added.
+ val itemOrdering = Array(
+ sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST))
+ createTable(items, itemsColumns, Array(years("arrive_time")), itemOrdering)
+ // Inserted out of order: within year 2022, September is before March in insertion order
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(2, 'bb', 20.0, cast('2022-09-20' as timestamp)), " +
+ "(4, 'dd', 40.0, cast('2023-11-05' as timestamp)), " +
+ "(1, 'aa', 10.0, cast('2022-03-15' as timestamp)), " +
+ "(3, 'cc', 30.0, cast('2023-01-10' as timestamp))")
+
+ val purchaseOrdering = Array(
+ sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST))
+ createTable(purchases, purchasesColumns, Array(years("time")), purchaseOrdering)
+ // Also inserted out of order
+ sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
+ "(2, 20.0, cast('2022-09-20' as timestamp)), " +
+ "(4, 40.0, cast('2023-11-05' as timestamp)), " +
+ "(1, 10.0, cast('2022-03-15' as timestamp)), " +
+ "(3, 30.0, cast('2023-01-10' as timestamp))")
+
+ Seq(true, false).foreach { preserveOrdering =>
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key ->
+ preserveOrdering.toString) {
+ val df = sql(
+ s"""
+ |${selectWithMergeJoinHint("i", "p")}
+ |i.id, i.name
+ |FROM testcat.ns.$items i
+ |JOIN testcat.ns.$purchases p ON p.time = i.arrive_time
+ |""".stripMargin)
+ checkAnswer(df, Seq(Row(1, "aa"), Row(2, "bb"), Row(3, "cc"), Row(4, "dd")))
+
+ val plan = df.queryExecution.executedPlan
+ assert(collectAllShuffles(plan).isEmpty, "should not contain any shuffle")
+
+ val groupPartitions = collectAllGroupPartitions(plan)
+ assert(groupPartitions.nonEmpty, "should contain GroupPartitionsExec for coalescing")
+ assert(groupPartitions.exists(_.groupedPartitions.exists(_._2.size > 1)),
+ "expected coalescing GroupPartitionsExec")
+
+ val smjs = collect(plan) { case j: SortMergeJoinExec => j }
+ assert(smjs.nonEmpty, "expected SortMergeJoinExec in plan")
+ smjs.foreach { smj =>
+ val sorts = smj.children.flatMap(child => collect(child) { case s: SortExec => s })
+ if (preserveOrdering) {
+ assert(sorts.isEmpty,
+ "config enabled: SortedMergeCoalescedRDD preserves [arrive_time ASC], " +
+ "no SortExec should be added before SMJ")
+
+ val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1))
+ coalescingGP.foreach { gp =>
+ assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]],
+ "config enabled: should use SortedMergeCoalescedRDD")
+ }
+ } else {
+ assert(sorts.nonEmpty,
+ "config disabled: simple coalescing loses arrive_time ordering within a year, " +
+ "SortExec should be added before SMJ")
+ }
+ }
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
index 502424d58d2cb..06e55a179d4e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
@@ -53,7 +53,7 @@ class InMemoryTableMetricSuite
Array(Column.create("i", IntegerType)),
Array.empty[Transform], Collections.emptyMap[String, String])
- val metrics = runAndFetchMetrics(func("testcat.table_name"))
+ val metrics = runAndFetchMetrics(func("testcat.table_name")).map(m => m._1._3 -> m._2)
checker(metrics)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
index c37e051929555..72b80cbe05d6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.datasources.v2
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, SortOrder}
-import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, PartitioningCollection}
-import org.apache.spark.sql.execution.DummySparkPlan
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, SortOrder}
+import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
+import org.apache.spark.sql.execution.{DummySparkPlan, LeafExecNode, SafeForKWayMerge}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.IntegerType
@@ -141,4 +142,41 @@ class GroupPartitionsExecSuite extends SharedSparkSession {
assert(!gpe.groupedPartitions.forall(_._2.size <= 1), "expected coalescing")
assert(gpe.outputOrdering === Nil)
}
+
+ test("SPARK-55715: coalescing with sorted merge config enabled returns full child ordering") {
+ // Key 1 appears on partitions 0 and 2, causing coalescing. The child is a LeafExecNode
+ // so childIsSafeForKWayMerge = true. With the preserve-ordering config enabled, case 2
+ // of outputOrdering kicks in and the full child ordering (including the non-key exprC) must
+ // be returned, not just the subset of key-expression orders.
+ val partitionKeys = Seq(row(1), row(2), row(1))
+ val childOrdering = Seq(SortOrder(exprA, Ascending), SortOrder(exprC, Ascending))
+ val child = DummyLeafSparkPlan(
+ outputPartitioning = KeyedPartitioning(Seq(exprA), partitionKeys),
+ outputOrdering = childOrdering)
+ val gpe = GroupPartitionsExec(child)
+
+ assert(!gpe.groupedPartitions.forall(_._2.size <= 1), "expected coalescing")
+ withSQLConf(SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "true") {
+ // Config enabled: k-way merge preserves full ordering including non-key exprC.
+ assert(gpe.outputOrdering === childOrdering)
+ }
+ withSQLConf(
+ SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "false",
+ SQLConf.V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED.key -> "true") {
+ // Sorted-merge config disabled, key-ordering config enabled: only key-expression orders
+ // survive simple concatenation (non-key exprC is dropped).
+ val ordering = gpe.outputOrdering
+ assert(ordering.length === 1)
+ assert(ordering.head.child === exprA)
+ }
+ }
+}
+
+private case class DummyLeafSparkPlan(
+ override val outputOrdering: Seq[SortOrder] = Nil,
+ override val outputPartitioning: Partitioning = UnknownPartitioning(0)
+ ) extends LeafExecNode with SafeForKWayMerge {
+ override protected def doExecute(): RDD[InternalRow] =
+ throw new UnsupportedOperationException
+ override def output: Seq[Attribute] = Seq.empty
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index 456e7ed3478a6..8f5157acca7d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -54,7 +54,9 @@ trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase {
}
}
- def runAndFetchMetrics(func: => Unit): Map[String, String] = {
+ // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that
+ // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue.
+ def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = {
val statusStore = spark.sharedState.statusStore
val oldCount = statusStore.executionsList().size
@@ -73,7 +75,9 @@ trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase {
val exec = statusStore.executionsList().last
val execId = exec.executionId
- val sqlMetrics = exec.metrics.map { metric => metric.accumulatorId -> metric.name }.toMap
+ val sqlMetrics = statusStore.planGraph(execId).allNodes
+ .flatMap(n => n.metrics.map(m => (m.accumulatorId, (n.id, n.name, m.name))))
+ .toMap
statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> v }
}
}