Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions fuzz/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ use tracing::debug;
use vortex_array::ArrayRef;
use vortex_array::DynArray;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::aggregate_fn::fns::sum::sum;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::arbitrary::ArbitraryArray;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::compute::MinMaxResult;
use vortex_array::compute::min_max;
use vortex_array::compute::sum;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::scalar::Scalar;
Expand All @@ -68,6 +69,7 @@ use vortex_error::vortex_panic;
use vortex_mask::Mask;
use vortex_utils::aliases::hash_set::HashSet;

use crate::SESSION;
use crate::error::Backtrace;
use crate::error::VortexFuzzError;
use crate::error::VortexFuzzResult;
Expand Down Expand Up @@ -173,6 +175,8 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
let array = ArbitraryArray::arbitrary(u)?.0;
let mut current_array = array.to_array();

let mut ctx = SESSION.create_execution_ctx();

let mut valid_actions = actions_for_dtype(current_array.dtype())
.into_iter()
.collect::<Vec<_>>();
Expand Down Expand Up @@ -330,6 +334,7 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction {
current_array
.to_canonical()
.vortex_expect("to_canonical should succeed in fuzz test"),
&mut ctx,
)
.vortex_expect("sum_canonical_array should succeed in fuzz test");
(Action::Sum, ExpectedValue::Scalar(sum_result))
Expand Down Expand Up @@ -566,6 +571,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult<bool> {
let FuzzArrayAction { array, actions } = fuzz_action;
let mut current_array = array.to_array();

let mut ctx = SESSION.create_execution_ctx();

debug!(
"Initial array:\nTree:\n{}Values:\n{:#}",
current_array.display_tree(),
Expand Down Expand Up @@ -640,8 +647,8 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> VortexFuzzResult<bool> {
current_array = cast_result;
}
Action::Sum => {
let sum_result =
sum(&current_array).vortex_expect("sum operation should succeed in fuzz test");
let sum_result = sum(&current_array, &mut ctx)
.vortex_expect("sum operation should succeed in fuzz test");
assert_scalar_eq(&expected.scalar(), &sum_result, i)?;
}
Action::MinMax => {
Expand Down
7 changes: 4 additions & 3 deletions fuzz/src/array/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::Canonical;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray as _;
use vortex_array::compute::sum;
use vortex_array::aggregate_fn::fns::sum::sum;
use vortex_array::scalar::Scalar;
use vortex_error::VortexResult;

/// Compute sum on the canonical form of the array to get a consistent baseline.
pub fn sum_canonical_array(canonical: Canonical) -> VortexResult<Scalar> {
pub fn sum_canonical_array(canonical: Canonical, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
// TODO(joe): replace with baseline not using canonical
sum(&canonical.into_array())
sum(&canonical.into_array(), ctx)
}
160 changes: 22 additions & 138 deletions vortex-array/public-api.lock

Large diffs are not rendered by default.

47 changes: 25 additions & 22 deletions vortex-array/src/aggregate_fn/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_session::VortexSession;
use vortex_error::vortex_err;

use crate::AnyCanonical;
use crate::ArrayRef;
use crate::Columnar;
use crate::DynArray;
use crate::VortexSessionExecute;
use crate::ExecutionCtx;
use crate::aggregate_fn::AggregateFn;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTable;
Expand All @@ -35,19 +35,24 @@ pub struct Accumulator<V: AggregateFnVTable> {
partial_dtype: DType,
/// The partial state of the accumulator, updated after each accumulate/merge call.
partial: V::Partial,
/// A session used to lookup custom aggregate kernels.
session: VortexSession,
}

impl<V: AggregateFnVTable> Accumulator<V> {
pub fn try_new(
vtable: V,
options: V::Options,
dtype: DType,
session: VortexSession,
) -> VortexResult<Self> {
let return_dtype = vtable.return_dtype(&options, &dtype)?;
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
vortex_err!(
"Aggregate function {} cannot be applied to dtype {}",
vtable.id(),
dtype
)
})?;
let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
vortex_err!(
"Aggregate function {} cannot be applied to dtype {}",
vtable.id(),
dtype
)
})?;
let partial = vtable.empty_partial(&options, &dtype)?;
let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();

Expand All @@ -58,7 +63,6 @@ impl<V: AggregateFnVTable> Accumulator<V> {
return_dtype,
partial_dtype,
partial,
session,
})
}
}
Expand All @@ -67,7 +71,7 @@ impl<V: AggregateFnVTable> Accumulator<V> {
/// function is not known at compile time.
pub trait DynAccumulator: 'static + Send {
/// Accumulate a new array into the accumulator's state.
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>;
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;

/// Whether the accumulator's result is fully determined.
fn is_saturated(&self) -> bool;
Expand All @@ -84,7 +88,7 @@ pub trait DynAccumulator: 'static + Send {
}

impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
fn accumulate(&mut self, batch: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
if self.is_saturated() {
return Ok(());
}
Expand All @@ -96,9 +100,9 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
batch.dtype()
);

let kernels = &self.session.aggregate_fns().kernels;
let session = ctx.session().clone();
let kernels = &session.aggregate_fns().kernels;

let mut ctx = self.session.create_execution_ctx();
let mut batch = batch.clone();
for _ in 0..*MAX_ITERATIONS {
if batch.is::<AnyCanonical>() {
Expand All @@ -112,7 +116,7 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
.or_else(|| kernels_r.get(&(batch_id, None)))
.and_then(|kernel| {
kernel
.aggregate(&self.aggregate_fn, &batch, &mut ctx)
.aggregate(&self.aggregate_fn, &batch, ctx)
.transpose()
})
.transpose()?
Expand All @@ -128,14 +132,13 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
}

// Execute one step and try again
batch = batch.execute(&mut ctx)?;
batch = batch.execute(ctx)?;
}

// Otherwise, execute the batch until it is columnar and accumulate it into the state.
let columnar = batch.execute::<Columnar>(&mut ctx)?;
let columnar = batch.execute::<Columnar>(ctx)?;

self.vtable
.accumulate(&mut self.partial, &columnar, &mut ctx)
self.vtable.accumulate(&mut self.partial, &columnar, ctx)
}

