diff --git a/datafusion/physical-optimizer/src/coalesce_batches.rs b/datafusion/physical-optimizer/src/coalesce_batches.rs index 61e4c0e7f180..8daff3172103 100644 --- a/datafusion/physical-optimizer/src/coalesce_batches.rs +++ b/datafusion/physical-optimizer/src/coalesce_batches.rs @@ -29,7 +29,7 @@ use datafusion_common::{ use datafusion_physical_expr::Partitioning; use datafusion_physical_plan::{ async_func::AsyncFuncExec, coalesce_batches::CoalesceBatchesExec, - joins::HashJoinExec, repartition::RepartitionExec, ExecutionPlan, + repartition::RepartitionExec, ExecutionPlan, }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -58,17 +58,15 @@ impl PhysicalOptimizerRule for CoalesceBatches { let target_batch_size = config.execution.batch_size; plan.transform_up(|plan| { let plan_any = plan.as_any(); - let wrap_in_coalesce = plan_any.downcast_ref::().is_some() - // Don't need to add CoalesceBatchesExec after a round robin RepartitionExec - || plan_any - .downcast_ref::() - .map(|repart_exec| { - !matches!( - repart_exec.partitioning().clone(), - Partitioning::RoundRobinBatch(_) - ) - }) - .unwrap_or(false); + let wrap_in_coalesce = plan_any + .downcast_ref::() + .map(|repart_exec| { + !matches!( + repart_exec.partitioning().clone(), + Partitioning::RoundRobinBatch(_) + ) + }) + .unwrap_or(false); if wrap_in_coalesce { Ok(Transformed::yes(Arc::new(CoalesceBatchesExec::new( diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index 860fc81bf665..3954ab4b6fe5 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -68,6 +68,10 @@ impl LimitedBatchCoalescer { } } + pub fn is_finished(&self) -> bool { + self.finished + } + /// Return the schema of the output batches pub fn schema(&self) -> SchemaRef { self.inner.schema() @@ -125,6 +129,11 @@ impl LimitedBatchCoalescer { self.inner.is_empty() } + /// Return the total number of rows that have been pushed to this coalescer + pub fn total_rows(&self) -> usize { + self.total_rows + } + /// Complete the current buffered batch and finish the coalescer /// /// Any subsequent calls to `push_batch()` will return an Err diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index a50a6551db4d..bd14693f74aa 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use std::task::Poll; +use crate::coalesce::LimitedBatchCoalescer; use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::hash_join::shared_bounds::{ PartitionBuildDataReport, SharedBuildAccumulator, @@ -216,6 +217,11 @@ pub(super) struct HashJoinStream { /// Partitioning mode to use mode: PartitionMode, + + /// Batch coalescer for coalescing batches from the probe side + batch_coalescer: LimitedBatchCoalescer, + /// Flag to track if we've handled empty output to avoid returning empty schema + handled_empty_output: bool, } impl RecordBatchStream for HashJoinStream { @@ -351,7 +357,7 @@ impl HashJoinStream { ) -> Self { Self { partition, - schema, + schema: Arc::clone(&schema), on_right, filter, join_type, @@ -368,6 +374,8 @@ impl HashJoinStream { build_accumulator, build_waiter: None, mode, + batch_coalescer: LimitedBatchCoalescer::new(schema, batch_size, None), + handled_empty_output: false, } } @@ -396,7 +404,18 @@ impl HashJoinStream { let poll = handle_state!(self.process_unmatched_build_batch()); self.join_metrics.baseline.record_poll(poll) } - HashJoinStreamState::Completed => Poll::Ready(None), + HashJoinStreamState::Completed => { + if !self.handled_empty_output { + self.handled_empty_output = true; + if self.batch_coalescer.total_rows() == 0 { + let empty_batch = + RecordBatch::new_empty(Arc::clone(&self.schema)); + let poll = Poll::Ready(Some(Ok(empty_batch))); + return self.join_metrics.baseline.record_poll(poll); + } + } + Poll::Ready(None) + } }; } } @@ -518,6 +537,11 @@ impl HashJoinStream { fn process_probe_batch( &mut self, ) -> Result>> { + // Flush any remaining buffered batch + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + let state = self.state.try_as_process_probe_batch_mut()?; let build_side = self.build_side.try_as_ready_mut()?; @@ -650,6 +674,8 @@ impl HashJoinStream { )? }; + self.batch_coalescer.push_batch(result)?; + timer.done(); if next_offset.is_none() { @@ -662,7 +688,11 @@ impl HashJoinStream { ) }; - Ok(StatefulStreamResult::Ready(Some(result))) + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + Ok(StatefulStreamResult::Continue) } /// Processes unmatched build-side rows for certain join types and produces output batch @@ -671,15 +701,37 @@ impl HashJoinStream { fn process_unmatched_build_batch( &mut self, ) -> Result>> { + // Flush any remaining buffered batch before processing unmatched rows let timer = self.join_metrics.join_time.timer(); + // This will finish any remaining buffered batch and in the case where batch coalescer + // is finished + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + // Return if the batch coalescer is finished + if self.batch_coalescer.is_finished() { + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + if !need_produce_result_in_final(self.join_type) { + self.batch_coalescer.finish()?; + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + self.state = HashJoinStreamState::Completed; return Ok(StatefulStreamResult::Continue); } let build_side = self.build_side.try_as_ready()?; if !build_side.left_data.report_probe_completed() { + self.batch_coalescer.finish()?; + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } self.state = HashJoinStreamState::Completed; return Ok(StatefulStreamResult::Continue); } @@ -706,11 +758,19 @@ impl HashJoinStream { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); } + + self.batch_coalescer.push_batch(result?)?; + + self.batch_coalescer.finish()?; + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + timer.done(); self.state = HashJoinStreamState::Completed; - Ok(StatefulStreamResult::Ready(Some(result?))) + Ok(StatefulStreamResult::Continue) } }