From 038c839895dace0eaa942c6a69398f1d5fe3dfb4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 1 Apr 2026 12:04:59 +0200 Subject: [PATCH 1/2] [SPARK-56321][SQL] Pass function catalog to toCatalystOrdering in V2ScanPartitioningAndOrdering ### What changes were proposed in this pull request? `V2ScanPartitioningAndOrdering.ordering` was calling `V2ExpressionUtils.toCatalystOrdering` without the `funCatalog` argument. This meant that function-based sort expressions reported by a data source via `SupportsReportOrdering` (e.g. transform functions like `bucket(n, col)`) could not be resolved against the function catalog and would be silently dropped. The fix passes `relation.funCatalog` as the third argument, consistent with how `toCatalystOpt` is already called in the `partitioning` rule of the same object. ### Why are the changes needed? Without the function catalog, sort orders involving catalog functions reported by `SupportsReportOrdering` are not resolved, causing them to be ignored by the planner even when the data source correctly reports them. ### Does this PR introduce _any_ user-facing change? Yes. Data sources implementing `SupportsReportOrdering` with function-based sort expressions that require the function catalog will now have those sort orders correctly recognized by Spark, potentially eliminating unnecessary sort operations. ### How was this patch tested? `WriteDistributionAndOrderingSuite` already covers this due to `InMemoryBaseTable` is updated to use `InMemoryBatchScanWithOrdering` (a new inner classimplementing `SupportsReportOrdering`) when a table ordering is configured. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Sonnet 4.6 --- .../connector/catalog/InMemoryBaseTable.scala | 24 +++++++++++++++++-- .../v2/V2ScanPartitioningAndOrdering.scala | 3 ++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index e7762565f47ec..9cc32efedf82a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -429,8 +429,15 @@ abstract class InMemoryBaseTable( private var _pushedFilters: Array[Filter] = Array.empty override def build: Scan = { - val scan = InMemoryBatchScan( - data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, options) + val scan = if (InMemoryBaseTable.this.ordering.nonEmpty) { + new InMemoryBatchScanWithOrdering( + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, + options) + } else { + InMemoryBatchScan( + data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, tableSchema, + options) + } if (evaluableFilters.nonEmpty) { scan.filter(evaluableFilters) } @@ -596,6 +603,19 @@ abstract class InMemoryBaseTable( } } + // Extends InMemoryBatchScan with SupportsReportOrdering. Only instantiated when the table has a + // non-empty ordering, so that V2ScanPartitioningAndOrdering only sets ordering = Some(...) on the + // logical plan when there is actual ordering to report. + private class InMemoryBatchScanWithOrdering( + data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType, + options: CaseInsensitiveStringMap) + extends InMemoryBatchScan(data, readSchema, tableSchema, options) + with SupportsReportOrdering { + override def outputOrdering(): Array[SortOrder] = InMemoryBaseTable.this.ordering + } + abstract class InMemoryWriterBuilder(val info: LogicalWriteInfo) extends SupportsTruncate with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala index 5d06c8786d894..7f51875b971fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala @@ -69,7 +69,8 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with Logging { private def ordering(plan: LogicalPlan) = plan.transformDown { case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportOrdering, _, _, _) => - val ordering = V2ExpressionUtils.toCatalystOrdering(scan.outputOrdering(), relation) + val ordering = + V2ExpressionUtils.toCatalystOrdering(scan.outputOrdering(), relation, relation.funCatalog) d.copy(ordering = Some(ordering)) } } From 1c9e6ff77287330ada0f3cfc86c6cb26d402d5d1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 31 Mar 2026 19:07:07 +0200 Subject: [PATCH 2/2] [SPARK-55715][SQL] Keep outputOrdering when GroupPartitionsExec coalesces partitions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? #### Background `GroupPartitionsExec` coalesces multiple input partitions that share the same partition key into a single output partition. Before this PR, `outputOrdering` was always discarded after coalescing: even when the child reported ordering (e.g. via `SupportsReportOrdering`) or when ordering was derived from `KeyedPartitioning` key expressions (via `spark.sql.sources.v2.bucketing.partitionKeyOrdering.enabled`), coalescing by simple concatenation destroyed the within-partition ordering. This forced `EnsureRequirements` to inject an extra `SortExec` before `SortMergeJoinExec`, defeating the purpose of using a storage-partitioned join. #### k-way merge: SortedMergeCoalescedRDD This PR introduces `SortedMergeCoalescedRDD`, a new RDD that coalesces partitions by performing a k-way merge instead of simple concatenation. When multiple input partitions share the same key, a priority-queue-based merge interleaves their rows in sorted order, producing a single output partition whose row order matches the child's `outputOrdering`. `GroupPartitionsExec.doExecute()` uses `SortedMergeCoalescedRDD` when all of the following hold: 1. `spark.sql.sources.v2.bucketing.preserveOrderingOnCoalesce.enabled` is `true`. 2. The child reports a non-empty `outputOrdering`. 3. The child subtree is safe for concurrent partition reads (`childIsSafeForKWayMerge`). 4. At least one output partition actually coalesces multiple input partitions. When the config is enabled, the k-way merge is always applied regardless of whether the parent operator actually requires the ordering. Making this dynamic (only merge-sort when required) will be addressed in a follow-up ticket. #### Why k-way merge safety matters: SafeForKWayMerge Unlike `CoalescedRDD`, which processes input partitions sequentially, `SortedMergeCoalescedRDD` opens all N input partition iterators upfront and interleaves reads across them — all on a single JVM thread within a single Spark task. A `SparkPlan` object is shared across all partition computations, so any plan node that stores per-partition mutable state in an instance field rather than inside the partition's iterator closure is aliased across all N concurrent computations. The last writer wins, and any computation that reads or frees state based on its own earlier write will operate on incorrect state (a use-after-free). To avoid this class of bugs, `GroupPartitionsExec` uses a whitelist approach via a new marker trait `SafeForKWayMerge`. Nodes implementing this trait guarantee that all per-partition mutable state is captured inside the partition's iterator closure (e.g. via the `PartitionEvaluatorFactory` pattern), never in shared plan-node instance fields. Unknown node types fall through to unsafe, causing a silent fallback to simple sequential coalescing. The following nodes implement `SafeForKWayMerge`: - `DataSourceV2ScanExecBase` (leaf nodes reading from V2 sources) - `ProjectExec`, `FilterExec` (stateless row-by-row operators) - `WholeStageCodegenExec`, `InputAdapter` (code-gen wrappers that delegate to the above) #### GroupPartitionsExec.outputOrdering `GroupPartitionsExec.outputOrdering` is updated to reflect what ordering is preserved: 1. **No coalescing** (all groups ≤ 1 partition): `child.outputOrdering` is passed through unchanged. 2. **Coalescing with k-way merge** (config enabled + `childIsSafeForKWayMerge`): `child.outputOrdering` is returned in full — the k-way merge produces a globally sorted partition. 3. **Coalescing without k-way merge, no reducers**: only sort orders whose expression is a partition key expression are returned. These key expressions evaluate to the same constant value within every merged partition (all merged splits share the same key), so their sort orders remain valid after concatenation. This is the ordering preserved by the existing `spark.sql.sources.v2.bucketing.preserveKeyOrderingOnCoalesce.enabled` config. 4. **Coalescing without k-way merge, with reducers**: `super.outputOrdering` (empty) — the reduced key can take different values within the output partition, so no ordering is guaranteed. #### DataSourceRDD: concurrent-reader metrics support `SortedMergeCoalescedRDD` opens multiple `PartitionReader`s concurrently within a single Spark task. The existing `DataSourceRDD` assumed at most one active reader per task at a time, causing only the last reader's custom metrics to be reported (the previous readers' metrics were overwritten and lost). `DataSourceRDD` is refactored to support concurrent readers: - A new `TaskState` class (one per task) holds an `ArrayBuffer[PartitionIterator[_]]` (`partitionIterators`) tracking all readers opened for the task, Spark input metrics (`InputMetrics`), and a `closedMetrics` map accumulating final metric values from already-closed readers. - `mergeAndUpdateCustomMetrics()` runs in two phases: (1) drain closed iterators into `closedMetrics`; (2) merge live readers' current values with `closedMetrics` via the new `CustomTaskMetric.mergeWith()` and push the result to the Spark UI accumulators. - This works correctly in all three execution modes: single partition per task, sequential coalescing (one reader at a time), and concurrent k-way merge (N readers simultaneously). #### CustomTaskMetric.mergeWith A new default method `mergeWith(CustomTaskMetric other)` is added to `CustomTaskMetric`. The default implementation sums the two values, which is correct for count-type metrics. Data sources with non-additive metrics (e.g. max, average) should override this method. This replaces the previously proposed `PartitionReader.initMetricsValues` mechanism (which threaded prior metric values into the next reader's constructor) with a cleaner, pull-based merge at reporting time. `PartitionReader.initMetricsValues` becomes deprecated as it is no longer needed. ### Why are the changes needed? Without this fix, `GroupPartitionsExec` always discards ordering when coalescing, forcing `EnsureRequirements` to inject an extra `SortExec` before `SortMergeJoinExec` even when the data is already sorted by the join key within each partition. With `SortedMergeCoalescedRDD`, the full child ordering is preserved end-to-end, eliminating these redundant sorts and making storage-partitioned joins with ordering fully efficient. `spark.sql.sources.v2.bucketing.preserveKeyOrderingOnCoalesce.enabled` (introduced earlier) preserves only sort orders over partition key expressions, which remain constant within a merged partition. This PR goes further: by performing a k-way merge, the full `outputOrdering` — including secondary sort columns beyond the partition key — is preserved end-to-end. ### Does this PR introduce _any_ user-facing change? Yes. A new SQL configuration is added: - `spark.sql.sources.v2.bucketing.preserveOrderingOnCoalesce.enabled` (default: `false`): when enabled, `GroupPartitionsExec` uses a k-way merge to coalesce partitions while preserving the full child ordering, avoiding extra sort steps for operations like `SortMergeJoin`. ### How was this patch tested? - **`SortedMergeCoalescedRDDSuite`**: unit tests for the new RDD covering correctness of the k-way merge, empty partitions, single partition, and ordering guarantees. - **`GroupPartitionsExecSuite`**: unit tests covering all four branches of `outputOrdering` (no coalescing; k-way merge enabled; key-expression ordering only; reducers present). - **`KeyGroupedPartitioningSuite`**: SQL-level tests verifying that no extra `SortExec` is injected when `SortedMergeCoalescedRDD` is used, and a new test (`SPARK-55715: Custom metrics of sorted-merge coalesced partitions`) that verifies per-scan custom metrics are correctly reported across concurrent readers in the k-way merge case. - **`BufferedRowsReader` hardening**: the test-framework reader in `InMemoryBaseTable` now tracks a `closed` flag and throws `IllegalStateException` for reads, double-closes, or metric fetches on a closed reader. This ensures future tests catch reader lifecycle bugs that were previously hidden by the noop `close()`. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Sonnet 4.6 --- .../spark/rdd/SortedMergeCoalescedRDD.scala | 172 +++++++++++++ .../rdd/SortedMergeCoalescedRDDSuite.scala | 219 ++++++++++++++++ .../connector/metric/CustomTaskMetric.java | 23 ++ .../sql/connector/read/PartitionReader.java | 6 + .../apache/spark/sql/internal/SQLConf.scala | 18 ++ .../connector/catalog/InMemoryBaseTable.scala | 20 +- .../sql/execution/SafeForKWayMerge.scala | 27 ++ .../sql/execution/WholeStageCodegenExec.scala | 5 +- .../execution/basicPhysicalOperators.scala | 3 +- .../datasources/v2/DataSourceRDD.scala | 242 ++++++++++-------- .../v2/DataSourceV2ScanExecBase.scala | 4 +- .../datasources/v2/GroupPartitionsExec.scala | 88 ++++++- .../KeyGroupedPartitioningSuite.scala | 229 ++++++++++++++++- .../InMemoryTableMetricSuite.scala | 2 +- .../v2/GroupPartitionsExecSuite.scala | 44 +++- .../spark/sql/test/SharedSparkSession.scala | 8 +- 16 files changed, 979 insertions(+), 131 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala new file mode 100644 index 0000000000000..4cd37e52dd8be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SortedMergeCoalescedRDD.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.io.{IOException, ObjectOutputStream} + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.util.Utils + +/** + * An RDD that coalesces partitions while preserving ordering through k-way merge. + * + * Unlike CoalescedRDD which simply concatenates partitions, this RDD performs a sorted + * merge of multiple input partitions to maintain ordering. This is useful when input + * partitions are locally sorted and we want to preserve that ordering after coalescing. + * + * The merge is performed using a priority queue (min-heap) which provides O(n log k) + * time complexity, where n is the total number of elements and k is the number of + * partitions being merged. + * + * @param prev The parent RDD + * @param numPartitions The number of output partitions after coalescing + * @param ordering The ordering to maintain during merge + * @param partitionCoalescer The coalescer defining how to group input partitions + * @tparam T The element type + */ +private[spark] class SortedMergeCoalescedRDD[T: ClassTag]( + @transient var prev: RDD[T], + numPartitions: Int, + partitionCoalescer: PartitionCoalescer, + ordering: Ordering[T]) + extends RDD[T](prev.context, Nil) { + + override def getPartitions: Array[Partition] = { + partitionCoalescer.coalesce(numPartitions, prev).zipWithIndex.map { + case (pg, i) => + val parentIndices = pg.partitions.map(_.index).toSeq + new SortedMergePartition(i, prev, parentIndices, pg.prefLoc) + } + } + + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + val mergePartition = partition.asInstanceOf[SortedMergePartition] + val parentPartitions = mergePartition.parents + + if (parentPartitions.isEmpty) { + Iterator.empty + } else if (parentPartitions.length == 1) { + // No merge needed for single partition + firstParent[T].iterator(parentPartitions.head, context) + } else { + // Perform k-way merge + new SortedMergeIterator[T]( + parentPartitions.map(p => firstParent[T].iterator(p, context)), + ordering + ) + } + } + + override def getDependencies: Seq[org.apache.spark.Dependency[_]] = { + Seq(new org.apache.spark.NarrowDependency(prev) { + def getParents(id: Int): Seq[Int] = + partitions(id).asInstanceOf[SortedMergePartition].parentsIndices + }) + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + partition.asInstanceOf[SortedMergePartition].prefLoc.toSeq + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + prev = null + } +} + +/** + * Partition for SortedMergeCoalescedRDD that tracks which parent partitions to merge. + * @param index of this coalesced partition + * @param rdd which it belongs to + * @param parentsIndices list of indices in the parent that have been coalesced into this partition + * @param prefLoc the preferred location for this partition + */ +private[spark] class SortedMergePartition( + idx: Int, + @transient private val rdd: RDD[_], + val parentsIndices: Seq[Int], + val prefLoc: Option[String] = None) extends Partition { + override val index: Int = idx + var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent partition at the time of task serialization + parents = parentsIndices.map(rdd.partitions(_)) + oos.defaultWriteObject() + } +} + +/** + * Iterator that performs k-way merge of sorted iterators. + * + * Uses a priority queue (min-heap) to efficiently find the next smallest element + * across all input iterators according to the specified ordering. This provides + * O(n log k) time complexity where n is the total number of elements and k is + * the number of iterators being merged. + * + * @param iterators The sequence of sorted iterators to merge + * @param ordering The ordering to use for comparison + * @tparam T The element type + */ +private[spark] class SortedMergeIterator[T]( + iterators: Seq[Iterator[T]], + ordering: Ordering[T]) extends Iterator[T] { + + // Priority queue entry: (current element, iterator index) + private case class QueueEntry(element: T, iteratorIdx: Int) + + // Min-heap ordered by element according to the provided ordering + private implicit val queueOrdering: Ordering[QueueEntry] = new Ordering[QueueEntry] { + override def compare(x: QueueEntry, y: QueueEntry): Int = { + // Reverse for min-heap (PriorityQueue is max-heap by default) + ordering.compare(y.element, x.element) + } + } + + private val queue = mutable.PriorityQueue.empty[QueueEntry] + + // Initialize queue with first element from each non-empty iterator + iterators.zipWithIndex.foreach { case (iter, idx) => + if (iter.hasNext) { + queue.enqueue(QueueEntry(iter.next(), idx)) + } + } + + override def hasNext: Boolean = queue.nonEmpty + + override def next(): T = { + if (!hasNext) { + throw new NoSuchElementException("next on empty iterator") + } + + val entry = queue.dequeue() + val result = entry.element + + // If the iterator has more elements, add the next one to the queue + val iter = iterators(entry.iteratorIdx) + if (iter.hasNext) { + queue.enqueue(QueueEntry(iter.next(), entry.iteratorIdx)) + } + + result + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala new file mode 100644 index 0000000000000..1cb83af2b3b69 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/SortedMergeCoalescedRDDSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + +class SortedMergeCoalescedRDDSuite extends SparkFunSuite with SharedSparkContext { + + test("SPARK-55715: k-way merge maintains ordering - integers") { + // Create RDD with 4 partitions, each sorted + val data = Seq( + Seq(1, 5, 9, 13), // partition 0 + Seq(2, 6, 10, 14), // partition 1 + Seq(3, 7, 11, 15), // partition 2 + Seq(4, 8, 12, 16) // partition 3 + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + // Coalesce to 2 partitions using sorted merge + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1), Seq(2, 3))) + val merged = new SortedMergeCoalescedRDD[Int](rdd, 2, coalescer, Ordering.Int) + + // Verify per-partition contents: group (0,1) merges elements from partitions 0+1, + // group (2,3) merges elements from partitions 2+3 + val partitionData = merged + .mapPartitionsWithIndex { (idx, iter) => Iterator.single((idx, iter.toSeq)) } + .collect().toMap + + assert(partitionData(0) === Seq(1, 2, 5, 6, 9, 10, 13, 14)) + assert(partitionData(1) === Seq(3, 4, 7, 8, 11, 12, 15, 16)) + } + + test("SPARK-55715: k-way merge handles empty partitions") { + val data = Seq( + Seq(1, 5, 9), // partition 0 + Seq.empty[Int], // partition 1 - empty + Seq(3, 7, 11), // partition 2 + Seq.empty[Int] // partition 3 - empty + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2, 3))) + val merged = new SortedMergeCoalescedRDD[Int]( rdd, 1, coalescer, Ordering.Int) + + val result = merged.collect() + assert(result === Seq(1, 3, 5, 7, 9, 11)) + } + + test("SPARK-55715: k-way merge handles all-empty partitions in a group") { + val data = Seq( + Seq.empty[Int], // partition 0 - empty + Seq.empty[Int], // partition 1 - empty + Seq(1, 2, 3) // partition 2 - non-empty (different group) + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1), Seq(2))) + val merged = new SortedMergeCoalescedRDD[Int]( rdd, 2, coalescer, Ordering.Int) + + val partitionData = merged + .mapPartitionsWithIndex { (idx, iter) => Iterator.single((idx, iter.toSeq)) } + .collect().toMap + + assert(partitionData(0) === Seq.empty) + assert(partitionData(1) === Seq(1, 2, 3)) + } + + test("SPARK-55715: k-way merge with single partition per group - no merge needed") { + val data = Seq(1, 2, 3, 4, 5, 6) + val rdd = sc.parallelize(data, 3) + + // Each group has only 1 partition - should just pass through + val coalescer = new TestPartitionCoalescer(Seq(Seq(0), Seq(1), Seq(2))) + val merged = new SortedMergeCoalescedRDD[Int](rdd, 3, coalescer, Ordering.Int) + + assert(merged.collect() === data) + } + + test("SPARK-55715: k-way merge with reverse ordering") { + val data = Seq( + Seq(13, 9, 5, 1), // partition 0 - descending + Seq(14, 10, 6, 2), // partition 1 - descending + Seq(15, 11, 7, 3), // partition 2 - descending + Seq(16, 12, 8, 4) // partition 3 - descending + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + // Use reverse ordering + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2, 3))) + val merged = new SortedMergeCoalescedRDD[Int]( + rdd, 1, coalescer, Ordering.Int.reverse) + + val result = merged.collect() + assert(result === (1 to 16).reverse) + } + + test("SPARK-55715: k-way merge with many partitions") { + val numPartitions = 20 + val rowsPerPartition = 50 + + // Create sorted data where partition i starts at i and increments by numPartitions + val data = (0 until numPartitions).map { partIdx => + (0 until rowsPerPartition).map(i => partIdx + i * numPartitions) + } + + val rdd = sc.parallelize(data, numPartitions).flatMap(identity) + + // Coalesce all partitions into one + val coalescer = new TestPartitionCoalescer(Seq((0 until numPartitions))) + val merged = new SortedMergeCoalescedRDD[Int](rdd, 1, coalescer, Ordering.Int) + + val result = merged.collect() + assert(result.length === numPartitions * rowsPerPartition) + assert(result === result.sorted) + } + + test("SPARK-55715: k-way merge preserves duplicate elements across partitions") { + val data = Seq( + Seq(1, 2, 3), // partition 0 + Seq(1, 2, 3), // partition 1 - identical to partition 0 + Seq(2, 2, 4) // partition 2 - contains repeated value within partition + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2))) + val merged = new SortedMergeCoalescedRDD[Int](rdd, 1, coalescer, Ordering.Int) + + val result = merged.collect() + assert(result === Seq(1, 1, 2, 2, 2, 2, 3, 3, 4)) + } + + test("SPARK-55715: k-way merge with strings") { + val data = Seq( + Seq("apple", "cherry", "grape"), + Seq("banana", "date", "kiwi"), + Seq("apricot", "fig", "mango") + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2))) + val merged = new SortedMergeCoalescedRDD[String](rdd, 1, coalescer, Ordering.String) + + val result = merged.collect() + assert(result === result.sorted) + } + + test("SPARK-55715: k-way merge with tuples") { + val data = Seq( + Seq((1, "a"), (3, "c"), (5, "e")), + Seq((2, "b"), (4, "d"), (6, "f")), + Seq((1, "z"), (4, "y"), (7, "x")) + ) + + val rdd = sc.parallelize(data, data.size).flatMap(identity) + + implicit val tupleOrdering: Ordering[(Int, String)] = Ordering.by[(Int, String), Int](_._1) + + val coalescer = new TestPartitionCoalescer(Seq(Seq(0, 1, 2))) + val merged = new SortedMergeCoalescedRDD[(Int, String)](rdd, 1, coalescer, tupleOrdering) + + val result = merged.collect() + val expected = data.flatten.sortBy(_._1) + assert(result === expected) + } + + test("SPARK-55715: SortedMergeIterator - next() on empty iterator throws " + + "NoSuchElementException") { + val iter = new SortedMergeIterator[Int](Seq.empty, Ordering.Int) + assert(!iter.hasNext) + intercept[NoSuchElementException] { iter.next() } + } + + test("SPARK-55715: SortedMergeIterator - empty iterators list") { + val iter = new SortedMergeIterator[Int](Seq(Iterator.empty, Iterator.empty), Ordering.Int) + assert(!iter.hasNext) + assert(iter.toSeq === Seq.empty) + } + + test("SPARK-55715: SortedMergeIterator - single iterator passes through unchanged") { + val iter = new SortedMergeIterator[Int](Seq(Iterator(3, 1, 2)), Ordering.Int) + assert(iter.toSeq === Seq(3, 1, 2)) + } +} + +/** + * Test partition coalescer that groups partitions according to a predefined plan. + */ +class TestPartitionCoalescer(grouping: Seq[Seq[Int]]) extends PartitionCoalescer with Serializable { + override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = { + grouping.map { partitionIndices => + val pg = new PartitionGroup(None) + partitionIndices.foreach { idx => + pg.partitions += parent.partitions(idx) + } + pg + }.toArray + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java index 9e4b75c4bd112..0dba93af5f8d6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomTaskMetric.java @@ -47,4 +47,27 @@ public interface CustomTaskMetric { * Returns the long value of custom task metric. */ long value(); + + /** + * Merges this metric with another metric of the same name, returning a new + * {@link CustomTaskMetric} that represents the combined value. This is called when a task reads + * multiple partitions concurrently (e.g., k-way merge coalescing) to produce a single + * task-level value before reporting to the driver. + * + *