fn is_saturated(&self) -> bool {
Expand Down
64 changes: 34 additions & 30 deletions vortex-array/src/aggregate_fn/accumulator_grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_error::vortex_panic;
use vortex_mask::Mask;
use vortex_session::VortexSession;

use crate::AnyCanonical;
use crate::ArrayRef;
Expand All @@ -18,7 +18,6 @@ use crate::Columnar;
use crate::DynArray;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::VortexSessionExecute;
use crate::aggregate_fn::Accumulator;
use crate::aggregate_fn::AggregateFn;
use crate::aggregate_fn::AggregateFnRef;
Expand Down Expand Up @@ -58,20 +57,25 @@ pub struct GroupedAccumulator<V: AggregateFnVTable> {
partial_dtype: DType,
/// The accumulated state for prior batches of groups.
partials: Vec<ArrayRef>,
/// A session used to lookup custom aggregate kernels.
session: VortexSession,
}

impl<V: AggregateFnVTable> GroupedAccumulator<V> {
pub fn try_new(
vtable: V,
options: V::Options,
dtype: DType,
session: VortexSession,
) -> VortexResult<Self> {
pub fn try_new(vtable: V, options: V::Options, dtype: DType) -> VortexResult<Self> {
let aggregate_fn = AggregateFn::new(vtable.clone(), options.clone()).erased();
let return_dtype = vtable.return_dtype(&options, &dtype)?;
let partial_dtype = vtable.partial_dtype(&options, &dtype)?;
let return_dtype = vtable.return_dtype(&options, &dtype).ok_or_else(|| {
vortex_err!(
"Aggregate function {} cannot be applied to dtype {}",
vtable.id(),
dtype
)
})?;
let partial_dtype = vtable.partial_dtype(&options, &dtype).ok_or_else(|| {
vortex_err!(
"Aggregate function {} cannot be applied to dtype {}",
vtable.id(),
dtype
)
})?;

Ok(Self {
vtable,
Expand All @@ -81,7 +85,6 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
return_dtype,
partial_dtype,
partials: vec![],
session,
})
}
}
Expand All @@ -90,7 +93,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
/// function is not known at compile time.
pub trait DynGroupedAccumulator: 'static + Send {
/// Accumulate a list of groups into the accumulator.
fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()>;
fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>;

/// Finish the accumulation and return the partial aggregate results for all groups.
/// Resets the accumulator state for the next round of accumulation.
Expand All @@ -102,7 +105,7 @@ pub trait DynGroupedAccumulator: 'static + Send {
}

impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
fn accumulate_list(&mut self, groups: &ArrayRef) -> VortexResult<()> {
fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> {
let elements_dtype = match groups.dtype() {
DType::List(elem, _) => elem,
DType::FixedSizeList(elem, ..) => elem,
Expand All @@ -118,17 +121,15 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
elements_dtype
);

let mut ctx = self.session.create_execution_ctx();

// We first execute the groups until it is a ListView or FixedSizeList, since we only
// dispatch the aggregate kernel over the elements of these arrays.
let canonical = match groups.clone().execute::<Columnar>(&mut ctx)? {
let canonical = match groups.clone().execute::<Columnar>(ctx)? {
Columnar::Canonical(c) => c,
Columnar::Constant(c) => c.into_array().execute::<Canonical>(&mut ctx)?,
Columnar::Constant(c) => c.into_array().execute::<Canonical>(ctx)?,
};
match canonical {
Canonical::List(groups) => self.accumulate_list_view(&groups, &mut ctx),
Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, &mut ctx),
Canonical::List(groups) => self.accumulate_list_view(&groups, ctx),
Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx),
_ => vortex_panic!("We checked the DType above, so this should never happen"),
}
}
Expand Down Expand Up @@ -160,8 +161,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
let mut elements = groups.elements().clone();
let session = self.session.clone();

let session = ctx.session().clone();
let kernels = &session.aggregate_fns().grouped_kernels;

for _ in 0..*MAX_ITERATIONS {
Expand Down Expand Up @@ -205,7 +205,13 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
let offsets = offsets.clone().execute::<Buffer<O>>(ctx)?;
let sizes = sizes.execute::<Buffer<O>>(ctx)?;
self.accumulate_list_view_typed(&elements, offsets.as_ref(), sizes.as_ref(), &validity)
self.accumulate_list_view_typed(
&elements,
offsets.as_ref(),
sizes.as_ref(),
&validity,
ctx,
)
})
}

Expand All @@ -215,12 +221,12 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
offsets: &[O],
sizes: &[O],
validity: &Mask,
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
let mut accumulator = Accumulator::try_new(
self.vtable.clone(),
self.options.clone(),
self.dtype.clone(),
self.session.clone(),
)?;
let mut states = builder_with_capacity(&self.partial_dtype, offsets.len());

Expand All @@ -230,7 +236,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {

if validity.value(offset) {
let group = elements.slice(offset..offset + size)?;
accumulator.accumulate(&group)?;
accumulator.accumulate(&group, ctx)?;
states.append_scalar(&accumulator.finish()?)?;
} else {
states.append_null()
Expand All @@ -246,8 +252,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
ctx: &mut ExecutionCtx,
) -> VortexResult<()> {
let mut elements = groups.elements().clone();

let session = self.session.clone();
let session = ctx.session().clone();
let kernels = &session.aggregate_fns().grouped_kernels;

for _ in 0..64 {
Expand Down Expand Up @@ -291,7 +296,6 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
self.vtable.clone(),
self.options.clone(),
self.dtype.clone(),
self.session.clone(),
)?;
let mut states = builder_with_capacity(&self.partial_dtype, groups.len());

Expand All @@ -304,7 +308,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
for i in 0..groups.len() {
if validity.value(i) {
let group = elements.slice(offset..offset + size)?;
accumulator.accumulate(&group)?;
accumulator.accumulate(&group, ctx)?;
states.append_scalar(&accumulator.finish()?)?;
} else {
states.append_null()
Expand Down
Loading
Loading