Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl<AggrMode> AggregateHashTable<AggrMode> {
acc + state.group_values.size()
+ state.batch_group_indices.allocated_size()
}
AggregateHashTableState::OutputtingMaterializedFinal(output) => {
AggregateHashTableState::OutputtingMaterialized(output) => {
output.memory_size()
}
AggregateHashTableState::Done => 0,
Expand Down Expand Up @@ -214,15 +214,6 @@ impl<AggrMode> AggregateHashTable<AggrMode> {
}
}

pub(super) fn emit_to_for_batch_size(batch_size: usize, group_count: usize) -> EmitTo {
debug_assert!(batch_size > 0);
if group_count <= batch_size {
EmitTo::All
} else {
EmitTo::First(batch_size)
}
}

/// State and argument information for a single Aggregate
///
/// For example, for `SELECT COUNT(x), SUM(y WHERE z > 10) ...` there would be two
Expand Down Expand Up @@ -304,24 +295,25 @@ pub(super) enum AggregateHashTableState {
Building(AggregateHashTableBuffer),
/// Emitting results directly from group keys and aggregate state.
Outputting(AggregateHashTableBuffer),
/// Materialize all the output results, and then incrementally output in the `OutputtingMaterializedFinal` state.
/// Materialize all the output results, and then incrementally output in the `OutputtingMaterialized` state.
///
/// Note this is a temporary solution until the `GroupValues` issue is solved:
/// Issue: <https://github.com/apache/datafusion/issues/23178>
OutputtingMaterializedFinal(MaterializedFinalOutput),
OutputtingMaterialized(MaterializedAggregateOutput),
Done,
}

