Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions datafusion/physical-optimizer/src/coalesce_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion_common::{
use datafusion_physical_expr::Partitioning;
use datafusion_physical_plan::{
async_func::AsyncFuncExec, coalesce_batches::CoalesceBatchesExec,
joins::HashJoinExec, repartition::RepartitionExec, ExecutionPlan,
repartition::RepartitionExec, ExecutionPlan,
};

use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
Expand Down Expand Up @@ -58,17 +58,15 @@ impl PhysicalOptimizerRule for CoalesceBatches {
let target_batch_size = config.execution.batch_size;
plan.transform_up(|plan| {
let plan_any = plan.as_any();
let wrap_in_coalesce = plan_any.downcast_ref::<HashJoinExec>().is_some()
// Don't need to add CoalesceBatchesExec after a round robin RepartitionExec
|| plan_any
.downcast_ref::<RepartitionExec>()
.map(|repart_exec| {
!matches!(
repart_exec.partitioning().clone(),
Partitioning::RoundRobinBatch(_)
)
})
.unwrap_or(false);
let wrap_in_coalesce = plan_any
.downcast_ref::<RepartitionExec>()
.map(|repart_exec| {
!matches!(
repart_exec.partitioning().clone(),
Partitioning::RoundRobinBatch(_)
)
})
.unwrap_or(false);

if wrap_in_coalesce {
Ok(Transformed::yes(Arc::new(CoalesceBatchesExec::new(
Expand Down
9 changes: 9 additions & 0 deletions datafusion/physical-plan/src/coalesce/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ impl LimitedBatchCoalescer {
}
}

pub fn is_finished(&self) -> bool {
self.finished
}

/// Return the schema of the output batches
pub fn schema(&self) -> SchemaRef {
self.inner.schema()
Expand Down Expand Up @@ -125,6 +129,11 @@ impl LimitedBatchCoalescer {
self.inner.is_empty()
}

/// Return the total number of rows that have been pushed to this coalescer
pub fn total_rows(&self) -> usize {
self.total_rows
}

/// Complete the current buffered batch and finish the coalescer
///
/// Any subsequent calls to `push_batch()` will return an Err
Expand Down
68 changes: 64 additions & 4 deletions datafusion/physical-plan/src/joins/hash_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
use std::sync::Arc;
use std::task::Poll;

use crate::coalesce::LimitedBatchCoalescer;
use crate::joins::hash_join::exec::JoinLeftData;
use crate::joins::hash_join::shared_bounds::{
PartitionBuildDataReport, SharedBuildAccumulator,
Expand Down Expand Up @@ -216,6 +217,11 @@ pub(super) struct HashJoinStream {

/// Partitioning mode to use
mode: PartitionMode,

/// Batch coalescer for coalescing batches from the probe side
batch_coalescer: LimitedBatchCoalescer,
/// Flag to track if we've handled empty output to avoid returning empty schema
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #18859

handled_empty_output: bool,
}

impl RecordBatchStream for HashJoinStream {
Expand Down Expand Up @@ -351,7 +357,7 @@ impl HashJoinStream {
) -> Self {
Self {
partition,
schema,
schema: Arc::clone(&schema),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to clone here ? Usually this happens at the call site

on_right,
filter,
join_type,
Expand All @@ -368,6 +374,8 @@ impl HashJoinStream {
build_accumulator,
build_waiter: None,
mode,
batch_coalescer: LimitedBatchCoalescer::new(schema, batch_size, None),
handled_empty_output: false,
}
}

Expand Down Expand Up @@ -396,7 +404,18 @@ impl HashJoinStream {
let poll = handle_state!(self.process_unmatched_build_batch());
self.join_metrics.baseline.record_poll(poll)
}
HashJoinStreamState::Completed => Poll::Ready(None),
HashJoinStreamState::Completed => {
if !self.handled_empty_output {
self.handled_empty_output = true;
if self.batch_coalescer.total_rows() == 0 {
let empty_batch =
RecordBatch::new_empty(Arc::clone(&self.schema));
let poll = Poll::Ready(Some(Ok(empty_batch)));
return self.join_metrics.baseline.record_poll(poll);
}
}
Poll::Ready(None)
}
};
}
}
Expand Down Expand Up @@ -518,6 +537,11 @@ impl HashJoinStream {
fn process_probe_batch(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
// Flush any remaining buffered batch
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
return Ok(StatefulStreamResult::Ready(Some(batch)));
}

let state = self.state.try_as_process_probe_batch_mut()?;
let build_side = self.build_side.try_as_ready_mut()?;

Expand Down Expand Up @@ -650,6 +674,8 @@ impl HashJoinStream {
)?
};

self.batch_coalescer.push_batch(result)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that ignoring PushBatchStatus::LimitReached result is intentional. Same at line 762 below.


timer.done();

if next_offset.is_none() {
Expand All @@ -662,7 +688,11 @@ impl HashJoinStream {
)
};

Ok(StatefulStreamResult::Ready(Some(result)))
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
return Ok(StatefulStreamResult::Ready(Some(batch)));
}

Ok(StatefulStreamResult::Continue)
}

/// Processes unmatched build-side rows for certain join types and produces output batch
Expand All @@ -671,15 +701,37 @@ impl HashJoinStream {
fn process_unmatched_build_batch(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
// Flush any remaining buffered batch before processing unmatched rows
let timer = self.join_metrics.join_time.timer();

// This will finish any remaining buffered batch and in the case where batch coalescer
// is finished
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
return Ok(StatefulStreamResult::Ready(Some(batch)));
}

// Return if the batch coalescer is finished
if self.batch_coalescer.is_finished() {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
}

if !need_produce_result_in_final(self.join_type) {
self.batch_coalescer.finish()?;
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
return Ok(StatefulStreamResult::Ready(Some(batch)));
}

self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
}

let build_side = self.build_side.try_as_ready()?;
if !build_side.left_data.report_probe_completed() {
self.batch_coalescer.finish()?;
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
return Ok(StatefulStreamResult::Ready(Some(batch)));
}
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
}
Expand All @@ -706,11 +758,19 @@ impl HashJoinStream {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
}

self.batch_coalescer.push_batch(result?)?;

self.batch_coalescer.finish()?;
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
return Ok(StatefulStreamResult::Ready(Some(batch)));
}

timer.done();

self.state = HashJoinStreamState::Completed;

Ok(StatefulStreamResult::Ready(Some(result?)))
Ok(StatefulStreamResult::Continue)
}
}

Expand Down
Loading