@@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
3030use super :: {
3131 DisplayAs , ExecutionPlanProperties , RecordBatchStream , SendableRecordBatchStream ,
3232} ;
33+ use crate :: coalesce:: LimitedBatchCoalescer ;
3334use crate :: execution_plan:: { CardinalityEffect , EvaluationType , SchedulingType } ;
3435use crate :: hash_utils:: create_hashes;
3536use crate :: metrics:: { BaselineMetrics , SpillMetrics } ;
@@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
932933 spill_stream,
933934 1 , // Each receiver handles one input partition
934935 BaselineMetrics :: new ( & metrics, partition) ,
936+ context. session_config ( ) . batch_size ( ) / num_input_partitions,
935937 ) ) as SendableRecordBatchStream
936938 } )
937939 . collect :: < Vec < _ > > ( ) ;
@@ -959,7 +961,6 @@ impl ExecutionPlan for RepartitionExec {
959961 . into_iter ( )
960962 . next ( )
961963 . expect ( "at least one spill reader should exist" ) ;
962-
963964 Ok ( Box :: pin ( PerPartitionStream :: new (
964965 schema_captured,
965966 rx. into_iter ( )
@@ -970,6 +971,7 @@ impl ExecutionPlan for RepartitionExec {
970971 spill_stream,
971972 num_input_partitions,
972973 BaselineMetrics :: new ( & metrics, partition) ,
974+ context. session_config ( ) . batch_size ( ) ,
973975 ) ) as SendableRecordBatchStream )
974976 }
975977 } )
@@ -1427,9 +1429,12 @@ struct PerPartitionStream {
14271429
14281430 /// Execution metrics
14291431 baseline_metrics : BaselineMetrics ,
1432+
1433+ batch_coalescer : LimitedBatchCoalescer ,
14301434}
14311435
14321436impl PerPartitionStream {
1437+ #[ allow( clippy:: too_many_arguments) ]
14331438 fn new (
14341439 schema : SchemaRef ,
14351440 receiver : DistributionReceiver < MaybeBatch > ,
@@ -1438,16 +1443,29 @@ impl PerPartitionStream {
14381443 spill_stream : SendableRecordBatchStream ,
14391444 num_input_partitions : usize ,
14401445 baseline_metrics : BaselineMetrics ,
1446+ batch_size : usize ,
14411447 ) -> Self {
14421448 Self {
1443- schema,
1449+ schema : Arc :: clone ( & schema ) ,
14441450 receiver,
14451451 _drop_helper : drop_helper,
14461452 reservation,
14471453 spill_stream,
14481454 state : StreamState :: ReadingMemory ,
14491455 remaining_partitions : num_input_partitions,
14501456 baseline_metrics,
1457+ batch_coalescer : LimitedBatchCoalescer :: new ( schema, batch_size, None ) ,
1458+ }
1459+ }
1460+
1461+ fn flush_remaining_batch (
1462+ & mut self ,
1463+ ) -> Poll < Option < std:: result:: Result < RecordBatch , DataFusionError > > > {
1464+ // Flush any remaining buffered batch
1465+ match self . batch_coalescer . finish ( ) {
1466+ Ok ( ( ) ) => Poll :: Ready ( self . batch_coalescer . next_completed_batch ( ) . map ( Ok ) ) ,
1467+
1468+ Err ( e) => Poll :: Ready ( Some ( Err ( e) ) ) ,
14511469 }
14521470 }
14531471
@@ -1460,75 +1478,82 @@ impl PerPartitionStream {
14601478 let _timer = cloned_time. timer ( ) ;
14611479
14621480 loop {
1463- match self . state {
1464- StreamState :: ReadingMemory => {
1465- // Poll the memory channel for next message
1466- let value = match self . receiver . recv ( ) . poll_unpin ( cx) {
1467- Poll :: Ready ( v) => v,
1468- Poll :: Pending => {
1469- // Nothing from channel, wait
1470- return Poll :: Pending ;
1471- }
1472- } ;
1473-
1474- match value {
1475- Some ( Some ( v) ) => match v {
1476- Ok ( RepartitionBatch :: Memory ( batch) ) => {
1477- // Release memory and return batch
1478- self . reservation
1479- . lock ( )
1480- . shrink ( batch. get_array_memory_size ( ) ) ;
1481- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1481+ loop {
1482+ match self . state {
1483+ StreamState :: ReadingMemory => {
1484+ // Poll the memory channel for next message
1485+ let value = match self . receiver . recv ( ) . poll_unpin ( cx) {
1486+ Poll :: Ready ( v) => v,
1487+ Poll :: Pending => {
1488+ // Nothing from channel, wait
1489+ return Poll :: Pending ;
14821490 }
1483- Ok ( RepartitionBatch :: Spilled ) => {
1484- // Batch was spilled, transition to reading from spill stream
1485- // We must block on spill stream until we get the batch
1486- // to preserve ordering
1487- self . state = StreamState :: ReadingSpilled ;
1491+ } ;
1492+
1493+ match value {
1494+ Some ( Some ( v) ) => match v {
1495+ Ok ( RepartitionBatch :: Memory ( batch) ) => {
1496+ // Release memory and return batch
1497+ self . reservation
1498+ . lock ( )
1499+ . shrink ( batch. get_array_memory_size ( ) ) ;
1500+ self . batch_coalescer . push_batch ( batch) ?;
1501+ break ;
1502+ }
1503+ Ok ( RepartitionBatch :: Spilled ) => {
1504+ // Batch was spilled, transition to reading from spill stream
1505+ // We must block on spill stream until we get the batch
1506+ // to preserve ordering
1507+ self . state = StreamState :: ReadingSpilled ;
1508+ continue ;
1509+ }
1510+ Err ( e) => {
1511+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
1512+ }
1513+ } ,
1514+ Some ( None ) => {
1515+ // One input partition finished
1516+ self . remaining_partitions -= 1 ;
1517+ if self . remaining_partitions == 0 {
1518+ // All input partitions finished
1519+ return self . flush_remaining_batch ( ) ;
1520+ }
1521+ // Continue to poll for more data from other partitions
14881522 continue ;
14891523 }
1490- Err ( e) => {
1491- return Poll :: Ready ( Some ( Err ( e) ) ) ;
1524+ None => {
1525+ // Channel closed unexpectedly
1526+ return self . flush_remaining_batch ( ) ;
14921527 }
1493- } ,
1494- Some ( None ) => {
1495- // One input partition finished
1496- self . remaining_partitions -= 1 ;
1497- if self . remaining_partitions == 0 {
1498- // All input partitions finished
1499- return Poll :: Ready ( None ) ;
1500- }
1501- // Continue to poll for more data from other partitions
1502- continue ;
1503- }
1504- None => {
1505- // Channel closed unexpectedly
1506- return Poll :: Ready ( None ) ;
15071528 }
15081529 }
1509- }
1510- StreamState :: ReadingSpilled => {
1511- // Poll spill stream for the spilled batch
1512- match self . spill_stream . poll_next_unpin ( cx) {
1513- Poll :: Ready ( Some ( Ok ( batch) ) ) => {
1514- self . state = StreamState :: ReadingMemory ;
1515- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1516- }
1517- Poll :: Ready ( Some ( Err ( e) ) ) => {
1518- return Poll :: Ready ( Some ( Err ( e) ) ) ;
1519- }
1520- Poll :: Ready ( None ) => {
1521- // Spill stream ended, keep draining the memory channel
1522- self . state = StreamState :: ReadingMemory ;
1523- }
1524- Poll :: Pending => {
1525- // Spilled batch not ready yet, must wait
1526- // This preserves ordering by blocking until spill data arrives
1527- return Poll :: Pending ;
1530+ StreamState :: ReadingSpilled => {
1531+ // Poll spill stream for the spilled batch
1532+ match self . spill_stream . poll_next_unpin ( cx) {
1533+ Poll :: Ready ( Some ( Ok ( batch) ) ) => {
1534+ self . state = StreamState :: ReadingMemory ;
1535+ self . batch_coalescer . push_batch ( batch) ?;
1536+ break ;
1537+ }
1538+ Poll :: Ready ( Some ( Err ( e) ) ) => {
1539+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
1540+ }
1541+ Poll :: Ready ( None ) => {
1542+ // Spill stream ended, keep draining the memory channel
1543+ self . state = StreamState :: ReadingMemory ;
1544+ }
1545+ Poll :: Pending => {
1546+ // Spilled batch not ready yet, must wait
1547+ // This preserves ordering by blocking until spill data arrives
1548+ return Poll :: Pending ;
1549+ }
15281550 }
15291551 }
15301552 }
15311553 }
1554+ if let Some ( batch) = self . batch_coalescer . next_completed_batch ( ) {
1555+ return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1556+ }
15321557 }
15331558 }
15341559}
@@ -1570,14 +1595,15 @@ mod tests {
15701595 } ;
15711596
15721597 use arrow:: array:: { ArrayRef , StringArray , UInt32Array } ;
1598+ use arrow:: compute:: sort;
15731599 use arrow:: datatypes:: { DataType , Field , Schema } ;
15741600 use datafusion_common:: cast:: as_string_array;
15751601 use datafusion_common:: exec_err;
15761602 use datafusion_common:: test_util:: batches_to_sort_string;
15771603 use datafusion_common_runtime:: JoinSet ;
1604+ use datafusion_execution:: config:: SessionConfig ;
15781605 use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
15791606 use insta:: assert_snapshot;
1580- use itertools:: Itertools ;
15811607
15821608 #[ tokio:: test]
15831609 async fn one_to_many_round_robin ( ) -> Result < ( ) > {
@@ -1588,7 +1614,7 @@ mod tests {
15881614
15891615 // repartition from 1 input to 4 output
15901616 let output_partitions =
1591- repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 4 ) ) . await ?;
1617+ repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 4 ) , 8 ) . await ?;
15921618
15931619 assert_eq ! ( 4 , output_partitions. len( ) ) ;
15941620 assert_eq ! ( 13 , output_partitions[ 0 ] . len( ) ) ;
@@ -1608,7 +1634,7 @@ mod tests {
16081634
16091635 // repartition from 3 input to 1 output
16101636 let output_partitions =
1611- repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 1 ) ) . await ?;
1637+ repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 1 ) , 8 ) . await ?;
16121638
16131639 assert_eq ! ( 1 , output_partitions. len( ) ) ;
16141640 assert_eq ! ( 150 , output_partitions[ 0 ] . len( ) ) ;
@@ -1625,7 +1651,7 @@ mod tests {
16251651
16261652 // repartition from 3 input to 5 output
16271653 let output_partitions =
1628- repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) ) . await ?;
1654+ repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) , 8 ) . await ?;
16291655
16301656 assert_eq ! ( 5 , output_partitions. len( ) ) ;
16311657 assert_eq ! ( 30 , output_partitions[ 0 ] . len( ) ) ;
@@ -1648,6 +1674,7 @@ mod tests {
16481674 & schema,
16491675 partitions,
16501676 Partitioning :: Hash ( vec ! [ col( "c0" , & schema) ?] , 8 ) ,
1677+ 8 ,
16511678 )
16521679 . await ?;
16531680
@@ -1670,8 +1697,11 @@ mod tests {
16701697 schema : & SchemaRef ,
16711698 input_partitions : Vec < Vec < RecordBatch > > ,
16721699 partitioning : Partitioning ,
1700+ batch_size : usize ,
16731701 ) -> Result < Vec < Vec < RecordBatch > > > {
1674- let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
1702+ let session_config = SessionConfig :: new ( ) . with_batch_size ( batch_size) ;
1703+ let task_ctx =
1704+ Arc :: new ( TaskContext :: default ( ) . with_session_config ( session_config) ) ;
16751705 // create physical plan
16761706 let exec =
16771707 TestMemoryExec :: try_new_exec ( & input_partitions, Arc :: clone ( schema) , None ) ?;
@@ -1702,7 +1732,8 @@ mod tests {
17021732 vec ! [ partition. clone( ) , partition. clone( ) , partition. clone( ) ] ;
17031733
17041734 // repartition from 3 input to 5 output
1705- repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) ) . await
1735+ repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) , 8 )
1736+ . await
17061737 } ) ;
17071738
17081739 let output_partitions = handle. join ( ) . await . unwrap ( ) . unwrap ( ) ;
@@ -1898,7 +1929,9 @@ mod tests {
18981929 // with different compilers, we will compare the same execution with
18991930 // and without dropping the output stream.
19001931 async fn hash_repartition_with_dropping_output_stream ( ) {
1901- let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
1932+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 4 ) ;
1933+ let task_ctx =
1934+ Arc :: new ( TaskContext :: default ( ) . with_session_config ( session_config) ) ;
19021935 let partitioning = Partitioning :: Hash (
19031936 vec ! [ Arc :: new( crate :: expressions:: Column :: new(
19041937 "my_awesome_field" ,
@@ -1950,14 +1983,21 @@ mod tests {
19501983 } ) ;
19511984 let batches_with_drop = crate :: common:: collect ( output_stream1) . await . unwrap ( ) ;
19521985
1953- fn sort ( batch : Vec < RecordBatch > ) -> Vec < RecordBatch > {
1954- batch
1955- . into_iter ( )
1956- . sorted_by_key ( |b| format ! ( "{b:?}" ) )
1957- . collect ( )
1958- }
1986+ assert_eq ! ( 1 , batches_with_drop. len( ) ) ;
1987+ assert_eq ! ( 1 , batches_without_drop. len( ) ) ;
19591988
1960- assert_eq ! ( sort( batches_without_drop) , sort( batches_with_drop) ) ;
1989+ assert_eq ! (
1990+ sort( batches_without_drop[ 0 ] . column( 0 ) , None )
1991+ . unwrap( )
1992+ . as_any( )
1993+ . downcast_ref:: <StringArray >( )
1994+ . unwrap( ) ,
1995+ sort( batches_with_drop[ 0 ] . column( 0 ) , None )
1996+ . unwrap( )
1997+ . as_any( )
1998+ . downcast_ref:: <StringArray >( )
1999+ . unwrap( )
2000+ ) ;
19612001 }
19622002
19632003 fn str_batches_to_vec ( batches : & [ RecordBatch ] ) -> Vec < & str > {
@@ -2396,6 +2436,7 @@ mod test {
23962436 use arrow:: compute:: SortOptions ;
23972437 use arrow:: datatypes:: { DataType , Field , Schema } ;
23982438 use datafusion_common:: assert_batches_eq;
2439+ use datafusion_execution:: config:: SessionConfig ;
23992440
24002441 use super :: * ;
24012442 use crate :: test:: TestMemoryExec ;
@@ -2507,8 +2548,10 @@ mod test {
25072548 let runtime = RuntimeEnvBuilder :: default ( )
25082549 . with_memory_limit ( 64 , 1.0 )
25092550 . build_arc ( ) ?;
2510-
2511- let task_ctx = TaskContext :: default ( ) . with_runtime ( runtime) ;
2551+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 4 ) ;
2552+ let task_ctx = TaskContext :: default ( )
2553+ . with_runtime ( runtime)
2554+ . with_session_config ( session_config) ;
25122555 let task_ctx = Arc :: new ( task_ctx) ;
25132556
25142557 // Create physical plan with order preservation
0 commit comments