diff --git a/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala new file mode 100644 index 0000000000000..4cd37e52dd8be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.{IOException, ObjectOutputStream} + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.util.Utils + +/** + * An RDD that coalesces partitions while preserving ordering through k-way merge. + * + * Unlike CoalescedRDD which simply concatenates partitions, this RDD performs a sorted + * merge of multiple input partitions to maintain ordering. This is useful when input + * partitions are locally sorted and we want to preserve that ordering after coalescing. + * + * The merge is performed using a priority queue (min-heap) which provides O(n log k) + * time complexity, where n is the total number of elements and k is the number of + * partitions being merged. + * + * @param prev The parent RDD + * @param numPartitions The number of output partitions after coalescing + * @param ordering The ordering to maintain during merge + * @param partitionCoalescer The coalescer defining how to group input partitions + * @tparam T The element type + */ +private[spark] class SortedMergeCoalescedRDD[T: ClassTag]( + @transient var prev: RDD[T], + numPartitions: Int, + partitionCoalescer: PartitionCoalescer, + ordering: Ordering[T]) + extends RDD[T](prev.context, Nil) { + + override def getPartitions: Array[Partition] = { + partitionCoalescer.coalesce(numPartitions, prev).zipWithIndex.map { + case (pg, i) => + val parentIndices = pg.partitions.map(_.index).toSeq + new SortedMergePartition(i, prev, parentIndices, pg.prefLoc) + } + } + + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + val mergePartition = partition.asInstanceOf[SortedMergePartition] + val parentPartitions = mergePartition.parents + + if (parentPartitions.isEmpty) { + Iterator.empty + } else if (parentPartitions.length == 1) { + // No merge needed for single partition + firstParent[T].iterator(parentPartitions.head, context) + } else { + // Perform k-way merge + new SortedMergeIterator[T]( + parentPartitions.map(p => firstParent[T].iterator(p, context)), + ordering + ) + } + } + + override def getDependencies: Seq[org.apache.spark.Dependency[_]] = { + Seq(new org.apache.spark.NarrowDependency(prev) { + def getParents(id: Int): Seq[Int] = + partitions(id).asInstanceOf[SortedMergePartition].parentsIndices + }) + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + partition.asInstanceOf[SortedMergePartition].prefLoc.toSeq + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + prev = null + } +} + +/** + * Partition for SortedMergeCoalescedRDD that tracks which parent partitions to merge. + * @param index of this coalesced partition + * @param rdd which it belongs to + * @param parentsIndices list of indices in the parent that have been coalesced into this partition + * @param prefLoc the preferred location for this partition + */ +private[spark] class SortedMergePartition( + idx: Int, + @transient private val rdd: RDD[_], + val parentsIndices: Seq[Int], + val prefLoc: Option[String] = None) extends Partition { + override val index: Int = idx + var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent partition at the time of task serialization + parents = parentsIndices.map(rdd.partitions(_)) + oos.defaultWriteObject() + } +} + +/** + * Iterator that performs k-way merge of sorted iterators. + * + * Uses a priority queue (min-heap) to efficiently find the next smallest element + * across all input iterators according to the specified ordering. This provides + * O(n log k) time complexity where n is the total number of elements and k is + * the number of iterators being merged. + * + * @param iterators The sequence of sorted iterators to merge + * @param ordering The ordering to use for comparison + * @tparam T The element type + */ +private[spark] class SortedMergeIterator[T]( + iterators: Seq[Iterator[T]], + ordering: Ordering[T]) extends Iterator[T] { + + // Priority queue entry: (current element, iterator index) + private case class QueueEntry(element: T, iteratorIdx: Int) + + // Min-heap ordered by element according to the provided ordering + private implicit val queueOrdering: Ordering[QueueEntry] = new Ordering[QueueEntry] { + override def compare(x: QueueEntry, y: QueueEntry): Int = { + // Reverse for min-heap (PriorityQueue is max-heap by default) + ordering.compare(y.element, x.element) + } + } + + private val queue = mutable.PriorityQueue.empty[QueueEntry] + + // Initialize queue with first element from each non-empty iterator + iterators.zipWithIndex.foreach { case (iter, idx) => + if (iter.hasNext) { + queue.enqueue(QueueEntry(iter.next(), idx)) + } + } + + override def hasNext: Boolean = queue.nonEmpty + + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException("next on empty iterator") + } + + val entry = queue.dequeue() + val result = entry.element + + // If the iterator has more elements, add the next one to the queue + val iter = iterators(entry.iteratorIdx) + if (iter.hasNext) { + queue.enqueue(QueueEntry(iter.next(), entry.iteratorIdx)) + } + + result + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala new file mode 100644 index 0000000000000..1cb83af2b3b69 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + +class SortedMergeCoalescedRDDSuite extends SparkFunSuite with SharedSparkContext { + + test("SPARK-55715: k-way merge maintains ordering - integers") { + // Create RDD with 4 partitions, each sorted + val data = Seq( + Seq(1, 5, 9, 13), // partition 0 + Seq(2, 6, 10, 14), // partition 1 + Seq(3, 7, 11, 15), // partition 2 + Seq(4, 8, 12, 16) // partition 3 + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + // Coalesce to 2 partitions using sorted merge + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1), Seq(2, 3))) + val merged = new SortedMergeCoalescedRDD[Int](rdd, 2, coalescer, Ordering.Int) + + // Verify per-partition contents: group (0,1) merges elements from partitions 0+1, + // group (2,3) merges elements from partitions 2+3 + val partitionData = merged + .mapPartitionsWithIndex { (idx, iter) => Iterator.single((idx, iter.toSeq)) } + .collect().toMap + + assert(partitionData(0) === Seq(1, 2, 5, 6, 9, 10, 13, 14)) + assert(partitionData(1) === Seq(3, 4, 7, 8, 11, 12, 15, 16)) + } + + test("SPARK-55715: k-way merge handles empty partitions") { + val data = Seq( + Seq(1, 5, 9), // partition 0 + Seq.empty[Int], // partition 1 - empty + Seq(3, 7, 11), // partition 2 + Seq.empty[Int] // partition 3 - empty + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2, 3))) + val merged = new SortedMergeCoalescedRDD[Int]( rdd, 1, coalescer, Ordering.Int) + + val result = merged.collect() + assert(result === Seq(1, 3, 5, 7, 9, 11)) + } + + test("SPARK-55715: k-way merge handles all-empty partitions in a group") { + val data = Seq( + Seq.empty[Int], // partition 0 - empty + Seq.empty[Int], // partition 1 - empty + Seq(1, 2, 3) // partition 2 - non-empty (different group) + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1), Seq(2))) + val merged = new SortedMergeCoalescedRDD[Int]( rdd, 2, coalescer, Ordering.Int) + + val partitionData = merged + .mapPartitionsWithIndex { (idx, iter) => Iterator.single((idx, iter.toSeq)) } + .collect().toMap + + assert(partitionData(0) === Seq.empty) + assert(partitionData(1) === Seq(1, 2, 3)) + } + + test("SPARK-55715: k-way merge with single partition per group - no merge needed") { + val data = Seq(1, 2, 3, 4, 5, 6) + val rdd = sc.parallelize(data, 3) + + // Each group has only 1 partition - should just pass through + val coalescer = new TestPartitionCoalescer(Seq(Seq(0), Seq(1), Seq(2))) + val merged = new SortedMergeCoalescedRDD[Int](rdd, 3, coalescer, Ordering.Int) + + assert(merged.collect() === data) + } + + test("SPARK-55715: k-way merge with reverse ordering") { + val data = Seq( + Seq(13, 9, 5, 1), // partition 0 - descending + Seq(14, 10, 6, 2), // partition 1 - descending + Seq(15, 11, 7, 3), // partition 2 - descending + Seq(16, 12, 8, 4) // partition 3 - descending + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + // Use reverse ordering + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2, 3))) + val merged = new SortedMergeCoalescedRDD[Int]( + rdd, 1, coalescer, Ordering.Int.reverse) + + val result = merged.collect() + assert(result === (1 to 16).reverse) + } + + test("SPARK-55715: k-way merge with many partitions") { + val numPartitions = 20 + val rowsPerPartition = 50 + + // Create sorted data where partition i starts at i and increments by numPartitions + val data = (0 until numPartitions).map { partIdx => + (0 until rowsPerPartition).map(i => partIdx + i * numPartitions) + } + + val rdd = sc.parallelize(data, numPartitions).flatMap(identity) + + // Coalesce all partitions into one + val coalescer = new TestPartitionCoalescer(Seq((0 until numPartitions))) + val merged = new SortedMergeCoalescedRDD[Int](rdd, 1, coalescer, Ordering.Int) + + val result = merged.collect() + assert(result.length === numPartitions * rowsPerPartition) + assert(result === result.sorted) + } + + test("SPARK-55715: k-way merge preserves duplicate elements across partitions") { + val data = Seq( + Seq(1, 2, 3), // partition 0 + Seq(1, 2, 3), // partition 1 - identical to partition 0 + Seq(2, 2, 4) // partition 2 - contains repeated value within partition + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2))) + val merged = new SortedMergeCoalescedRDD[Int](rdd, 1, coalescer, Ordering.Int) + + val result = merged.collect() + assert(result === Seq(1, 1, 2, 2, 2, 2, 3, 3, 4)) + } + + test("SPARK-55715: k-way merge with strings") { + val data = Seq( + Seq("apple", "cherry", "grape"), + Seq("banana", "date", "kiwi"), + Seq("apricot", "fig", "mango") + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2))) + val merged = new SortedMergeCoalescedRDD[String](rdd, 1, coalescer, Ordering.String) + + val result = merged.collect() + assert(result === result.sorted) + } + + test("SPARK-55715: k-way merge with tuples") { + val data = Seq( + Seq((1, "a"), (3, "c"), (5, "e")), + Seq((2, "b"), (4, "d"), (6, "f")), + Seq((1, "z"), (4, "y"), (7, "x")) + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + implicit val tupleOrdering: Ordering[(Int, String)] = Ordering.by[(Int, String), Int](_._1) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2))) + val merged = new SortedMergeCoalescedRDD[(Int, String)](rdd, 1, coalescer, tupleOrdering) + + val result = merged.collect() + val expected = data.flatten.sortBy(_._1) + assert(result === expected) + } + + test("SPARK-55715: SortedMergeIterator - next() on empty iterator throws " + + "NoSuchElementException") { + val iter = new SortedMergeIterator[Int](Seq.empty, Ordering.Int) + assert(!iter.hasNext) + intercept[NoSuchElementException] { iter.next() } + } + + test("SPARK-55715: SortedMergeIterator - empty iterators list") { + val iter = new SortedMergeIterator[Int](Seq(Iterator.empty, Iterator.empty), Ordering.Int) + assert(!iter.hasNext) + assert(iter.toSeq === Seq.empty) + } + + test("SPARK-55715: SortedMergeIterator - single iterator passes through unchanged") { + val iter = new SortedMergeIterator[Int](Seq(Iterator(3, 1, 2)), Ordering.Int) + assert(iter.toSeq === Seq(3, 1, 2)) + } +} + +/** + * Test partition coalescer that groups partitions according to a predefined plan. + */ +class TestPartitionCoalescer(grouping: Seq[Seq[Int]]) extends PartitionCoalescer with Serializable { + override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = { + grouping.map { partitionIndices => + val pg = new PartitionGroup(None) + partitionIndices.foreach { idx => + pg.partitions += parent.partitions(idx) + } + pg + }.toArray + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java index 9e4b75c4bd112..0dba93af5f8d6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java @@ -47,4 +47,27 @@ public interface CustomTaskMetric { * Returns the long value of custom task metric. */ long value(); + + /** + * Merges this metric with another metric of the same name, returning a new + * {@link CustomTaskMetric} that represents the combined value. This is called when a task reads + * multiple partitions concurrently (e.g., k-way merge coalescing) to produce a single + * task-level value before reporting to the driver. + * + *

The default implementation returns a new metric whose value is the sum of the two values, + * which is correct for count-type metrics. Data sources with non-additive metrics (e.g., max, + * average) should override this method. + * + * @param other another metric with the same name to merge with + * @return a new metric representing the merged value + * @since 4.1.0 + */ + default CustomTaskMetric mergeWith(CustomTaskMetric other) { + final String metricName = this.name(); + final long mergedValue = this.value() + other.value(); + return new CustomTaskMetric() { + @Override public String name() { return metricName; } + @Override public long value() { return mergedValue; } + }; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java index c12bc14a49c44..8420b6bdbdaf5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java @@ -65,6 +65,12 @@ default CustomTaskMetric[] currentMetricsValues() { * {@link org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} and the reader * is initialized with the metrics returned by the previous reader that belongs to the same * partition. By default, this method does nothing. + * + * @deprecated Use {@link CustomTaskMetric#mergeWith(CustomTaskMetric)} instead. When a task + * reads multiple partitions concurrently or sequentially, {@code DataSourceRDD} now merges + * metrics from all readers via {@code mergeWith} at reporting time, removing the need to + * seed each new reader with the prior reader's values. */ + @Deprecated(since = "4.2.0") default void initMetricsValues(CustomTaskMetric[] metrics) {} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5639b6bbfbf4f..7bd53ed4eda18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2191,6 +2191,21 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED = + buildConf("spark.sql.sources.v2.bucketing.preserveOrderingOnCoalesce.enabled") + .doc(s"When turned on, GroupPartitionsExec will use sorted merge to preserve full " + + s"ordering (as opposed to the key-derived ordering preserved by " + + s"${V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED.key}) when coalescing " + + s"multiple partitions with the same key. This allows eliminating downstream sorts when " + + s"data is both partitioned and sorted. However, sorted merge uses more resources " + + s"(priority queue, comparison overhead) than simple concatenation, especially when " + + s"coalescing many partitions. When turned off, only key-derived ordering is preserved " + + s"during coalescing. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -7763,6 +7778,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingPreserveKeyOrderingOnCoalesceEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED) + def v2BucketingPreserveOrderingOnCoalesceEnabled: Boolean = + getConf(SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index e7762565f47ec..decff8e2bcbf0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -429,8 +429,15 @@ abstract class InMemoryBaseTable( private var _pushedFilters: Array[Filter] = Array.empty override def build: Scan = { - val scan = InMemoryBatchScan( - data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options) + val scan = if (InMemoryBaseTable.this.ordering.nonEmpty) { + new InMemoryBatchScanWithOrdering( + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, + options) + } else { + InMemoryBatchScan( + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, + options) + } if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } @@ -596,6 +603,19 @@ abstract class InMemoryBaseTable( } } + // Extends InMemoryBatchScan with SupportsReportOrdering. Only instantiated when the table has a + // non-empty ordering, so that V2ScanPartitioningAndOrdering only sets ordering = Some(...) on the + // logical plan when there is actual ordering to report. + private class InMemoryBatchScanWithOrdering( + data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType, + options: CaseInsensitiveStringMap) + extends InMemoryBatchScan(data, readSchema, tableSchema, options) + with SupportsReportOrdering { + override def outputOrdering(): Array[SortOrder] = InMemoryBaseTable.this.ordering + } + abstract class InMemoryWriterBuilder(val info: LogicalWriteInfo) extends SupportsTruncate with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend { @@ -846,8 +866,13 @@ private class BufferedRowsReader( private var index: Int = -1 private var rowsRead: Long = 0 + private var closed: Boolean = false + + private def checkNotClosed(op: String): Unit = + if (closed) throw new IllegalStateException(s"$op called on a closed BufferedRowsReader") override def next(): Boolean = { + checkNotClosed("next()") index += 1 val hasNext = index < partition.rows.length if (hasNext) rowsRead += 1 @@ -855,6 +880,7 @@ private class BufferedRowsReader( } override def get(): InternalRow = { + checkNotClosed("get()") val originalRow = partition.rows(index) val values = new Array[Any](nonMetadataColumns.length) nonMetadataColumns.zipWithIndex.foreach { case (col, idx) => @@ -864,7 +890,10 @@ private class BufferedRowsReader( addMetadata(new GenericInternalRow(values)) } - override def close(): Unit = {} + override def close(): Unit = { + checkNotClosed("close()") + closed = true + } private def extractFieldValue( field: StructField, @@ -995,15 +1024,8 @@ private class BufferedRowsReader( private def castElement(elem: Any, toType: DataType, fromType: DataType): Any = Cast(Literal(elem, fromType), toType, None, EvalMode.TRY).eval(null) - override def initMetricsValues(metrics: Array[CustomTaskMetric]): Unit = { - metrics.foreach { m => - m.name match { - case "rows_read" => rowsRead = m.value() - } - } - } - override def currentMetricsValues(): Array[CustomTaskMetric] = { + checkNotClosed("currentMetricsValues()") val metric = new CustomTaskMetric { override def name(): String = "rows_read" override def value(): Long = rowsRead diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala new file mode 100644 index 0000000000000..08be66e3f4477 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * Marker trait for execution nodes that are safe to use with SortedMergeCoalescedRDD + * (concurrent k-way merge). Nodes implementing this trait must store no per-partition + * mutable state in shared plan-node instance fields; all per-partition state must be + * captured inside the partition's iterator closure (e.g. via the + * PartitionEvaluatorFactory pattern). + */ +trait SafeForKWayMerge diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 246508965d3d6..c1c6f36475475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -501,7 +501,8 @@ trait InputRDDCodegen extends CodegenSupport { * This is the leaf node of a tree with WholeStageCodegen that is used to generate code * that consumes an RDD iterator of InternalRow. */ -case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen { +case class InputAdapter(child: SparkPlan) + extends UnaryExecNode with InputRDDCodegen with SafeForKWayMerge { override def output: Seq[Attribute] = child.output @@ -633,7 +634,7 @@ object WholeStageCodegenExec { * used to generated code for [[BoundReference]]. */ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with CodegenSupport with SafeForKWayMerge { override def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 5e73cdd6da8f4..bafe7e568bf2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode with CodegenSupport + with SafeForKWayMerge with PartitioningPreservingUnaryExecNode with OrderPreservingUnaryExecNode { @@ -222,7 +223,7 @@ trait GeneratePredicateHelper extends PredicateHelper { /** Physical plan for Filter. */ case class FilterExec(condition: Expression, child: SparkPlan) - extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper { + extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper with SafeForKWayMerge { // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 2fedb97e8461e..d7bb1d9093edf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -19,34 +19,111 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.concurrent.ConcurrentHashMap -import scala.language.existentials +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.ArrayImplicits._ class DataSourceRDDPartition(val index: Int, val inputPartition: Option[InputPartition]) extends Partition with Serializable /** - * Holds the state for a reader in a task, used by the completion listener to access the most - * recently created reader and iterator for final metrics updates and cleanup. + * Holds all mutable state for a single Spark task reading from a {@link DataSourceRDD}: + *

* - * When `compute()` is called multiple times for the same task (e.g., when DataSourceRDD is - * coalesced), this state is updated on each call to track the most recent reader. The task - * completion listener then uses this most recent reader for final cleanup and metrics reporting. + *

When metrics are reported: + *

* - * @param reader The partition reader - * @param iterator The metrics iterator wrapping the reader + *

Why this works across all execution modes: + *

*/ -private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIterator[_]) +private class TaskState(customMetrics: Map[String, SQLMetric]) { + val partitionIterators = new ArrayBuffer[PartitionIterator[_]]() + + // Input metrics (recordsRead, bytesRead) tracked for this task. + private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics + private val startingBytesRead = inputMetrics.bytesRead + private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + private var recordsReadAtLastBytesUpdate = 0L + private var recordsReadAtLastCustomMetricsUpdate = 0L + + // Pre-merged custom metrics snapshot of all readers closed by natural exhaustion. + // Maintained as a map (one entry per metric name). + private val closedMetrics = new HashMap[String, CustomTaskMetric]() + + def updateMetrics(numRows: Int, force: Boolean = false): Unit = { + inputMetrics.incRecordsRead(numRows) + val shouldUpdateBytesRead = force || + inputMetrics.recordsRead - recordsReadAtLastBytesUpdate >= + SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS + if (shouldUpdateBytesRead) { + recordsReadAtLastBytesUpdate = inputMetrics.recordsRead + inputMetrics.setBytesRead(startingBytesRead + getBytesRead()) + } + val shouldUpdateCustomMetrics = force || + inputMetrics.recordsRead - recordsReadAtLastCustomMetricsUpdate >= + CustomMetrics.NUM_ROWS_PER_UPDATE + if (shouldUpdateCustomMetrics) { + recordsReadAtLastCustomMetricsUpdate = inputMetrics.recordsRead + mergeAndUpdateCustomMetrics() + } + } + + private def mergeAndUpdateCustomMetrics(): Unit = { + partitionIterators.filterInPlace { iter => + if (iter.isClosed) { + iter.finalMetrics.foreach { m => + closedMetrics.update(m.name(), closedMetrics.get(m.name()).fold(m)(_.mergeWith(m))) + } + false + } else true + } + val mergedMetrics = (partitionIterators.flatMap(_.currentMetricsValues) ++ closedMetrics.values) + .groupMapReduce(_.name())(identity)(_.mergeWith(_)) + .values + .toSeq + if (mergedMetrics.nonEmpty) { + CustomMetrics.updateMetrics(mergedMetrics, customMetrics) + } + } +} // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for // columnar scan. @@ -71,10 +148,8 @@ class DataSourceRDD( customMetrics: Map[String, SQLMetric]) extends RDD[InternalRow](sc, Nil) { - // Map from task attempt ID to the most recently created ReaderState for that task. - // When compute() is called multiple times for the same task (due to coalescing), the map entry - // is updated each time so the completion listener always closes the last reader. - @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long, ReaderState]() + // One TaskState per task attempt. + @transient private lazy val taskStates = new ConcurrentHashMap[Long, TaskState]() override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { @@ -90,52 +165,39 @@ class DataSourceRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val taskAttemptId = context.taskAttemptId() - // Add completion listener only once per task attempt. When compute() is called a second time - // for the same task (e.g., due to coalescing), the first call will have already put a - // ReaderState into taskReaderStates, so containsKey returns true and we skip this block. - if (!taskReaderStates.containsKey(taskAttemptId)) { + // Ensure a TaskState exists for this task and register the completion listener on the + // first compute() call. computeIfAbsent is atomic; same-task calls are always on one + // thread, so partitionIterators.isEmpty reliably identifies the first call. + val taskState = taskStates.computeIfAbsent(taskAttemptId, _ => new TaskState(customMetrics)) + + if (taskState.partitionIterators.isEmpty) { context.addTaskCompletionListener[Unit] { ctx => - // In case of early stopping before consuming the entire iterator, - // we need to do one more metric update at the end of the task. + // In case of early stopping, do a final metrics update and close all readers. try { - val readerState = taskReaderStates.get(ctx.taskAttemptId()) - if (readerState != null) { - CustomMetrics.updateMetrics( - readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics) - readerState.iterator.forceUpdateMetrics() - readerState.reader.close() + val taskState = taskStates.get(ctx.taskAttemptId()) + if (taskState != null) { + taskState.updateMetrics(0, force = true) + taskState.partitionIterators.foreach(_.close()) } } finally { - taskReaderStates.remove(ctx.taskAttemptId()) + taskStates.remove(ctx.taskAttemptId()) } } } castPartition(split).inputPartition.iterator.flatMap { inputPartition => - val (iter, reader) = if (columnarReads) { + val iter = if (columnarReads) { val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) - val iter = new MetricsBatchIterator( - new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) - (iter, batchReader) + new PartitionIterator[ColumnarBatch](batchReader, _.numRows, taskState) } else { val rowReader = partitionReaderFactory.createReader(inputPartition) - val iter = new MetricsRowIterator( - new PartitionIterator[InternalRow](rowReader, customMetrics)) - (iter, rowReader) + new PartitionIterator[InternalRow](rowReader, _ => 1, taskState) } - // Flush metrics and close the previous reader before advancing to the next one. - // Pass the accumulated metrics to the new reader so they carry forward correctly. - val prevState = taskReaderStates.get(taskAttemptId) - if (prevState != null) { - val metrics = prevState.reader.currentMetricsValues - CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) - reader.initMetricsValues(metrics) - prevState.reader.close() - } - - // Update the map so the completion listener always references the latest reader. - taskReaderStates.put(taskAttemptId, ReaderState(reader, iter)) + // Track this iterator; early-stop close and final metrics-flush for iterators not yet + // naturally exhausted are handled by the task completion listener. This avoids closing + // live iterators prematurely in the concurrent k-way merge (SortedMergeCoalescedRDD). + taskState.partitionIterators += iter // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) @@ -148,16 +210,36 @@ class DataSourceRDD( } private class PartitionIterator[T]( - reader: PartitionReader[T], - customMetrics: Map[String, SQLMetric]) extends Iterator[T] { - private[this] var valuePrepared = false - private[this] var hasMoreInput = true + private var reader: PartitionReader[T], + rowCount: T => Int, + taskState: TaskState) extends Iterator[T] { + private var valuePrepared = false + private var hasMoreInput = true + + // Cached final metrics snapshot, captured just before the reader is closed on natural + // exhaustion. Allows mergeAndUpdateCustomMetrics() to include this reader's contribution + // after close() has been called and currentMetricsValues() is no longer valid. + private var cachedFinalMetrics: Array[CustomTaskMetric] = Array.empty + + def isClosed: Boolean = reader == null + + def finalMetrics: Array[CustomTaskMetric] = cachedFinalMetrics + + def close(): Unit = if (reader != null) { + reader.close() + reader = null + } - private var numRow = 0L + def currentMetricsValues: Array[CustomTaskMetric] = reader.currentMetricsValues override def hasNext: Boolean = { if (!valuePrepared && hasMoreInput) { hasMoreInput = reader.next() + if (!hasMoreInput) { + cachedFinalMetrics = reader.currentMetricsValues + taskState.updateMetrics(0, force = true) + close() + } valuePrepared = hasMoreInput } valuePrepared @@ -167,59 +249,9 @@ private class PartitionIterator[T]( if (!hasNext) { throw QueryExecutionErrors.endOfStreamError() } - if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { - CustomMetrics.updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics) - } - numRow += 1 valuePrepared = false - reader.get() - } -} - -private class MetricsHandler extends Logging with Serializable { - private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics - private val startingBytesRead = inputMetrics.bytesRead - private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - - def updateMetrics(numRows: Int, force: Boolean = false): Unit = { - inputMetrics.incRecordsRead(numRows) - val shouldUpdateBytesRead = - inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0 - if (shouldUpdateBytesRead || force) { - inputMetrics.setBytesRead(startingBytesRead + getBytesRead()) - } - } -} - -private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { - protected val metricsHandler = new MetricsHandler - - override def hasNext: Boolean = { - if (iter.hasNext) { - true - } else { - forceUpdateMetrics() - false - } - } - - def forceUpdateMetrics(): Unit = metricsHandler.updateMetrics(0, force = true) -} - -private class MetricsRowIterator( - iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) { - override def next(): InternalRow = { - val item = iter.next() - metricsHandler.updateMetrics(1) - item - } -} - -private class MetricsBatchIterator( - iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) { - override def next(): ColumnarBatch = { - val batch: ColumnarBatch = iter.next() - metricsHandler.updateMetrics(batch.numRows) - batch + val result = reader.get() + taskState.updateMetrics(rowCount(result)) + result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index a1a6c6e022482..0bec918039775 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -24,14 +24,14 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan} -import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SQLExecution} +import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SafeForKWayMerge, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils -trait DataSourceV2ScanExecBase extends LeafExecNode { +trait DataSourceV2ScanExecBase extends LeafExecNode with SafeForKWayMerge { lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric => customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 81981c29b2b31..e552c4f71641f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -21,13 +21,14 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Partition, SparkException} -import org.apache.spark.rdd.{CoalescedRDD, PartitionCoalescer, PartitionGroup, RDD} +import org.apache.spark.rdd.{CoalescedRDD, PartitionCoalescer, PartitionGroup, RDD, SortedMergeCoalescedRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning} import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.functions.Reducer -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{SafeForKWayMerge, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -158,16 +159,91 @@ case class GroupPartitionsExec( @transient lazy val isGrouped: Boolean = groupedPartitionsTuple._2 + // Whether the child subtree is safe to use with SortedMergeCoalescedRDD (k-way merge). + // + // --- The general problem --- + // + // Unlike a simple CoalescedRDD, which processes input partitions sequentially (open P0, + // exhaust P0, open P1, exhaust P1, ...), SortedMergeCoalescedRDD must perform a k-way merge: + // it opens *all* input partition iterators upfront and then interleaves reads across them to + // produce a globally sorted output partition. Crucially, all of this happens on a single JVM + // thread within a single Spark task. + // + // A SparkPlan object is shared across all partition computations of a task. When N partition + // iterators are live concurrently on the same thread, N independent "computations" are all + // operating through the same plan node instances. If any node in the subtree stores + // per-partition state in an instance field rather than as a local variable captured inside the + // partition's iterator closure, that field is aliased across all N concurrent computations. + // Whichever computation last wrote the field "wins", and any computation that then reads or + // frees it based on its own earlier write will operate on the wrong state. + // + // --- The correct pattern: PartitionEvaluatorFactory --- + // + // The PartitionEvaluatorFactory / PartitionEvaluator pattern is specifically designed to avoid + // this problem. The factory's createEvaluator() is called once per partition and returns a + // fresh PartitionEvaluator instance. All per-partition mutable state lives inside that + // evaluator instance, not on the shared plan node. Operators that follow this pattern + // exclusively (and hold no other mutable state on the plan node) are safe for k-way merge. + // + // --- Concrete example of an unsafe operator: SortExec + SortMergeJoinExec --- + // + // SortExec stores its active sorter in a plain var field (`rowSorter`) on the plan node. + // When the k-way merge initialises its N partition iterators, each one drives the same SortExec + // instance and calls createSorter(), which assigns rowSorter = newSorter -- overwriting the + // field each time. After all N iterators are initialised, rowSorter holds only the sorter + // created for the *last* partition. + // + // SortMergeJoinExec performs eager resource cleanup: when a join partition is exhausted it + // calls cleanupResources() on its children, which reaches SortExec.cleanupResources(). That + // method calls rowSorter.cleanupResources() -- but rowSorter now holds the last-created sorter, + // not the one belonging to the just-exhausted partition. If that last sorter is still being + // actively read by another partition in the k-way merge, freeing it causes a use-after-free. + // + // To be conservative, we use a whitelist: unknown node types fall through to unsafe, causing + // a fallback to simple sequential coalescing. Only node types explicitly confirmed to store no + // per-partition state in shared (plan node) instance fields are permitted. + @transient private lazy val childIsSafeForKWayMerge: Boolean = + !child.exists { + case _: SafeForKWayMerge => false + case _ => true + } + override protected def doExecute(): RDD[InternalRow] = { if (groupedPartitions.isEmpty) { sparkContext.emptyRDD + } else if (SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled && + child.outputOrdering.nonEmpty && + childIsSafeForKWayMerge && + groupedPartitions.exists(_._2.size > 1)) { + // Use sorted merge when: + // 1. Config is enabled + // 2. Child has ordering + // 3. Actually coalescing multiple partitions + // Convert SortOrder expressions to Ordering[InternalRow] + val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) + val rowOrdering = new InterpretedOrdering(child.outputOrdering, child.output) + new SortedMergeCoalescedRDD[InternalRow]( + child.execute(), + groupedPartitions.size, + partitionCoalescer, + rowOrdering) } else { val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) + // Use simple coalescing when config is disabled, no ordering, or no actual coalescing new CoalescedRDD(child.execute(), groupedPartitions.size, Some(partitionCoalescer)) } } - override def supportsColumnar: Boolean = child.supportsColumnar + override def supportsColumnar: Boolean = { + // Don't use columnar when sorted merge coalescing is needed, since we can't preserve + // ordering with sorted merge for columnar batches + val needsSortedMerge = SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled && + child.outputOrdering.nonEmpty && + childIsSafeForKWayMerge && + groupedPartitions.exists(_._2.size > 1) + + child.supportsColumnar && !needsSortedMerge + } override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { if (groupedPartitions.isEmpty) { @@ -189,6 +265,12 @@ case class GroupPartitionsExec( // within-partition ordering is fully preserved (including any key-derived ordering that // `DataSourceV2ScanExecBase` already prepended). child.outputOrdering + } else if (SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled && + child.outputOrdering.nonEmpty && + childIsSafeForKWayMerge) { + // Coalescing with sorted merge: SortedMergeCoalescedRDD performs a k-way merge using the + // child's ordering, so the full within-partition ordering is preserved end-to-end. + child.outputOrdering } else { // Coalescing: multiple input partitions are merged into one output partition. The child's // within-partition ordering is lost due to concatenation -- for example, if two input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala index 5d06c8786d894..7f51875b971fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala @@ -69,7 +69,8 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with Logging { private def ordering(plan: LogicalPlan) = plan.transformDown { case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportOrdering, _, _, _) => - val ordering = V2ExpressionUtils.toCatalystOrdering(scan.outputOrdering(), relation) + val ordering = + V2ExpressionUtils.toCatalystOrdering(scan.outputOrdering(), relation, relation.funCatalog) d.copy(ordering = Some(ordering)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 44cb3a23cfa7c..4a406322a5a19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -20,6 +20,7 @@ import java.sql.Timestamp import java.util.Collections import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.rdd.SortedMergeCoalescedRDD import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Literal, TransformExpression} @@ -252,9 +253,10 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with table: String, columns: Array[Column], partitions: Array[Transform], + ordering: Array[SortOrder] = Array.empty, catalog: InMemoryTableCatalog = catalog): Unit = { catalog.createTable(Identifier.of(Array("ns"), table), - columns, partitions, emptyProps, Distributions.unspecified(), Array.empty, None, None, + columns, partitions, emptyProps, Distributions.unspecified(), ordering, None, None, numRowsPerSplit = 1) } @@ -3091,7 +3093,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with assert(groupPartitions(0).outputPartitioning.numPartitions === 2, "group partitions should have 2 partition groups") } - assert(metrics("number of rows read") == "3") + assert(metrics.collect { + case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v + } === Seq("3")) } test("SPARK-55619: Custom metrics of coalesced partitions") { @@ -3106,7 +3110,68 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1) df.collect() } - assert(metrics("number of rows read") == "2") + assert(metrics.collect { + case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v + } === Seq("2")) + } + + test("SPARK-55715: Custom metrics of sorted-merge coalesced partitions") { + // items has id=1 on three splits with interleaved arrive_times -- out of order across splits. + // purchases has item_id=1 on two splits, also out of order. Both sides coalesce under SMJ, + // using SortedMergeCoalescedRDD with multiple concurrent readers per task. This test verifies + // that all rows from both tables (5 + 4 = 9) are accounted for in the per-scan metrics. + val itemOrdering = Array( + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(items, itemsColumns, Array(identity("id")), itemOrdering) + // Rows inserted out of order: id=1 lands on partitions 1, 3, 4 with arrive_times + // [2022-03-10, 2021-05-20, 2025-09-01] -- out of order. + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(3, 'dd', 40.0, cast('2024-01-01' as timestamp)), " + + "(1, 'bb', 20.0, cast('2022-03-10' as timestamp)), " + + "(2, 'cc', 30.0, cast('2023-06-15' as timestamp)), " + + "(1, 'aa', 10.0, cast('2021-05-20' as timestamp)), " + + "(1, 'ee', 50.0, cast('2025-09-01' as timestamp))") + + val purchaseOrdering = Array( + sort(FieldReference("item_id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(purchases, purchasesColumns, Array(identity("item_id")), purchaseOrdering) + // item_id=1 lands on partitions 1 and 3 with times [2022-03-10, 2021-05-20] -- out of order. + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 30.0, cast('2023-06-15' as timestamp)), " + + "(1, 20.0, cast('2022-03-10' as timestamp)), " + + "(3, 40.0, cast('2024-01-01' as timestamp)), " + + "(1, 10.0, cast('2021-05-20' as timestamp))") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "true") { + val metrics = runAndFetchMetrics { + val df = sql( + s"""${selectWithMergeJoinHint("i", "p")} + |i.id, i.name + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.item_id = i.id AND p.time = i.arrive_time + |""".stripMargin) + checkAnswer(df, Seq(Row(1, "aa"), Row(1, "bb"), Row(2, "cc"), Row(3, "dd"))) + val plan = df.queryExecution.executedPlan + val groupPartitions = collectAllGroupPartitions(plan) + val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1)) + assert(coalescingGP.nonEmpty, "expected a coalescing GroupPartitionsExec") + coalescingGP.foreach { gp => + assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]], + "should use SortedMergeCoalescedRDD when preserve-ordering config is enabled") + } + } + assert(metrics.collect { + case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v + } === Seq("5")) + assert(metrics.collect { + case ((_, "BatchScan testcat.ns.purchases", "number of rows read"), v) => v + } === Seq("4")) + } } test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " + @@ -3690,4 +3755,162 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } } + + test("SPARK-55715: preserve outputOrdering when coalescing partitions with sorted merge") { + // Both tables are partitioned by their id column and report ordering [id ASC, price ASC] + // via SupportsReportOrdering. Each has two rows with id=1 (two splits), so GroupPartitionsExec + // must coalesce them. We join on (id, price) = (item_id, price) using SMJ. + // + // With config enabled: SortedMergeCoalescedRDD performs a k-way merge preserving the full + // [id ASC, price ASC] ordering -> EnsureRequirements is satisfied -> no SortExec added. + // With config disabled: simple CoalescedRDD concatenates the splits and only the key-derived + // [id ASC] ordering survives -> price ordering is lost -> SortExec is added for price. + val itemOrdering = Array( + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(items, itemsColumns, Array(identity("id")), itemOrdering) + // Rows inserted out of order: id values are interleaved and arrive_time is not monotone + // within each id group, so ordering by [id ASC, arrive_time ASC] is non-trivial. + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(2, 'cc', 30.0, cast('2023-06-15' as timestamp)), " + + "(1, 'bb', 20.0, cast('2022-03-10' as timestamp)), " + + "(3, 'dd', 40.0, cast('2024-01-01' as timestamp)), " + + "(1, 'aa', 10.0, cast('2021-05-20' as timestamp)), " + + "(2, 'ee', 50.0, cast('2025-09-01' as timestamp))") + + val purchaseOrdering = Array( + sort(FieldReference("item_id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(purchases, purchasesColumns, Array(identity("item_id")), purchaseOrdering) + // Also inserted out of order + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 50.0, cast('2025-09-01' as timestamp)), " + + "(1, 10.0, cast('2021-05-20' as timestamp)), " + + "(3, 40.0, cast('2024-01-01' as timestamp)), " + + "(2, 30.0, cast('2023-06-15' as timestamp)), " + + "(1, 20.0, cast('2022-03-10' as timestamp))") + + Seq(true, false).foreach { preserveOrdering => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> + preserveOrdering.toString) { + val df = sql( + s""" + |${selectWithMergeJoinHint("i", "p")} + |i.id, i.name + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.item_id = i.id AND p.time = i.arrive_time + |""".stripMargin) + checkAnswer(df, Seq( + Row(1, "aa"), Row(1, "bb"), Row(2, "cc"), Row(2, "ee"), Row(3, "dd"))) + + val plan = df.queryExecution.executedPlan + assert(collectAllShuffles(plan).isEmpty, "should not contain any shuffle") + + val groupPartitions = collectAllGroupPartitions(plan) + assert(groupPartitions.nonEmpty, "should contain GroupPartitionsExec for coalescing") + assert(groupPartitions.exists(_.groupedPartitions.exists(_._2.size > 1)), + "expected coalescing GroupPartitionsExec") + + val smjs = collect(plan) { case j: SortMergeJoinExec => j } + assert(smjs.nonEmpty, "expected SortMergeJoinExec in plan") + smjs.foreach { smj => + val sorts = smj.children.flatMap(child => collect(child) { case s: SortExec => s }) + if (preserveOrdering) { + assert(sorts.isEmpty, + "config enabled: SortedMergeCoalescedRDD preserves [id ASC, arrive_time ASC], " + + "no SortExec should be added before SMJ") + + // Also verify the k-way merge RDD is actually used + val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1)) + coalescingGP.foreach { gp => + assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]], + "config enabled: should use SortedMergeCoalescedRDD") + } + } else { + assert(sorts.nonEmpty, + "config disabled: simple coalescing loses arrive_time ordering, " + + "SortExec should be added before SMJ") + } + } + } + } + } + + test("SPARK-55715: preserve outputOrdering when coalescing transform-partitioned splits") { + // Both tables are partitioned by years("arrive_time") / years("time") and report ordering + // [arrive_time ASC] / [time ASC]. Two rows share the same year bucket (2022 and 2023), so + // GroupPartitionsExec coalesces two splits per year. We join solely on + // p.time = i.arrive_time (the partition key expression) using SMJ. + // + // With config enabled: SortedMergeCoalescedRDD k-way merge preserves [arrive_time ASC] + // ordering -> EnsureRequirements is satisfied -> no SortExec added. + // With config disabled: simple CoalescedRDD only preserves the key-derived year ordering -> + // arrive_time ordering within a year is lost -> SortExec is added. + val itemOrdering = Array( + sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(items, itemsColumns, Array(years("arrive_time")), itemOrdering) + // Inserted out of order: within year 2022, September is before March in insertion order + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(2, 'bb', 20.0, cast('2022-09-20' as timestamp)), " + + "(4, 'dd', 40.0, cast('2023-11-05' as timestamp)), " + + "(1, 'aa', 10.0, cast('2022-03-15' as timestamp)), " + + "(3, 'cc', 30.0, cast('2023-01-10' as timestamp))") + + val purchaseOrdering = Array( + sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(purchases, purchasesColumns, Array(years("time")), purchaseOrdering) + // Also inserted out of order + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 20.0, cast('2022-09-20' as timestamp)), " + + "(4, 40.0, cast('2023-11-05' as timestamp)), " + + "(1, 10.0, cast('2022-03-15' as timestamp)), " + + "(3, 30.0, cast('2023-01-10' as timestamp))") + + Seq(true, false).foreach { preserveOrdering => + withSQLConf( + SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> + preserveOrdering.toString) { + val df = sql( + s""" + |${selectWithMergeJoinHint("i", "p")} + |i.id, i.name + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.time = i.arrive_time + |""".stripMargin) + checkAnswer(df, Seq(Row(1, "aa"), Row(2, "bb"), Row(3, "cc"), Row(4, "dd"))) + + val plan = df.queryExecution.executedPlan + assert(collectAllShuffles(plan).isEmpty, "should not contain any shuffle") + + val groupPartitions = collectAllGroupPartitions(plan) + assert(groupPartitions.nonEmpty, "should contain GroupPartitionsExec for coalescing") + assert(groupPartitions.exists(_.groupedPartitions.exists(_._2.size > 1)), + "expected coalescing GroupPartitionsExec") + + val smjs = collect(plan) { case j: SortMergeJoinExec => j } + assert(smjs.nonEmpty, "expected SortMergeJoinExec in plan") + smjs.foreach { smj => + val sorts = smj.children.flatMap(child => collect(child) { case s: SortExec => s }) + if (preserveOrdering) { + assert(sorts.isEmpty, + "config enabled: SortedMergeCoalescedRDD preserves [arrive_time ASC], " + + "no SortExec should be added before SMJ") + + val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1)) + coalescingGP.foreach { gp => + assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]], + "config enabled: should use SortedMergeCoalescedRDD") + } + } else { + assert(sorts.nonEmpty, + "config disabled: simple coalescing loses arrive_time ordering within a year, " + + "SortExec should be added before SMJ") + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala index 502424d58d2cb..06e55a179d4e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala @@ -53,7 +53,7 @@ class InMemoryTableMetricSuite Array(Column.create("i", IntegerType)), Array.empty[Transform], Collections.emptyMap[String, String]) - val metrics = runAndFetchMetrics(func("testcat.table_name")) + val metrics = runAndFetchMetrics(func("testcat.table_name")).map(m => m._1._3 -> m._2) checker(metrics) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala index c37e051929555..72b80cbe05d6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, PartitioningCollection} -import org.apache.spark.sql.execution.DummySparkPlan +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} +import org.apache.spark.sql.execution.{DummySparkPlan, LeafExecNode, SafeForKWayMerge} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.IntegerType @@ -141,4 +142,41 @@ class GroupPartitionsExecSuite extends SharedSparkSession { assert(!gpe.groupedPartitions.forall(_._2.size <= 1), "expected coalescing") assert(gpe.outputOrdering === Nil) } + + test("SPARK-55715: coalescing with sorted merge config enabled returns full child ordering") { + // Key 1 appears on partitions 0 and 2, causing coalescing. The child is a LeafExecNode + // so childIsSafeForKWayMerge = true. With the preserve-ordering config enabled, case 2 + // of outputOrdering kicks in and the full child ordering (including the non-key exprC) must + // be returned, not just the subset of key-expression orders. + val partitionKeys = Seq(row(1), row(2), row(1)) + val childOrdering = Seq(SortOrder(exprA, Ascending), SortOrder(exprC, Ascending)) + val child = DummyLeafSparkPlan( + outputPartitioning = KeyedPartitioning(Seq(exprA), partitionKeys), + outputOrdering = childOrdering) + val gpe = GroupPartitionsExec(child) + + assert(!gpe.groupedPartitions.forall(_._2.size <= 1), "expected coalescing") + withSQLConf(SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "true") { + // Config enabled: k-way merge preserves full ordering including non-key exprC. + assert(gpe.outputOrdering === childOrdering) + } + withSQLConf( + SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED.key -> "true") { + // Sorted-merge config disabled, key-ordering config enabled: only key-expression orders + // survive simple concatenation (non-key exprC is dropped). + val ordering = gpe.outputOrdering + assert(ordering.length === 1) + assert(ordering.head.child === exprA) + } + } +} + +private case class DummyLeafSparkPlan( + override val outputOrdering: Seq[SortOrder] = Nil, + override val outputPartitioning: Partitioning = UnknownPartitioning(0) + ) extends LeafExecNode with SafeForKWayMerge { + override protected def doExecute(): RDD[InternalRow] = + throw new UnsupportedOperationException + override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 456e7ed3478a6..8f5157acca7d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -54,7 +54,9 @@ trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase { } } - def runAndFetchMetrics(func: => Unit): Map[String, String] = { + // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that + // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. + def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = { val statusStore = spark.sharedState.statusStore val oldCount = statusStore.executionsList().size @@ -73,7 +75,9 @@ trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase { val exec = statusStore.executionsList().last val execId = exec.executionId - val sqlMetrics = exec.metrics.map { metric => metric.accumulatorId -> metric.name }.toMap + val sqlMetrics = statusStore.planGraph(execId).allNodes + .flatMap(n => n.metrics.map(m => (m.accumulatorId, (n.id, n.name, m.name)))) + .toMap statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> v } } }