The default implementation returns a new metric whose value is the sum of the two values, + * which is correct for count-type metrics. Data sources with non-additive metrics (e.g., max, + * average) should override this method. + * + * @param other another metric with the same name to merge with + * @return a new metric representing the merged value + * @since 4.1.0 + */ + default CustomTaskMetric mergeWith(CustomTaskMetric other) { + final String metricName = this.name(); + final long mergedValue = this.value() + other.value(); + return new CustomTaskMetric() { + @Override public String name() { return metricName; } + @Override public long value() { return mergedValue; } + }; + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java index c12bc14a49c44..8420b6bdbdaf5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java @@ -65,6 +65,12 @@ default CustomTaskMetric[] currentMetricsValues() { * {@link org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} and the reader * is initialized with the metrics returned by the previous reader that belongs to the same * partition. By default, this method does nothing. + * + * @deprecated Use {@link CustomTaskMetric#mergeWith(CustomTaskMetric)} instead. When a task + * reads multiple partitions concurrently or sequentially, {@code DataSourceRDD} now merges + * metrics from all readers via {@code mergeWith} at reporting time, removing the need to + * seed each new reader with the prior reader's values. */ + @Deprecated(since = "4.2.0") default void initMetricsValues(CustomTaskMetric[] metrics) {} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5639b6bbfbf4f..7bd53ed4eda18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2191,6 +2191,21 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED = + buildConf("spark.sql.sources.v2.bucketing.preserveOrderingOnCoalesce.enabled") + .doc(s"When turned on, GroupPartitionsExec will use sorted merge to preserve full " + + s"ordering (as opposed to the key-derived ordering preserved by " + + s"${V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED.key}) when coalescing " + + s"multiple partitions with the same key. This allows eliminating downstream sorts when " + + s"data is both partitioned and sorted. However, sorted merge uses more resources " + + s"(priority queue, comparison overhead) than simple concatenation, especially when " + + s"coalescing many partitions. When turned off, only key-derived ordering is preserved " + + s"during coalescing. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -7763,6 +7778,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingPreserveKeyOrderingOnCoalesceEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED) + def v2BucketingPreserveOrderingOnCoalesceEnabled: Boolean = + getConf(SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 9cc32efedf82a..decff8e2bcbf0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -866,8 +866,13 @@ private class BufferedRowsReader( private var index: Int = -1 private var rowsRead: Long = 0 + private var closed: Boolean = false + + private def checkNotClosed(op: String): Unit = + if (closed) throw new IllegalStateException(s"$op called on a closed BufferedRowsReader") override def next(): Boolean = { + checkNotClosed("next()") index += 1 val hasNext = index < partition.rows.length if (hasNext) rowsRead += 1 @@ -875,6 +880,7 @@ private class BufferedRowsReader( } override def get(): InternalRow = { + checkNotClosed("get()") val originalRow = partition.rows(index) val values = new Array[Any](nonMetadataColumns.length) nonMetadataColumns.zipWithIndex.foreach { case (col, idx) => @@ -884,7 +890,10 @@ private class BufferedRowsReader( addMetadata(new GenericInternalRow(values)) } - override def close(): Unit = {} + override def close(): Unit = { + checkNotClosed("close()") + closed = true + } private def extractFieldValue( field: StructField, @@ -1015,15 +1024,8 @@ private class BufferedRowsReader( private def castElement(elem: Any, toType: DataType, fromType: DataType): Any = Cast(Literal(elem, fromType), toType, None, EvalMode.TRY).eval(null) - override def initMetricsValues(metrics: Array[CustomTaskMetric]): Unit = { - metrics.foreach { m => - m.name match { - case "rows_read" => rowsRead = m.value() - } - } - } - override def currentMetricsValues(): Array[CustomTaskMetric] = { + checkNotClosed("currentMetricsValues()") val metric = new CustomTaskMetric { override def name(): String = "rows_read" override def value(): Long = rowsRead diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala new file mode 100644 index 0000000000000..08be66e3f4477 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SafeForKWayMerge.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * Marker trait for execution nodes that are safe to use with SortedMergeCoalescedRDD + * (concurrent k-way merge). Nodes implementing this trait must store no per-partition + * mutable state in shared plan-node instance fields; all per-partition state must be + * captured inside the partition's iterator closure (e.g. via the + * PartitionEvaluatorFactory pattern). + */ +trait SafeForKWayMerge diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 246508965d3d6..c1c6f36475475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -501,7 +501,8 @@ trait InputRDDCodegen extends CodegenSupport { * This is the leaf node of a tree with WholeStageCodegen that is used to generate code * that consumes an RDD iterator of InternalRow. */ -case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen { +case class InputAdapter(child: SparkPlan) + extends UnaryExecNode with InputRDDCodegen with SafeForKWayMerge { override def output: Seq[Attribute] = child.output @@ -633,7 +634,7 @@ object WholeStageCodegenExec { * used to generated code for [[BoundReference]]. */ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with CodegenSupport with SafeForKWayMerge { override def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 5e73cdd6da8f4..bafe7e568bf2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode with CodegenSupport + with SafeForKWayMerge with PartitioningPreservingUnaryExecNode with OrderPreservingUnaryExecNode { @@ -222,7 +223,7 @@ trait GeneratePredicateHelper extends PredicateHelper { /** Physical plan for Filter. */ case class FilterExec(condition: Expression, child: SparkPlan) - extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper { + extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper with SafeForKWayMerge { // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 2fedb97e8461e..d7bb1d9093edf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -19,34 +19,111 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.concurrent.ConcurrentHashMap -import scala.language.existentials +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.ArrayImplicits._ class DataSourceRDDPartition(val index: Int, val inputPartition: Option[InputPartition]) extends Partition with Serializable /** - * Holds the state for a reader in a task, used by the completion listener to access the most - * recently created reader and iterator for final metrics updates and cleanup. + * Holds all mutable state for a single Spark task reading from a {@link DataSourceRDD}: + *

* - * When `compute()` is called multiple times for the same task (e.g., when DataSourceRDD is - * coalesced), this state is updated on each call to track the most recent reader. The task - * completion listener then uses this most recent reader for final cleanup and metrics reporting. + *

When metrics are reported: + *

* - * @param reader The partition reader - * @param iterator The metrics iterator wrapping the reader + *

Why this works across all execution modes: + *

*/ -private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIterator[_]) +private class TaskState(customMetrics: Map[String, SQLMetric]) { + val partitionIterators = new ArrayBuffer[PartitionIterator[_]]() + + // Input metrics (recordsRead, bytesRead) tracked for this task. + private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics + private val startingBytesRead = inputMetrics.bytesRead + private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + private var recordsReadAtLastBytesUpdate = 0L + private var recordsReadAtLastCustomMetricsUpdate = 0L + + // Pre-merged custom metrics snapshot of all readers closed by natural exhaustion. + // Maintained as a map (one entry per metric name). + private val closedMetrics = new HashMap[String, CustomTaskMetric]() + + def updateMetrics(numRows: Int, force: Boolean = false): Unit = { + inputMetrics.incRecordsRead(numRows) + val shouldUpdateBytesRead = force || + inputMetrics.recordsRead - recordsReadAtLastBytesUpdate >= + SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS + if (shouldUpdateBytesRead) { + recordsReadAtLastBytesUpdate = inputMetrics.recordsRead + inputMetrics.setBytesRead(startingBytesRead + getBytesRead()) + } + val shouldUpdateCustomMetrics = force || + inputMetrics.recordsRead - recordsReadAtLastCustomMetricsUpdate >= + CustomMetrics.NUM_ROWS_PER_UPDATE + if (shouldUpdateCustomMetrics) { + recordsReadAtLastCustomMetricsUpdate = inputMetrics.recordsRead + mergeAndUpdateCustomMetrics() + } + } + + private def mergeAndUpdateCustomMetrics(): Unit = { + partitionIterators.filterInPlace { iter => + if (iter.isClosed) { + iter.finalMetrics.foreach { m => + closedMetrics.update(m.name(), closedMetrics.get(m.name()).fold(m)(_.mergeWith(m))) + } + false + } else true + } + val mergedMetrics = (partitionIterators.flatMap(_.currentMetricsValues) ++ closedMetrics.values) + .groupMapReduce(_.name())(identity)(_.mergeWith(_)) + .values + .toSeq + if (mergedMetrics.nonEmpty) { + CustomMetrics.updateMetrics(mergedMetrics, customMetrics) + } + } +} // TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for // columnar scan. @@ -71,10 +148,8 @@ class DataSourceRDD( customMetrics: Map[String, SQLMetric]) extends RDD[InternalRow](sc, Nil) { - // Map from task attempt ID to the most recently created ReaderState for that task. - // When compute() is called multiple times for the same task (due to coalescing), the map entry - // is updated each time so the completion listener always closes the last reader. - @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long, ReaderState]() + // One TaskState per task attempt. + @transient private lazy val taskStates = new ConcurrentHashMap[Long, TaskState]() override protected def getPartitions: Array[Partition] = { inputPartitions.zipWithIndex.map { @@ -90,52 +165,39 @@ class DataSourceRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val taskAttemptId = context.taskAttemptId() - // Add completion listener only once per task attempt. When compute() is called a second time - // for the same task (e.g., due to coalescing), the first call will have already put a - // ReaderState into taskReaderStates, so containsKey returns true and we skip this block. - if (!taskReaderStates.containsKey(taskAttemptId)) { + // Ensure a TaskState exists for this task and register the completion listener on the + // first compute() call. computeIfAbsent is atomic; same-task calls are always on one + // thread, so partitionIterators.isEmpty reliably identifies the first call. + val taskState = taskStates.computeIfAbsent(taskAttemptId, _ => new TaskState(customMetrics)) + + if (taskState.partitionIterators.isEmpty) { context.addTaskCompletionListener[Unit] { ctx => - // In case of early stopping before consuming the entire iterator, - // we need to do one more metric update at the end of the task. + // In case of early stopping, do a final metrics update and close all readers. try { - val readerState = taskReaderStates.get(ctx.taskAttemptId()) - if (readerState != null) { - CustomMetrics.updateMetrics( - readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics) - readerState.iterator.forceUpdateMetrics() - readerState.reader.close() + val taskState = taskStates.get(ctx.taskAttemptId()) + if (taskState != null) { + taskState.updateMetrics(0, force = true) + taskState.partitionIterators.foreach(_.close()) } } finally { - taskReaderStates.remove(ctx.taskAttemptId()) + taskStates.remove(ctx.taskAttemptId()) } } } castPartition(split).inputPartition.iterator.flatMap { inputPartition => - val (iter, reader) = if (columnarReads) { + val iter = if (columnarReads) { val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) - val iter = new MetricsBatchIterator( - new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) - (iter, batchReader) + new PartitionIterator[ColumnarBatch](batchReader, _.numRows, taskState) } else { val rowReader = partitionReaderFactory.createReader(inputPartition) - val iter = new MetricsRowIterator( - new PartitionIterator[InternalRow](rowReader, customMetrics)) - (iter, rowReader) + new PartitionIterator[InternalRow](rowReader, _ => 1, taskState) } - // Flush metrics and close the previous reader before advancing to the next one. - // Pass the accumulated metrics to the new reader so they carry forward correctly. - val prevState = taskReaderStates.get(taskAttemptId) - if (prevState != null) { - val metrics = prevState.reader.currentMetricsValues - CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics) - reader.initMetricsValues(metrics) - prevState.reader.close() - } - - // Update the map so the completion listener always references the latest reader. - taskReaderStates.put(taskAttemptId, ReaderState(reader, iter)) + // Track this iterator; early-stop close and final metrics-flush for iterators not yet + // naturally exhausted are handled by the task completion listener. This avoids closing + // live iterators prematurely in the concurrent k-way merge (SortedMergeCoalescedRDD). + taskState.partitionIterators += iter // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator(context, iter.asInstanceOf[Iterator[InternalRow]]) @@ -148,16 +210,36 @@ class DataSourceRDD( } private class PartitionIterator[T]( - reader: PartitionReader[T], - customMetrics: Map[String, SQLMetric]) extends Iterator[T] { - private[this] var valuePrepared = false - private[this] var hasMoreInput = true + private var reader: PartitionReader[T], + rowCount: T => Int, + taskState: TaskState) extends Iterator[T] { + private var valuePrepared = false + private var hasMoreInput = true + + // Cached final metrics snapshot, captured just before the reader is closed on natural + // exhaustion. Allows mergeAndUpdateCustomMetrics() to include this reader's contribution + // after close() has been called and currentMetricsValues() is no longer valid. + private var cachedFinalMetrics: Array[CustomTaskMetric] = Array.empty + + def isClosed: Boolean = reader == null + + def finalMetrics: Array[CustomTaskMetric] = cachedFinalMetrics + + def close(): Unit = if (reader != null) { + reader.close() + reader = null + } - private var numRow = 0L + def currentMetricsValues: Array[CustomTaskMetric] = reader.currentMetricsValues override def hasNext: Boolean = { if (!valuePrepared && hasMoreInput) { hasMoreInput = reader.next() + if (!hasMoreInput) { + cachedFinalMetrics = reader.currentMetricsValues + taskState.updateMetrics(0, force = true) + close() + } valuePrepared = hasMoreInput } valuePrepared @@ -167,59 +249,9 @@ private class PartitionIterator[T]( if (!hasNext) { throw QueryExecutionErrors.endOfStreamError() } - if (numRow % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) { - CustomMetrics.updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics) - } - numRow += 1 valuePrepared = false - reader.get() - } -} - -private class MetricsHandler extends Logging with Serializable { - private val inputMetrics = TaskContext.get().taskMetrics().inputMetrics - private val startingBytesRead = inputMetrics.bytesRead - private val getBytesRead = SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - - def updateMetrics(numRows: Int, force: Boolean = false): Unit = { - inputMetrics.incRecordsRead(numRows) - val shouldUpdateBytesRead = - inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0 - if (shouldUpdateBytesRead || force) { - inputMetrics.setBytesRead(startingBytesRead + getBytesRead()) - } - } -} - -private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { - protected val metricsHandler = new MetricsHandler - - override def hasNext: Boolean = { - if (iter.hasNext) { - true - } else { - forceUpdateMetrics() - false - } - } - - def forceUpdateMetrics(): Unit = metricsHandler.updateMetrics(0, force = true) -} - -private class MetricsRowIterator( - iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) { - override def next(): InternalRow = { - val item = iter.next() - metricsHandler.updateMetrics(1) - item - } -} - -private class MetricsBatchIterator( - iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) { - override def next(): ColumnarBatch = { - val batch: ColumnarBatch = iter.next() - metricsHandler.updateMetrics(batch.numRows) - batch + val result = reader.get() + taskState.updateMetrics(rowCount(result)) + result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index a1a6c6e022482..0bec918039775 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -24,14 +24,14 @@ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.KeyedPartitioning import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan} -import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SQLExecution} +import org.apache.spark.sql.execution.{ExplainUtils, LeafExecNode, SafeForKWayMerge, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils -trait DataSourceV2ScanExecBase extends LeafExecNode { +trait DataSourceV2ScanExecBase extends LeafExecNode with SafeForKWayMerge { lazy val customMetrics = scan.supportedCustomMetrics().map { customMetric => customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 81981c29b2b31..e552c4f71641f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -21,13 +21,14 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Partition, SparkException} -import org.apache.spark.rdd.{CoalescedRDD, PartitionCoalescer, PartitionGroup, RDD} +import org.apache.spark.rdd.{CoalescedRDD, PartitionCoalescer, PartitionGroup, RDD, SortedMergeCoalescedRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning} import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.functions.Reducer -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{SafeForKWayMerge, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -158,16 +159,91 @@ case class GroupPartitionsExec( @transient lazy val isGrouped: Boolean = groupedPartitionsTuple._2 + // Whether the child subtree is safe to use with SortedMergeCoalescedRDD (k-way merge). + // + // --- The general problem --- + // + // Unlike a simple CoalescedRDD, which processes input partitions sequentially (open P0, + // exhaust P0, open P1, exhaust P1, ...), SortedMergeCoalescedRDD must perform a k-way merge: + // it opens *all* input partition iterators upfront and then interleaves reads across them to + // produce a globally sorted output partition. Crucially, all of this happens on a single JVM + // thread within a single Spark task. + // + // A SparkPlan object is shared across all partition computations of a task. When N partition + // iterators are live concurrently on the same thread, N independent "computations" are all + // operating through the same plan node instances. If any node in the subtree stores + // per-partition state in an instance field rather than as a local variable captured inside the + // partition's iterator closure, that field is aliased across all N concurrent computations. + // Whichever computation last wrote the field "wins", and any computation that then reads or + // frees it based on its own earlier write will operate on the wrong state. + // + // --- The correct pattern: PartitionEvaluatorFactory --- + // + // The PartitionEvaluatorFactory / PartitionEvaluator pattern is specifically designed to avoid + // this problem. The factory's createEvaluator() is called once per partition and returns a + // fresh PartitionEvaluator instance. All per-partition mutable state lives inside that + // evaluator instance, not on the shared plan node. Operators that follow this pattern + // exclusively (and hold no other mutable state on the plan node) are safe for k-way merge. + // + // --- Concrete example of an unsafe operator: SortExec + SortMergeJoinExec --- + // + // SortExec stores its active sorter in a plain var field (`rowSorter`) on the plan node. + // When the k-way merge initialises its N partition iterators, each one drives the same SortExec + // instance and calls createSorter(), which assigns rowSorter = newSorter -- overwriting the + // field each time. After all N iterators are initialised, rowSorter holds only the sorter + // created for the *last* partition. + // + // SortMergeJoinExec performs eager resource cleanup: when a join partition is exhausted it + // calls cleanupResources() on its children, which reaches SortExec.cleanupResources(). That + // method calls rowSorter.cleanupResources() -- but rowSorter now holds the last-created sorter, + // not the one belonging to the just-exhausted partition. If that last sorter is still being + // actively read by another partition in the k-way merge, freeing it causes a use-after-free. + // + // To be conservative, we use a whitelist: unknown node types fall through to unsafe, causing + // a fallback to simple sequential coalescing. Only node types explicitly confirmed to store no + // per-partition state in shared (plan node) instance fields are permitted. + @transient private lazy val childIsSafeForKWayMerge: Boolean = + !child.exists { + case _: SafeForKWayMerge => false + case _ => true + } + override protected def doExecute(): RDD[InternalRow] = { if (groupedPartitions.isEmpty) { sparkContext.emptyRDD + } else if (SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled && + child.outputOrdering.nonEmpty && + childIsSafeForKWayMerge && + groupedPartitions.exists(_._2.size > 1)) { + // Use sorted merge when: + // 1. Config is enabled + // 2. Child has ordering + // 3. Actually coalescing multiple partitions + // Convert SortOrder expressions to Ordering[InternalRow] + val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) + val rowOrdering = new InterpretedOrdering(child.outputOrdering, child.output) + new SortedMergeCoalescedRDD[InternalRow]( + child.execute(), + groupedPartitions.size, + partitionCoalescer, + rowOrdering) } else { val partitionCoalescer = new GroupedPartitionCoalescer(groupedPartitions.map(_._2)) + // Use simple coalescing when config is disabled, no ordering, or no actual coalescing new CoalescedRDD(child.execute(), groupedPartitions.size, Some(partitionCoalescer)) } } - override def supportsColumnar: Boolean = child.supportsColumnar + override def supportsColumnar: Boolean = { + // Don't use columnar when sorted merge coalescing is needed, since we can't preserve + // ordering with sorted merge for columnar batches + val needsSortedMerge = SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled && + child.outputOrdering.nonEmpty && + childIsSafeForKWayMerge && + groupedPartitions.exists(_._2.size > 1) + + child.supportsColumnar && !needsSortedMerge + } override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { if (groupedPartitions.isEmpty) { @@ -189,6 +265,12 @@ case class GroupPartitionsExec( // within-partition ordering is fully preserved (including any key-derived ordering that // `DataSourceV2ScanExecBase` already prepended). child.outputOrdering + } else if (SQLConf.get.v2BucketingPreserveOrderingOnCoalesceEnabled && + child.outputOrdering.nonEmpty && + childIsSafeForKWayMerge) { + // Coalescing with sorted merge: SortedMergeCoalescedRDD performs a k-way merge using the + // child's ordering, so the full within-partition ordering is preserved end-to-end. + child.outputOrdering } else { // Coalescing: multiple input partitions are merged into one output partition. The child's // within-partition ordering is lost due to concatenation -- for example, if two input diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 44cb3a23cfa7c..4a406322a5a19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -20,6 +20,7 @@ import java.sql.Timestamp import java.util.Collections import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.rdd.SortedMergeCoalescedRDD import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Literal, TransformExpression} @@ -252,9 +253,10 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with table: String, columns: Array[Column], partitions: Array[Transform], + ordering: Array[SortOrder] = Array.empty, catalog: InMemoryTableCatalog = catalog): Unit = { catalog.createTable(Identifier.of(Array("ns"), table), - columns, partitions, emptyProps, Distributions.unspecified(), Array.empty, None, None, + columns, partitions, emptyProps, Distributions.unspecified(), ordering, None, None, numRowsPerSplit = 1) } @@ -3091,7 +3093,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with assert(groupPartitions(0).outputPartitioning.numPartitions === 2, "group partitions should have 2 partition groups") } - assert(metrics("number of rows read") == "3") + assert(metrics.collect { + case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v + } === Seq("3")) } test("SPARK-55619: Custom metrics of coalesced partitions") { @@ -3106,7 +3110,68 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with val df = sql(s"SELECT * FROM testcat.ns.$items").coalesce(1) df.collect() } - assert(metrics("number of rows read") == "2") + assert(metrics.collect { + case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v + } === Seq("2")) + } + + test("SPARK-55715: Custom metrics of sorted-merge coalesced partitions") { + // items has id=1 on three splits with interleaved arrive_times -- out of order across splits. + // purchases has item_id=1 on two splits, also out of order. Both sides coalesce under SMJ, + // using SortedMergeCoalescedRDD with multiple concurrent readers per task. This test verifies + // that all rows from both tables (5 + 4 = 9) are accounted for in the per-scan metrics. + val itemOrdering = Array( + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(items, itemsColumns, Array(identity("id")), itemOrdering) + // Rows inserted out of order: id=1 lands on partitions 1, 3, 4 with arrive_times + // [2022-03-10, 2021-05-20, 2025-09-01] -- out of order. + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(3, 'dd', 40.0, cast('2024-01-01' as timestamp)), " + + "(1, 'bb', 20.0, cast('2022-03-10' as timestamp)), " + + "(2, 'cc', 30.0, cast('2023-06-15' as timestamp)), " + + "(1, 'aa', 10.0, cast('2021-05-20' as timestamp)), " + + "(1, 'ee', 50.0, cast('2025-09-01' as timestamp))") + + val purchaseOrdering = Array( + sort(FieldReference("item_id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(purchases, purchasesColumns, Array(identity("item_id")), purchaseOrdering) + // item_id=1 lands on partitions 1 and 3 with times [2022-03-10, 2021-05-20] -- out of order. + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 30.0, cast('2023-06-15' as timestamp)), " + + "(1, 20.0, cast('2022-03-10' as timestamp)), " + + "(3, 40.0, cast('2024-01-01' as timestamp)), " + + "(1, 10.0, cast('2021-05-20' as timestamp))") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "true") { + val metrics = runAndFetchMetrics { + val df = sql( + s"""${selectWithMergeJoinHint("i", "p")} + |i.id, i.name + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.item_id = i.id AND p.time = i.arrive_time + |""".stripMargin) + checkAnswer(df, Seq(Row(1, "aa"), Row(1, "bb"), Row(2, "cc"), Row(3, "dd"))) + val plan = df.queryExecution.executedPlan + val groupPartitions = collectAllGroupPartitions(plan) + val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1)) + assert(coalescingGP.nonEmpty, "expected a coalescing GroupPartitionsExec") + coalescingGP.foreach { gp => + assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]], + "should use SortedMergeCoalescedRDD when preserve-ordering config is enabled") + } + } + assert(metrics.collect { + case ((_, "BatchScan testcat.ns.items", "number of rows read"), v) => v + } === Seq("5")) + assert(metrics.collect { + case ((_, "BatchScan testcat.ns.purchases", "number of rows read"), v) => v + } === Seq("4")) + } } test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " + @@ -3690,4 +3755,162 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } } + + test("SPARK-55715: preserve outputOrdering when coalescing partitions with sorted merge") { + // Both tables are partitioned by their id column and report ordering [id ASC, price ASC] + // via SupportsReportOrdering. Each has two rows with id=1 (two splits), so GroupPartitionsExec + // must coalesce them. We join on (id, price) = (item_id, price) using SMJ. + // + // With config enabled: SortedMergeCoalescedRDD performs a k-way merge preserving the full + // [id ASC, price ASC] ordering -> EnsureRequirements is satisfied -> no SortExec added. + // With config disabled: simple CoalescedRDD concatenates the splits and only the key-derived + // [id ASC] ordering survives -> price ordering is lost -> SortExec is added for price. + val itemOrdering = Array( + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(items, itemsColumns, Array(identity("id")), itemOrdering) + // Rows inserted out of order: id values are interleaved and arrive_time is not monotone + // within each id group, so ordering by [id ASC, arrive_time ASC] is non-trivial. + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(2, 'cc', 30.0, cast('2023-06-15' as timestamp)), " + + "(1, 'bb', 20.0, cast('2022-03-10' as timestamp)), " + + "(3, 'dd', 40.0, cast('2024-01-01' as timestamp)), " + + "(1, 'aa', 10.0, cast('2021-05-20' as timestamp)), " + + "(2, 'ee', 50.0, cast('2025-09-01' as timestamp))") + + val purchaseOrdering = Array( + sort(FieldReference("item_id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(purchases, purchasesColumns, Array(identity("item_id")), purchaseOrdering) + // Also inserted out of order + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 50.0, cast('2025-09-01' as timestamp)), " + + "(1, 10.0, cast('2021-05-20' as timestamp)), " + + "(3, 40.0, cast('2024-01-01' as timestamp)), " + + "(2, 30.0, cast('2023-06-15' as timestamp)), " + + "(1, 20.0, cast('2022-03-10' as timestamp))") + + Seq(true, false).foreach { preserveOrdering => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> + preserveOrdering.toString) { + val df = sql( + s""" + |${selectWithMergeJoinHint("i", "p")} + |i.id, i.name + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.item_id = i.id AND p.time = i.arrive_time + |""".stripMargin) + checkAnswer(df, Seq( + Row(1, "aa"), Row(1, "bb"), Row(2, "cc"), Row(2, "ee"), Row(3, "dd"))) + + val plan = df.queryExecution.executedPlan + assert(collectAllShuffles(plan).isEmpty, "should not contain any shuffle") + + val groupPartitions = collectAllGroupPartitions(plan) + assert(groupPartitions.nonEmpty, "should contain GroupPartitionsExec for coalescing") + assert(groupPartitions.exists(_.groupedPartitions.exists(_._2.size > 1)), + "expected coalescing GroupPartitionsExec") + + val smjs = collect(plan) { case j: SortMergeJoinExec => j } + assert(smjs.nonEmpty, "expected SortMergeJoinExec in plan") + smjs.foreach { smj => + val sorts = smj.children.flatMap(child => collect(child) { case s: SortExec => s }) + if (preserveOrdering) { + assert(sorts.isEmpty, + "config enabled: SortedMergeCoalescedRDD preserves [id ASC, arrive_time ASC], " + + "no SortExec should be added before SMJ") + + // Also verify the k-way merge RDD is actually used + val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1)) + coalescingGP.foreach { gp => + assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]], + "config enabled: should use SortedMergeCoalescedRDD") + } + } else { + assert(sorts.nonEmpty, + "config disabled: simple coalescing loses arrive_time ordering, " + + "SortExec should be added before SMJ") + } + } + } + } + } + + test("SPARK-55715: preserve outputOrdering when coalescing transform-partitioned splits") { + // Both tables are partitioned by years("arrive_time") / years("time") and report ordering + // [arrive_time ASC] / [time ASC]. Two rows share the same year bucket (2022 and 2023), so + // GroupPartitionsExec coalesces two splits per year. We join solely on + // p.time = i.arrive_time (the partition key expression) using SMJ. + // + // With config enabled: SortedMergeCoalescedRDD k-way merge preserves [arrive_time ASC] + // ordering -> EnsureRequirements is satisfied -> no SortExec added. + // With config disabled: simple CoalescedRDD only preserves the key-derived year ordering -> + // arrive_time ordering within a year is lost -> SortExec is added. + val itemOrdering = Array( + sort(FieldReference("arrive_time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(items, itemsColumns, Array(years("arrive_time")), itemOrdering) + // Inserted out of order: within year 2022, September is before March in insertion order + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(2, 'bb', 20.0, cast('2022-09-20' as timestamp)), " + + "(4, 'dd', 40.0, cast('2023-11-05' as timestamp)), " + + "(1, 'aa', 10.0, cast('2022-03-15' as timestamp)), " + + "(3, 'cc', 30.0, cast('2023-01-10' as timestamp))") + + val purchaseOrdering = Array( + sort(FieldReference("time"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + createTable(purchases, purchasesColumns, Array(years("time")), purchaseOrdering) + // Also inserted out of order + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(2, 20.0, cast('2022-09-20' as timestamp)), " + + "(4, 40.0, cast('2023-11-05' as timestamp)), " + + "(1, 10.0, cast('2022-03-15' as timestamp)), " + + "(3, 30.0, cast('2023-01-10' as timestamp))") + + Seq(true, false).foreach { preserveOrdering => + withSQLConf( + SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> + preserveOrdering.toString) { + val df = sql( + s""" + |${selectWithMergeJoinHint("i", "p")} + |i.id, i.name + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON p.time = i.arrive_time + |""".stripMargin) + checkAnswer(df, Seq(Row(1, "aa"), Row(2, "bb"), Row(3, "cc"), Row(4, "dd"))) + + val plan = df.queryExecution.executedPlan + assert(collectAllShuffles(plan).isEmpty, "should not contain any shuffle") + + val groupPartitions = collectAllGroupPartitions(plan) + assert(groupPartitions.nonEmpty, "should contain GroupPartitionsExec for coalescing") + assert(groupPartitions.exists(_.groupedPartitions.exists(_._2.size > 1)), + "expected coalescing GroupPartitionsExec") + + val smjs = collect(plan) { case j: SortMergeJoinExec => j } + assert(smjs.nonEmpty, "expected SortMergeJoinExec in plan") + smjs.foreach { smj => + val sorts = smj.children.flatMap(child => collect(child) { case s: SortExec => s }) + if (preserveOrdering) { + assert(sorts.isEmpty, + "config enabled: SortedMergeCoalescedRDD preserves [arrive_time ASC], " + + "no SortExec should be added before SMJ") + + val coalescingGP = groupPartitions.filter(_.groupedPartitions.exists(_._2.size > 1)) + coalescingGP.foreach { gp => + assert(gp.execute().isInstanceOf[SortedMergeCoalescedRDD[_]], + "config enabled: should use SortedMergeCoalescedRDD") + } + } else { + assert(sorts.nonEmpty, + "config disabled: simple coalescing loses arrive_time ordering within a year, " + + "SortExec should be added before SMJ") + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala index 502424d58d2cb..06e55a179d4e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala @@ -53,7 +53,7 @@ class InMemoryTableMetricSuite Array(Column.create("i", IntegerType)), Array.empty[Transform], Collections.emptyMap[String, String]) - val metrics = runAndFetchMetrics(func("testcat.table_name")) + val metrics = runAndFetchMetrics(func("testcat.table_name")).map(m => m._1._3 -> m._2) checker(metrics) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala index c37e051929555..72b80cbe05d6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, PartitioningCollection} -import org.apache.spark.sql.execution.DummySparkPlan +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} +import org.apache.spark.sql.execution.{DummySparkPlan, LeafExecNode, SafeForKWayMerge} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.IntegerType @@ -141,4 +142,41 @@ class GroupPartitionsExecSuite extends SharedSparkSession { assert(!gpe.groupedPartitions.forall(_._2.size <= 1), "expected coalescing") assert(gpe.outputOrdering === Nil) } + + test("SPARK-55715: coalescing with sorted merge config enabled returns full child ordering") { + // Key 1 appears on partitions 0 and 2, causing coalescing. The child is a LeafExecNode + // so childIsSafeForKWayMerge = true. With the preserve-ordering config enabled, case 2 + // of outputOrdering kicks in and the full child ordering (including the non-key exprC) must + // be returned, not just the subset of key-expression orders. + val partitionKeys = Seq(row(1), row(2), row(1)) + val childOrdering = Seq(SortOrder(exprA, Ascending), SortOrder(exprC, Ascending)) + val child = DummyLeafSparkPlan( + outputPartitioning = KeyedPartitioning(Seq(exprA), partitionKeys), + outputOrdering = childOrdering) + val gpe = GroupPartitionsExec(child) + + assert(!gpe.groupedPartitions.forall(_._2.size <= 1), "expected coalescing") + withSQLConf(SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "true") { + // Config enabled: k-way merge preserves full ordering including non-key exprC. + assert(gpe.outputOrdering === childOrdering) + } + withSQLConf( + SQLConf.V2_BUCKETING_PRESERVE_ORDERING_ON_COALESCE_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_PRESERVE_KEY_ORDERING_ON_COALESCE_ENABLED.key -> "true") { + // Sorted-merge config disabled, key-ordering config enabled: only key-expression orders + // survive simple concatenation (non-key exprC is dropped). + val ordering = gpe.outputOrdering + assert(ordering.length === 1) + assert(ordering.head.child === exprA) + } + } +} + +private case class DummyLeafSparkPlan( + override val outputOrdering: Seq[SortOrder] = Nil, + override val outputPartitioning: Partitioning = UnknownPartitioning(0) + ) extends LeafExecNode with SafeForKWayMerge { + override protected def doExecute(): RDD[InternalRow] = + throw new UnsupportedOperationException + override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 456e7ed3478a6..8f5157acca7d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -54,7 +54,9 @@ trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase { } } - def runAndFetchMetrics(func: => Unit): Map[String, String] = { + // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that + // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. + def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = { val statusStore = spark.sharedState.statusStore val oldCount = statusStore.executionsList().size @@ -73,7 +75,9 @@ trait SharedSparkSession extends SQLTestUtils with SharedSparkSessionBase { val exec = statusStore.executionsList().last val execId = exec.executionId - val sqlMetrics = exec.metrics.map { metric => metric.accumulatorId -> metric.name }.toMap + val sqlMetrics = statusStore.planGraph(execId).allNodes + .flatMap(n => n.metrics.map(m => (m.accumulatorId, (n.id, n.name, m.name)))) + .toMap statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> v } } }