Skip to content

Commit 60baa66

Browse files
committed
enhancement: integrate batch coalescer with repartition exec
1 parent d24eb4a commit 60baa66

File tree

2 files changed

+152
-81
lines changed

2 files changed

+152
-81
lines changed

datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,38 @@ impl ExecutionPlan for CustomPlan {
134134
_partition: usize,
135135
_context: Arc<TaskContext>,
136136
) -> Result<SendableRecordBatchStream> {
137-
Ok(Box::pin(RecordBatchStreamAdapter::new(
138-
self.schema(),
139-
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
140-
)))
137+
if self.batches.is_empty() {
138+
Ok(Box::pin(RecordBatchStreamAdapter::new(
139+
self.schema(),
140+
futures::stream::empty(),
141+
)))
142+
} else {
143+
let batch_schema = self.batches[0].schema();
144+
let projection: Vec<usize> = self
145+
.schema()
146+
.fields()
147+
.iter()
148+
.filter_map(|field| batch_schema.index_of(field.name()).ok())
149+
.collect();
150+
151+
Ok(Box::pin(RecordBatchStreamAdapter::new(
152+
self.schema(),
153+
futures::stream::iter(self.batches.clone().into_iter().map(
154+
move |batch| {
155+
let res = batch.project(&projection);
156+
match res {
157+
Ok(b) => Ok(b),
158+
Err(e) => {
159+
Err(datafusion_common::DataFusionError::ArrowError(
160+
Box::new(e),
161+
None,
162+
))
163+
}
164+
}
165+
},
166+
)),
167+
)))
168+
}
141169
}
142170

143171
fn statistics(&self) -> Result<Statistics> {

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 120 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
3030
use super::{
3131
DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
3232
};
33+
use crate::coalesce::LimitedBatchCoalescer;
3334
use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
3435
use crate::hash_utils::create_hashes;
3536
use crate::metrics::{BaselineMetrics, SpillMetrics};
@@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
932933
spill_stream,
933934
1, // Each receiver handles one input partition
934935
BaselineMetrics::new(&metrics, partition),
936+
context.session_config().batch_size() / num_input_partitions,
935937
)) as SendableRecordBatchStream
936938
})
937939
.collect::<Vec<_>>();
@@ -959,7 +961,6 @@ impl ExecutionPlan for RepartitionExec {
959961
.into_iter()
960962
.next()
961963
.expect("at least one spill reader should exist");
962-
963964
Ok(Box::pin(PerPartitionStream::new(
964965
schema_captured,
965966
rx.into_iter()
@@ -970,6 +971,7 @@ impl ExecutionPlan for RepartitionExec {
970971
spill_stream,
971972
num_input_partitions,
972973
BaselineMetrics::new(&metrics, partition),
974+
context.session_config().batch_size(),
973975
)) as SendableRecordBatchStream)
974976
}
975977
})
@@ -1427,9 +1429,12 @@ struct PerPartitionStream {
14271429

14281430
/// Execution metrics
14291431
baseline_metrics: BaselineMetrics,
1432+
1433+
batch_coalescer: LimitedBatchCoalescer,
14301434
}
14311435

