Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 131 additions & 21 deletions datafusion/physical-plan/src/windows/window_agg_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ impl ExecutionPlan for WindowAggExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let batch_size = context.session_config().batch_size();
let input = self.input.execute(partition, context)?;
let stream = Box::pin(WindowAggStream::new(
Arc::clone(&self.schema),
Expand All @@ -273,6 +274,7 @@ impl ExecutionPlan for WindowAggExec {
BaselineMetrics::new(&self.metrics, partition),
self.partition_by_sort_keys()?,
self.ordered_partition_by_indices.clone(),
batch_size,
)?);
Ok(stream)
}
Expand Down Expand Up @@ -327,6 +329,15 @@ pub struct WindowAggStream {
partition_by_sort_keys: Vec<PhysicalSortExpr>,
baseline_metrics: BaselineMetrics,
ordered_partition_by_indices: Vec<usize>,
/// Target output batch size. The fully-computed result is emitted in
/// slices of at most this many rows so downstream operators are not forced
/// to hold one batch sized to the entire input.
batch_size: usize,
/// The fully-computed result, set once the input is exhausted. Emitted
/// incrementally as zero-copy slices via `emit_offset`.
computed: Option<RecordBatch>,
/// Number of result rows already emitted from `computed`.
emit_offset: usize,
}

impl WindowAggStream {
Expand All @@ -338,6 +349,7 @@ impl WindowAggStream {
baseline_metrics: BaselineMetrics,
partition_by_sort_keys: Vec<PhysicalSortExpr>,
ordered_partition_by_indices: Vec<usize>,
batch_size: usize,
) -> Result<Self> {
// In WindowAggExec all partition by columns should be ordered.
assert_eq_or_internal_err!(
Expand All @@ -354,6 +366,10 @@ impl WindowAggStream {
baseline_metrics,
partition_by_sort_keys,
ordered_partition_by_indices,
// Guard against a zero batch size, which would never make progress.
batch_size: batch_size.max(1),
computed: None,
emit_offset: 0,
})
}

Expand Down Expand Up @@ -425,29 +441,60 @@ impl WindowAggStream {
return Poll::Ready(None);
}

loop {
return Poll::Ready(Some(match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
self.batches.push(batch);
continue;
// Phase 1: drain the input and compute the full result exactly once.
if self.computed.is_none() {
loop {
match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
self.batches.push(batch);
}
Some(Err(e)) => {
self.finished = true;
return Poll::Ready(Some(Err(e)));
}
None => {
// Release the input pipeline's resources before
// computing the final aggregates.
let input_schema = self.input.schema();
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
let Some(result) = self.compute_aggregates()? else {
self.finished = true;
return Poll::Ready(None);
};
self.computed = Some(result);
break;
}
}
Some(Err(e)) => Err(e),
None => {
// Release the input pipeline's resources before computing
// the final aggregates.
let input_schema = self.input.schema();
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
let Some(result) = self.compute_aggregates()? else {
return Poll::Ready(None);
};
self.finished = true;
// Empty record batches should not be emitted.
// They need to be treated as [`Option<RecordBatch>`]es and handled separately
debug_assert!(result.num_rows() > 0);
Ok(result)
}
}));
}
}

// Phase 2: emit the computed result in `batch_size` chunks. Slicing is
// zero-copy, so this only bounds the batch each downstream operator
// must hold at once; it does not re-copy the data.
Poll::Ready(self.next_output_batch().transpose())
}

/// Returns the next `batch_size`-row slice of the computed result, or
/// `None` once the whole result has been emitted.
fn next_output_batch(&mut self) -> Result<Option<RecordBatch>> {
let Some(computed) = self.computed.as_ref() else {
return Ok(None);
};
let total_rows = computed.num_rows();
if self.emit_offset >= total_rows {
self.finished = true;
return Ok(None);
}
let length = self.batch_size.min(total_rows - self.emit_offset);
let batch = computed.slice(self.emit_offset, length);
self.emit_offset += length;
if self.emit_offset >= total_rows {
self.finished = true;
}
// Empty record batches should not be emitted; `compute_aggregates`
// already returns `None` for an empty result, so each slice is non-empty.
debug_assert!(batch.num_rows() > 0);
Ok(Some(batch))
}
}

Expand Down Expand Up @@ -500,4 +547,67 @@ mod tests {
));
Ok(())
}

#[tokio::test]
async fn window_agg_exec_emits_batch_size_chunks() -> Result<()> {
use crate::common::collect;
use arrow::array::{ArrayRef, Int64Array};
use datafusion_execution::config::SessionConfig;

let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
// 10 rows in a single partition (no PARTITION BY).
let a: ArrayRef = Arc::new(Int64Array::from((0..10).collect::<Vec<i64>>()));
let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a])?;
let input =
TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?;

let args = vec![crate::expressions::col("a", &schema)?];
// Running COUNT over UNBOUNDED PRECEDING .. CURRENT ROW -> 1, 2, ..., 10.
let window_expr = create_window_expr(
&WindowFunctionDefinition::AggregateUDF(count_udaf()),
"count(a)".to_string(),
&args,
&[],
&[],
Arc::new(WindowFrame::new_bounds(
WindowFrameUnits::Rows,
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameBound::CurrentRow,
)),
Arc::clone(&schema),
false,
false,
None,
)?;
let window = Arc::new(WindowAggExec::try_new(vec![window_expr], input, true)?);

// A small batch size forces the single computed result to be emitted in
// multiple chunks instead of one batch sized to the whole input.
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(SessionConfig::new().with_batch_size(4)),
);

let stream = window.execute(0, task_ctx)?;
let batches = collect(stream).await?;

// 10 rows with batch_size 4 -> chunks of 4, 4, 2.
assert_eq!(
batches.iter().map(|b| b.num_rows()).collect::<Vec<_>>(),
vec![4, 4, 2]
);

// The running-count column is unaffected by chunking: it must read
// 1..=10 across the concatenation of the emitted slices.
let combined = concat_batches(&window.schema(), &batches)?;
let count_col = combined
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.expect("count column is Int64");
let expected = Int64Array::from((1..=10).collect::<Vec<i64>>());
assert_eq!(count_col, &expected);

Ok(())
}
}
Loading