/// Fully evaluated final aggregate output and the next row offset to emit.
/// Fully evaluated aggregate output and the next row offset to emit.
///
/// Final aggregate evaluation consumes accumulator state, so final output is
/// materialized once and then sliced to honor `batch_size` across output polls.
pub(super) struct MaterializedFinalOutput {
/// Final aggregate evaluation consumes accumulator state, and partial terminal
/// output should not repeatedly renumber group values with `EmitTo::First`.
/// Materialize once and then slice to honor `batch_size` across output polls.
pub(super) struct MaterializedAggregateOutput {
batch: RecordBatch,
offset: usize,
}

impl MaterializedFinalOutput {
impl MaterializedAggregateOutput {
pub(super) fn new(batch: RecordBatch) -> Self {
Self { batch, offset: 0 }
}
Expand Down Expand Up @@ -496,7 +488,7 @@ mod tests {
use super::*;

#[test]
fn materialized_final_output_slices_batches_until_exhausted() -> Result<()> {
fn materialized_aggregate_output_slices_batches_until_exhausted() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new(
"group_col",
DataType::Int32,
Expand All @@ -506,7 +498,7 @@ mod tests {
schema,
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
)?;
let mut output = MaterializedFinalOutput::new(batch);
let mut output = MaterializedAggregateOutput::new(batch);

assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![1, 2]);
assert_eq!(int32_values(&output.next_batch(2).unwrap(), 0), vec![3, 4]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::aggregates::AggregateExec;

use super::common::{
AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, FinalMarker,
MaterializedFinalOutput,
MaterializedAggregateOutput,
};

/// Methods specific to the aggregate hash table used in the final aggregation stage.
Expand Down Expand Up @@ -57,8 +57,8 @@ impl AggregateHashTable<FinalMarker> {
) -> Result<Option<RecordBatch>> {
let output_schema = Arc::clone(&self.output_schema);
let batch_size = self.batch_size;
// Take ownership of the output state. Note `emit_next_materialized_batch`
// updates state after it emits a materialized slice.
// Take ownership of the output state. `emit_next_materialized_batch`
// restores `self.state` to `OutputtingMaterialized` or `Done`.
match std::mem::replace(&mut self.state, AggregateHashTableState::Done) {
AggregateHashTableState::Outputting(state) => {
if state.group_values.is_empty() {
Expand All @@ -68,7 +68,7 @@ impl AggregateHashTable<FinalMarker> {
let output = self.materialize_final_output(state, output_schema)?;
Ok(self.emit_next_materialized_batch(output, batch_size))
}
AggregateHashTableState::OutputtingMaterializedFinal(output) => {
AggregateHashTableState::OutputtingMaterialized(output) => {
Ok(self.emit_next_materialized_batch(output, batch_size))
}
AggregateHashTableState::Done => Ok(None),
Expand All @@ -82,7 +82,7 @@ impl AggregateHashTable<FinalMarker> {
&self,
mut state: AggregateHashTableBuffer,
output_schema: SchemaRef,
) -> Result<MaterializedFinalOutput> {
) -> Result<MaterializedAggregateOutput> {
// Final aggregate evaluation consumes accumulator state. Evaluate all
// groups once, then slice the materialized batch on subsequent polls.
let emit_to = EmitTo::All;
Expand All @@ -96,19 +96,19 @@ impl AggregateHashTable<FinalMarker> {

let batch = RecordBatch::try_new(output_schema, output)?;
debug_assert!(batch.num_rows() > 0);
Ok(MaterializedFinalOutput::new(batch))
Ok(MaterializedAggregateOutput::new(batch))
}

fn emit_next_materialized_batch(
&mut self,
mut output: MaterializedFinalOutput,
mut output: MaterializedAggregateOutput,
batch_size: usize,
) -> Option<RecordBatch> {
let batch = output.next_batch(batch_size);
if output.is_exhausted() {
self.state = AggregateHashTableState::Done;
} else {
self.state = AggregateHashTableState::OutputtingMaterializedFinal(output);
self.state = AggregateHashTableState::OutputtingMaterialized(output);
}
batch
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ use arrow::array::{ArrayRef, BooleanArray, new_null_array};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
use datafusion_expr::EmitTo;

use crate::aggregates::group_values::new_group_values;
use crate::aggregates::order::GroupOrdering;
use crate::aggregates::{AggregateExec, group_id_array, max_duplicate_ordinal};

use super::common::{
AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState,
EvaluatedAccumulatorArgs, HashAggregateAccumulator, PartialMarker, PartialSkipMarker,
emit_to_for_batch_size,
EvaluatedAccumulatorArgs, HashAggregateAccumulator, MaterializedAggregateOutput,
PartialMarker, PartialSkipMarker,
};

/// Methods specific to the aggregate hash table used in the partial aggregation stage.
Expand Down Expand Up @@ -62,43 +63,60 @@ impl AggregateHashTable<PartialMarker> {
) -> Result<Option<RecordBatch>> {
let output_schema = Arc::clone(&self.output_schema);
let batch_size = self.batch_size;
match &mut self.state {
// Take ownership of the output state. `emit_next_materialized_batch`

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2010YOUY01 I think your plan to refactor the aggregate code so it is easier to work with / understand is already paying off 🙏

// restores `self.state` to `OutputtingMaterialized` or `Done`.
match std::mem::replace(&mut self.state, AggregateHashTableState::Done) {
AggregateHashTableState::Outputting(state) => {
if state.group_values.is_empty() {
self.state = AggregateHashTableState::Done;
return Ok(None);
}

let emit_to =
emit_to_for_batch_size(batch_size, state.group_values.len());
let timer = self.group_by_metrics.emitting_time.timer();
let mut output = state.group_values.emit(emit_to)?;

for acc in state.accumulators.iter_mut() {
output.extend(acc.state(emit_to)?);
}
let done = state.group_values.is_empty();
drop(timer);

let batch = RecordBatch::try_new(output_schema, output)?;
debug_assert!(batch.num_rows() > 0);
if done {
self.state = AggregateHashTableState::Done;
}
Ok(Some(batch))
let output = self.materialize_partial_output(state, output_schema)?;
Ok(self.emit_next_materialized_batch(output, batch_size))
}
AggregateHashTableState::OutputtingMaterialized(output) => {
Ok(self.emit_next_materialized_batch(output, batch_size))
}
AggregateHashTableState::Done => Ok(None),
AggregateHashTableState::Building(_) => {
internal_err!("next_output_batch must be called in the outputting state")
}
AggregateHashTableState::OutputtingMaterializedFinal(_) => {
internal_err!(
"partial aggregate output should not materialize final output"
)
}
}
}

fn materialize_partial_output(
&self,
mut state: AggregateHashTableBuffer,
output_schema: SchemaRef,
) -> Result<MaterializedAggregateOutput> {
let emit_to = EmitTo::All;
let timer = self.group_by_metrics.emitting_time.timer();
let mut output = state.group_values.emit(emit_to)?;

for acc in state.accumulators.iter_mut() {
output.extend(acc.state(emit_to)?);
}
drop(timer);

let batch = RecordBatch::try_new(output_schema, output)?;
debug_assert!(batch.num_rows() > 0);
Ok(MaterializedAggregateOutput::new(batch))
}

fn emit_next_materialized_batch(
&mut self,
mut output: MaterializedAggregateOutput,
batch_size: usize,
) -> Option<RecordBatch> {
let batch = output.next_batch(batch_size);
if output.is_exhausted() {
self.state = AggregateHashTableState::Done;
} else {
self.state = AggregateHashTableState::OutputtingMaterialized(output);
}
batch
}

pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool {
self.state
.building()
Expand Down
Loading