14321436
impl PerPartitionStream {
1437+
#[allow(clippy::too_many_arguments)]
14331438
fn new(
14341439
schema: SchemaRef,
14351440
receiver: DistributionReceiver<MaybeBatch>,
@@ -1438,16 +1443,29 @@ impl PerPartitionStream {
14381443
spill_stream: SendableRecordBatchStream,
14391444
num_input_partitions: usize,
14401445
baseline_metrics: BaselineMetrics,
1446+
batch_size: usize,
14411447
) -> Self {
14421448
Self {
1443-
schema,
1449+
schema: Arc::clone(&schema),
14441450
receiver,
14451451
_drop_helper: drop_helper,
14461452
reservation,
14471453
spill_stream,
14481454
state: StreamState::ReadingMemory,
14491455
remaining_partitions: num_input_partitions,
14501456
baseline_metrics,
1457+
batch_coalescer: LimitedBatchCoalescer::new(schema, batch_size, None),
1458+
}
1459+
}
1460+
1461+
fn flush_remaining_batch(
1462+
&mut self,
1463+
) -> Poll<Option<std::result::Result<RecordBatch, DataFusionError>>> {
1464+
// Flush any remaining buffered batch
1465+
match self.batch_coalescer.finish() {
1466+
Ok(()) => Poll::Ready(self.batch_coalescer.next_completed_batch().map(Ok)),
1467+
1468+
Err(e) => Poll::Ready(Some(Err(e))),
14511469
}
14521470
}
14531471

@@ -1460,75 +1478,82 @@ impl PerPartitionStream {
14601478
let _timer = cloned_time.timer();
14611479

14621480
loop {
1463-
match self.state {
1464-
StreamState::ReadingMemory => {
1465-
// Poll the memory channel for next message
1466-
let value = match self.receiver.recv().poll_unpin(cx) {
1467-
Poll::Ready(v) => v,
1468-
Poll::Pending => {
1469-
// Nothing from channel, wait
1470-
return Poll::Pending;
1471-
}
1472-
};
1473-
1474-
match value {
1475-
Some(Some(v)) => match v {
1476-
Ok(RepartitionBatch::Memory(batch)) => {
1477-
// Release memory and return batch
1478-
self.reservation
1479-
.lock()
1480-
.shrink(batch.get_array_memory_size());
1481-
return Poll::Ready(Some(Ok(batch)));
1481+
loop {
1482+
match self.state {
1483+
StreamState::ReadingMemory => {
1484+
// Poll the memory channel for next message
1485+
let value = match self.receiver.recv().poll_unpin(cx) {
1486+
Poll::Ready(v) => v,
1487+
Poll::Pending => {
1488+
// Nothing from channel, wait
1489+
return Poll::Pending;
14821490
}
1483-
Ok(RepartitionBatch::Spilled) => {
1484-
// Batch was spilled, transition to reading from spill stream
1485-
// We must block on spill stream until we get the batch
1486-
// to preserve ordering
1487-
self.state = StreamState::ReadingSpilled;
1491+
};
1492+
1493+
match value {
1494+
Some(Some(v)) => match v {
1495+
Ok(RepartitionBatch::Memory(batch)) => {
1496+
// Release memory and return batch
1497+
self.reservation
1498+
.lock()
1499+
.shrink(batch.get_array_memory_size());
1500+
self.batch_coalescer.push_batch(batch)?;
1501+
break;
1502+
}
1503+
Ok(RepartitionBatch::Spilled) => {
1504+
// Batch was spilled, transition to reading from spill stream
1505+
// We must block on spill stream until we get the batch
1506+
// to preserve ordering
1507+
self.state = StreamState::ReadingSpilled;
1508+
continue;
1509+
}
1510+
Err(e) => {
1511+
return Poll::Ready(Some(Err(e)));
1512+
}
1513+
},
1514+
Some(None) => {
1515+
// One input partition finished
1516+
self.remaining_partitions -= 1;
1517+
if self.remaining_partitions == 0 {
1518+
// All input partitions finished
1519+
return self.flush_remaining_batch();
1520+
}
1521+
// Continue to poll for more data from other partitions
14881522
continue;
14891523
}
1490-
Err(e) => {
1491-
return Poll::Ready(Some(Err(e)));
1524+
None => {
1525+
// Channel closed unexpectedly
1526+
return self.flush_remaining_batch();
14921527
}
1493-
},
1494-
Some(None) => {
1495-
// One input partition finished
1496-
self.remaining_partitions -= 1;
1497-
if self.remaining_partitions == 0 {
1498-
// All input partitions finished
1499-
return Poll::Ready(None);
1500-
}
1501-
// Continue to poll for more data from other partitions
1502-
continue;
1503-
}
1504-
None => {
1505-
// Channel closed unexpectedly
1506-
return Poll::Ready(None);
15071528
}
15081529
}
1509-
}
1510-
StreamState::ReadingSpilled => {
1511-
// Poll spill stream for the spilled batch
1512-
match self.spill_stream.poll_next_unpin(cx) {
1513-
Poll::Ready(Some(Ok(batch))) => {
1514-
self.state = StreamState::ReadingMemory;
1515-
return Poll::Ready(Some(Ok(batch)));
1516-
}
1517-
Poll::Ready(Some(Err(e))) => {
1518-
return Poll::Ready(Some(Err(e)));
1519-
}
1520-
Poll::Ready(None) => {
1521-
// Spill stream ended, keep draining the memory channel
1522-
self.state = StreamState::ReadingMemory;
1523-
}
1524-
Poll::Pending => {
1525-
// Spilled batch not ready yet, must wait
1526-
// This preserves ordering by blocking until spill data arrives
1527-
return Poll::Pending;
1530+
StreamState::ReadingSpilled => {
1531+
// Poll spill stream for the spilled batch
1532+
match self.spill_stream.poll_next_unpin(cx) {
1533+
Poll::Ready(Some(Ok(batch))) => {
1534+
self.state = StreamState::ReadingMemory;
1535+
self.batch_coalescer.push_batch(batch)?;
1536+
break;
1537+
}
1538+
Poll::Ready(Some(Err(e))) => {
1539+
return Poll::Ready(Some(Err(e)));
1540+
}
1541+
Poll::Ready(None) => {
1542+
// Spill stream ended, keep draining the memory channel
1543+
self.state = StreamState::ReadingMemory;
1544+
}
1545+
Poll::Pending => {
1546+
// Spilled batch not ready yet, must wait
1547+
// This preserves ordering by blocking until spill data arrives
1548+
return Poll::Pending;
1549+
}
15281550
}
15291551
}
15301552
}
15311553
}
1554+
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
1555+
return Poll::Ready(Some(Ok(batch)));
1556+
}
15321557
}
15331558
}
15341559
}
@@ -1570,14 +1595,15 @@ mod tests {
15701595
};
15711596

15721597
use arrow::array::{ArrayRef, StringArray, UInt32Array};
1598+
use arrow::compute::sort;
15731599
use arrow::datatypes::{DataType, Field, Schema};
15741600
use datafusion_common::cast::as_string_array;
15751601
use datafusion_common::exec_err;
15761602
use datafusion_common::test_util::batches_to_sort_string;
15771603
use datafusion_common_runtime::JoinSet;
1604+
use datafusion_execution::config::SessionConfig;
15781605
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
15791606
use insta::assert_snapshot;
1580-
use itertools::Itertools;
15811607

15821608
#[tokio::test]
15831609
async fn one_to_many_round_robin() -> Result<()> {
@@ -1588,7 +1614,7 @@ mod tests {
15881614

15891615
// repartition from 1 input to 4 output
15901616
let output_partitions =
1591-
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1617+
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4), 8).await?;
15921618

15931619
assert_eq!(4, output_partitions.len());
15941620
assert_eq!(13, output_partitions[0].len());
@@ -1608,7 +1634,7 @@ mod tests {
16081634

16091635
// repartition from 3 input to 1 output
16101636
let output_partitions =
1611-
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1637+
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1), 8).await?;
16121638

16131639
assert_eq!(1, output_partitions.len());
16141640
assert_eq!(150, output_partitions[0].len());
@@ -1625,7 +1651,7 @@ mod tests {
16251651

16261652
// repartition from 3 input to 5 output
16271653
let output_partitions =
1628-
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1654+
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5), 8).await?;
16291655

16301656
assert_eq!(5, output_partitions.len());
16311657
assert_eq!(30, output_partitions[0].len());
@@ -1648,6 +1674,7 @@ mod tests {
16481674
&schema,
16491675
partitions,
16501676
Partitioning::Hash(vec![col("c0", &schema)?], 8),
1677+
8,
16511678
)
16521679
.await?;
16531680

@@ -1670,8 +1697,11 @@ mod tests {
16701697
schema: &SchemaRef,
16711698
input_partitions: Vec<Vec<RecordBatch>>,
16721699
partitioning: Partitioning,
1700+
batch_size: usize,
16731701
) -> Result<Vec<Vec<RecordBatch>>> {
1674-
let task_ctx = Arc::new(TaskContext::default());
1702+
let session_config = SessionConfig::new().with_batch_size(batch_size);
1703+
let task_ctx =
1704+
Arc::new(TaskContext::default().with_session_config(session_config));
16751705
// create physical plan
16761706
let exec =
16771707
TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
@@ -1702,7 +1732,8 @@ mod tests {
17021732
vec![partition.clone(), partition.clone(), partition.clone()];
17031733

17041734
// repartition from 3 input to 5 output
1705-
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1735+
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5), 8)
1736+
.await
17061737
});
17071738

17081739
let output_partitions = handle.join().await.unwrap().unwrap();
@@ -1898,7 +1929,9 @@ mod tests {
18981929
// with different compilers, we will compare the same execution with
18991930
// and without dropping the output stream.
19001931
async fn hash_repartition_with_dropping_output_stream() {
1901-
let task_ctx = Arc::new(TaskContext::default());
1932+
let session_config = SessionConfig::new().with_batch_size(4);
1933+
let task_ctx =
1934+
Arc::new(TaskContext::default().with_session_config(session_config));
19021935
let partitioning = Partitioning::Hash(
19031936
vec![Arc::new(crate::expressions::Column::new(
19041937
"my_awesome_field",
@@ -1950,14 +1983,21 @@ mod tests {
19501983
});
19511984
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
19521985

1953-
fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1954-
batch
1955-
.into_iter()
1956-
.sorted_by_key(|b| format!("{b:?}"))
1957-
.collect()
1958-
}
1986+
assert_eq!(1, batches_with_drop.len());
1987+
assert_eq!(1, batches_without_drop.len());
19591988

1960-
assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
1989+
assert_eq!(
1990+
sort(batches_without_drop[0].column(0), None)
1991+
.unwrap()
1992+
.as_any()
1993+
.downcast_ref::<StringArray>()
1994+
.unwrap(),
1995+
sort(batches_with_drop[0].column(0), None)
1996+
.unwrap()
1997+
.as_any()
1998+
.downcast_ref::<StringArray>()
1999+
.unwrap()
2000+
);
19612001
}
19622002

19632003
fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
@@ -2396,6 +2436,7 @@ mod test {
23962436
use arrow::compute::SortOptions;
23972437
use arrow::datatypes::{DataType, Field, Schema};
23982438
use datafusion_common::assert_batches_eq;
2439+
use datafusion_execution::config::SessionConfig;
23992440

24002441
use super::*;
24012442
use crate::test::TestMemoryExec;
@@ -2507,8 +2548,10 @@ mod test {
25072548
let runtime = RuntimeEnvBuilder::default()
25082549
.with_memory_limit(64, 1.0)
25092550
.build_arc()?;
2510-
2511-
let task_ctx = TaskContext::default().with_runtime(runtime);
2551+
let session_config = SessionConfig::new().with_batch_size(4);
2552+
let task_ctx = TaskContext::default()
2553+
.with_runtime(runtime)
2554+
.with_session_config(session_config);
25122555
let task_ctx = Arc::new(task_ctx);
25132556

25142557
// Create physical plan with order preservation

0 commit comments

Comments
 (0)