diff --git a/encodings/alp/src/alp/compute/mask.rs b/encodings/alp/src/alp/compute/mask.rs index 4d59e5bf73a..d4f63475c9e 100644 --- a/encodings/alp/src/alp/compute/mask.rs +++ b/encodings/alp/src/alp/compute/mask.rs @@ -32,7 +32,7 @@ impl MaskKernel for ALP { mask: &ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult> { - let vortex_mask = Validity::Array(mask.not()?).to_mask(array.len()); + let vortex_mask = Validity::Array(mask.not()?).execute_mask(array.len(), ctx)?; let masked_encoded = array.encoded().clone().mask(mask.clone())?; let masked_patches = array .patches() diff --git a/encodings/datetime-parts/src/canonical.rs b/encodings/datetime-parts/src/canonical.rs index 51e2d0447ba..3a27dd8b4d0 100644 --- a/encodings/datetime-parts/src/canonical.rs +++ b/encodings/datetime-parts/src/canonical.rs @@ -147,19 +147,17 @@ mod test { )) .unwrap(); - assert_eq!( - date_times.validity_mask().unwrap(), - validity.to_mask(date_times.len()) - ); - let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + assert!(date_times.validity()?.mask_eq(&validity, &mut ctx)?); + let primitive_values = decode_to_temporal(&date_times, &mut ctx)? .temporal_values() .clone() .execute::(&mut ctx)?; assert_arrays_eq!(primitive_values, milliseconds); - assert_eq!(primitive_values.validity(), &validity); + assert!(primitive_values.validity().mask_eq(&validity, &mut ctx)?); Ok(()) } } diff --git a/encodings/datetime-parts/src/compress.rs b/encodings/datetime-parts/src/compress.rs index 65bbea1c116..6010f6adeaf 100644 --- a/encodings/datetime-parts/src/compress.rs +++ b/encodings/datetime-parts/src/compress.rs @@ -77,7 +77,9 @@ impl TryFrom for DateTimePartsArray { mod tests { use rstest::rstest; use vortex_array::IntoArray; + use vortex_array::LEGACY_SESSION; use vortex_array::ToCanonical; + use vortex_array::VortexSessionExecute; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::TemporalArray; use vortex_array::extension::datetime::TimeUnit; @@ -110,8 +112,21 @@ mod tests { seconds, subseconds, } = split_temporal(temporal_array).unwrap(); - assert_eq!(days.to_primitive().validity(), &validity); - assert_eq!(seconds.to_primitive().validity(), &Validity::NonNullable); - assert_eq!(subseconds.to_primitive().validity(), &Validity::NonNullable); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert!( + days.to_primitive() + .validity() + .mask_eq(&validity, &mut ctx) + .unwrap() + ); + assert!(matches!( + seconds.to_primitive().validity(), + Validity::NonNullable + )); + assert!(matches!( + subseconds.to_primitive().validity(), + Validity::NonNullable + )); } } diff --git a/encodings/fastlanes/src/rle/array/mod.rs b/encodings/fastlanes/src/rle/array/mod.rs index 83f75ad45ae..5e46d7f6438 100644 --- a/encodings/fastlanes/src/rle/array/mod.rs +++ b/encodings/fastlanes/src/rle/array/mod.rs @@ -218,7 +218,9 @@ mod tests { use vortex_array::ArrayContext; use vortex_array::DynArray; use vortex_array::IntoArray; + use vortex_array::LEGACY_SESSION; use vortex_array::ToCanonical; + use vortex_array::VortexSessionExecute; use vortex_array::arrays::PrimitiveArray; use vortex_array::assert_arrays_eq; use vortex_array::dtype::DType; @@ -397,7 +399,12 @@ mod tests { let sliced_array = rle_array.slice(1..4).unwrap(); let validity_mask = sliced_array.validity_mask().unwrap(); - let expected_mask = Validity::from_iter([false, true, false]).to_mask(3); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let expected_mask = Validity::from_iter([false, true, false]) + .execute_mask(3, &mut ctx) + .unwrap(); + assert_eq!(validity_mask.len(), expected_mask.len()); + assert_eq!(validity_mask, expected_mask); assert_eq!(validity_mask.len(), expected_mask.len()); assert_eq!(validity_mask, expected_mask); } diff --git a/encodings/pco/public-api.lock b/encodings/pco/public-api.lock index fc17c5124b7..b82093e811b 100644 --- a/encodings/pco/public-api.lock +++ b/encodings/pco/public-api.lock @@ -46,7 +46,7 @@ pub fn vortex_pco::Pco::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype:: pub fn vortex_pco::Pco::dtype(array: &vortex_pco::PcoArray) -> &vortex_array::dtype::DType -pub fn vortex_pco::Pco::execute(array: &Self::Array, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_pco::Pco::execute(array: &Self::Array, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_pco::Pco::id(_array: &Self::Array) -> vortex_array::vtable::dyn_::ArrayId @@ -74,7 +74,7 @@ pub struct vortex_pco::PcoArray impl vortex_pco::PcoArray -pub fn vortex_pco::PcoArray::decompress(&self) -> vortex_error::VortexResult +pub fn vortex_pco::PcoArray::decompress(&self, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_pco::PcoArray::from_array(array: vortex_array::array::ArrayRef, level: usize, nums_per_page: usize) -> vortex_error::VortexResult diff --git a/encodings/pco/src/array.rs b/encodings/pco/src/array.rs index cfd8a6fafb8..85ba2bbe6cb 100644 --- a/encodings/pco/src/array.rs +++ b/encodings/pco/src/array.rs @@ -22,9 +22,11 @@ use vortex_array::DynArray; use vortex_array::ExecutionCtx; use vortex_array::ExecutionStep; use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::ToCanonical; +use vortex_array::VortexSessionExecute; use vortex_array::arrays::Primitive; use vortex_array::arrays::PrimitiveArray; use vortex_array::buffer::BufferHandle; @@ -263,8 +265,8 @@ impl VTable for Pco { Ok(()) } - fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult { - Ok(ExecutionStep::Done(array.decompress()?.into_array())) + fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { + Ok(ExecutionStep::Done(array.decompress(ctx)?.into_array())) } fn reduce_parent( @@ -437,14 +439,14 @@ impl PcoArray { } } - pub fn decompress(&self) -> VortexResult { + pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult { // To start, we figure out which chunks and pages we need to decompress, and with // what value offset into the first such page. let number_type = number_type_from_dtype(&self.dtype); let values_byte_buffer = match_number_enum!( number_type, NumberType => { - self.decompress_values_typed::()? + self.decompress_values_typed::(ctx)? } ); @@ -457,11 +459,14 @@ impl PcoArray { )) } - fn decompress_values_typed(&self) -> VortexResult { + fn decompress_values_typed( + &self, + ctx: &mut ExecutionCtx, + ) -> VortexResult { // To start, we figure out what range of values we need to decompress. let slice_value_indices = self .unsliced_validity - .to_mask(self.unsliced_n_rows) + .execute_mask(self.unsliced_n_rows, ctx)? .valid_counts_for_indices(&[self.slice_start, self.slice_stop]); let slice_value_start = slice_value_indices[0]; let slice_value_stop = slice_value_indices[1]; @@ -564,7 +569,11 @@ impl ValiditySliceHelper for PcoArray { impl OperationsVTable for Pco { fn scalar_at(array: &PcoArray, index: usize) -> VortexResult { - array._slice(index, index + 1).decompress()?.scalar_at(0) + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + array + ._slice(index, index + 1) + .decompress(&mut ctx)? + .scalar_at(0) } } diff --git a/encodings/pco/src/test.rs b/encodings/pco/src/test.rs index a87a43cc13d..2cc7c830a5e 100644 --- a/encodings/pco/src/test.rs +++ b/encodings/pco/src/test.rs @@ -6,6 +6,7 @@ use std::sync::LazyLock; use vortex_array::ArrayContext; use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; use vortex_array::ToCanonical; use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; @@ -47,7 +48,8 @@ fn test_compress_decompress() { assert!(compressed.pages.len() < array.nbytes() as usize); // check full decompression works - let decompressed = compressed.decompress().unwrap(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let decompressed = compressed.decompress(&mut ctx).unwrap(); assert_arrays_eq!(decompressed, PrimitiveArray::from_iter(data)); // check slicing works @@ -69,7 +71,8 @@ fn test_compress_decompress_small() { let expected = array.into_array(); assert_arrays_eq!(compressed, expected); - let decompressed = compressed.decompress().unwrap(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let decompressed = compressed.decompress(&mut ctx).unwrap(); assert_arrays_eq!(decompressed, expected); } @@ -78,7 +81,8 @@ fn test_empty() { let data: Vec = vec![]; let array = PrimitiveArray::from_iter(data.clone()); let compressed = PcoArray::from_primitive(&array, 3, 100).unwrap(); - let primitive = compressed.decompress().unwrap(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let primitive = compressed.decompress(&mut ctx).unwrap(); assert_arrays_eq!(primitive, PrimitiveArray::from_iter(data)); } @@ -118,9 +122,16 @@ fn test_validity_and_multiple_chunks_and_pages() { assert_nth_scalar!(slice, 0, 100); assert_nth_scalar!(slice, 2, 102); let primitive = slice.to_primitive(); - assert_eq!( - primitive.validity(), - &Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()) + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert!( + primitive + .validity() + .mask_eq( + &Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()), + &mut ctx, + ) + .unwrap() ); } diff --git a/encodings/zstd/public-api.lock b/encodings/zstd/public-api.lock index 192a4e54065..d7390a53c90 100644 --- a/encodings/zstd/public-api.lock +++ b/encodings/zstd/public-api.lock @@ -74,7 +74,7 @@ pub struct vortex_zstd::ZstdArray impl vortex_zstd::ZstdArray -pub fn vortex_zstd::ZstdArray::decompress(&self) -> vortex_error::VortexResult +pub fn vortex_zstd::ZstdArray::decompress(&self, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_zstd::ZstdArray::from_array(array: vortex_array::array::ArrayRef, level: i32, values_per_frame: usize) -> vortex_error::VortexResult diff --git a/encodings/zstd/src/array.rs b/encodings/zstd/src/array.rs index 194d7ce1122..9c2e133d50d 100644 --- a/encodings/zstd/src/array.rs +++ b/encodings/zstd/src/array.rs @@ -15,9 +15,11 @@ use vortex_array::DynArray; use vortex_array::ExecutionCtx; use vortex_array::ExecutionStep; use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; use vortex_array::Precision; use vortex_array::ProstMetadata; use vortex_array::ToCanonical; +use vortex_array::VortexSessionExecute; use vortex_array::accessor::ArrayAccessor; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; @@ -274,7 +276,7 @@ impl VTable for Zstd { fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult { array - .decompress()? + .decompress(ctx)? .execute::(ctx) .map(ExecutionStep::Done) } @@ -702,14 +704,14 @@ impl ZstdArray { } } - pub fn decompress(&self) -> VortexResult { + pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult { // To start, we figure out which frames we need to decompress, and with // what row offset into the first such frame. let byte_width = self.byte_width(); let slice_n_rows = self.slice_stop - self.slice_start; let slice_value_indices = self .unsliced_validity - .to_mask(self.unsliced_n_rows) + .execute_mask(self.unsliced_n_rows, ctx)? .valid_counts_for_indices(&[self.slice_start, self.slice_stop]); let slice_value_idx_start = slice_value_indices[0]; @@ -787,14 +789,14 @@ impl ZstdArray { // // We ensure that the validity of the decompressed array ALWAYS matches the validity // implied by the DType. - if !self.dtype().is_nullable() && slice_validity != Validity::NonNullable { + if !self.dtype().is_nullable() && !matches!(slice_validity, Validity::NonNullable) { assert!( - slice_validity.all_valid(slice_n_rows)?, + matches!(slice_validity, Validity::AllValid), "ZSTD array expects to be non-nullable but there are nulls after decompression" ); slice_validity = Validity::NonNullable; - } else if self.dtype.is_nullable() && slice_validity == Validity::NonNullable { + } else if self.dtype.is_nullable() && matches!(slice_validity, Validity::NonNullable) { slice_validity = Validity::AllValid; } // @@ -817,7 +819,7 @@ impl ZstdArray { Ok(primitive.into_array()) } DType::Binary(_) | DType::Utf8(_) => { - match slice_validity.to_mask(slice_n_rows).indices() { + match slice_validity.execute_mask(slice_n_rows, ctx)?.indices() { AllOr::All => { // the decompressed buffer is a bunch of interleaved u32 lengths // and strings of those lengths, we need to reconstruct the @@ -937,6 +939,10 @@ impl ValiditySliceHelper for ZstdArray { impl OperationsVTable for Zstd { fn scalar_at(array: &ZstdArray, index: usize) -> VortexResult { - array._slice(index, index + 1).decompress()?.scalar_at(0) + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + array + ._slice(index, index + 1) + .decompress(&mut ctx)? + .scalar_at(0) } } diff --git a/encodings/zstd/src/compute/cast.rs b/encodings/zstd/src/compute/cast.rs index 61cee2f8cd2..f57637df88c 100644 --- a/encodings/zstd/src/compute/cast.rs +++ b/encodings/zstd/src/compute/cast.rs @@ -6,6 +6,7 @@ use vortex_array::IntoArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::scalar_fn::fns::cast::CastReduce; +use vortex_array::validity::Validity; use vortex_error::VortexResult; use crate::Zstd; @@ -45,11 +46,10 @@ impl CastReduce for Zstd { } (Nullability::Nullable, Nullability::NonNullable) => { // null => non-null works if there are no nulls in the sliced range - let sliced_len = array.slice_stop() - array.slice_start(); - let has_nulls = !array - .unsliced_validity - .slice(array.slice_start()..array.slice_stop())? - .all_valid(sliced_len)?; + let has_nulls = !matches!( + array.validity()?, + Validity::AllValid | Validity::NonNullable + ); // We don't attempt to handle casting when there are nulls. if has_nulls { diff --git a/encodings/zstd/src/test.rs b/encodings/zstd/src/test.rs index 7d3b4fe1d1d..d17fe512824 100644 --- a/encodings/zstd/src/test.rs +++ b/encodings/zstd/src/test.rs @@ -3,7 +3,9 @@ #![allow(clippy::cast_possible_truncation)] use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; use vortex_array::ToCanonical; +use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::VarBinViewArray; @@ -30,7 +32,8 @@ fn test_zstd_compress_decompress() { assert!(compressed.dictionary.is_none()); // check full decompression works - let decompressed = compressed.decompress().unwrap(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let decompressed = compressed.decompress(&mut ctx).unwrap(); assert_arrays_eq!(decompressed, PrimitiveArray::from_iter(data)); // check slicing works @@ -75,11 +78,17 @@ fn test_zstd_with_validity_and_multi_frame() { assert_nth_scalar!(compressed, 10, None::); assert_nth_scalar!(compressed, 177, 177); - let decompressed = compressed.decompress().unwrap().to_primitive(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let decompressed = compressed.decompress(&mut ctx).unwrap().to_primitive(); let decompressed_values = decompressed.as_slice::(); assert_eq!(decompressed_values[3], 3); assert_eq!(decompressed_values[177], 177); - assert_eq!(decompressed.validity(), array.validity()); + assert!( + decompressed + .validity() + .mask_eq(array.validity(), &mut ctx) + .unwrap() + ); // check slicing works let slice = compressed.slice(176..179).unwrap(); @@ -88,9 +97,14 @@ fn test_zstd_with_validity_and_multi_frame() { i32::try_from(&primitive.scalar_at(1).unwrap()).unwrap(), 177 ); - assert_eq!( - primitive.validity(), - &Validity::Array(BoolArray::from_iter(vec![false, true, false]).into_array()) + assert!( + primitive + .validity() + .mask_eq( + &Validity::Array(BoolArray::from_iter(vec![false, true, false]).into_array()), + &mut ctx + ) + .unwrap() ); } @@ -107,7 +121,8 @@ fn test_zstd_with_dict() { assert_nth_scalar!(compressed, 0, 0); assert_nth_scalar!(compressed, 199, 199); - let decompressed = compressed.decompress().unwrap().to_primitive(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let decompressed = compressed.decompress(&mut ctx).unwrap().to_primitive(); assert_arrays_eq!(decompressed, PrimitiveArray::from_iter(data)); // check slicing works @@ -177,7 +192,9 @@ fn test_zstd_decompress_var_bin_view() { assert_nth_scalar!(compressed, 2, None::); assert_nth_scalar!(compressed, 3, "Lorem ipsum dolor sit amet"); assert_nth_scalar!(compressed, 4, "baz"); - let decompressed = compressed.decompress().unwrap().to_varbinview(); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let decompressed = compressed.decompress(&mut ctx).unwrap().to_varbinview(); assert_nth_scalar!(decompressed, 0, "foo"); assert_nth_scalar!(decompressed, 1, "bar"); assert_nth_scalar!(decompressed, 2, None::); diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index c27baf13951..e7e66efbe27 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -20110,10 +20110,6 @@ impl vortex_array::validity::Validity pub const vortex_array::validity::Validity::DTYPE: vortex_array::dtype::DType -pub fn vortex_array::validity::Validity::all_invalid(&self, len: usize) -> vortex_error::VortexResult - -pub fn vortex_array::validity::Validity::all_valid(&self, len: usize) -> vortex_error::VortexResult - pub fn vortex_array::validity::Validity::and(self, rhs: vortex_array::validity::Validity) -> vortex_error::VortexResult pub fn vortex_array::validity::Validity::as_array(&self) -> core::option::Option<&vortex_array::ArrayRef> @@ -20122,6 +20118,8 @@ pub fn vortex_array::validity::Validity::cast_nullability(self, nullability: vor pub fn vortex_array::validity::Validity::copy_from_array(array: &vortex_array::ArrayRef) -> vortex_error::VortexResult +pub fn vortex_array::validity::Validity::execute_mask(&self, length: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub fn vortex_array::validity::Validity::filter(&self, mask: &vortex_mask::Mask) -> vortex_error::VortexResult pub fn vortex_array::validity::Validity::into_array(self) -> core::option::Option @@ -20134,6 +20132,8 @@ pub fn vortex_array::validity::Validity::is_null(&self, index: usize) -> vortex_ pub fn vortex_array::validity::Validity::is_valid(&self, index: usize) -> vortex_error::VortexResult +pub fn vortex_array::validity::Validity::mask_eq(&self, other: &vortex_array::validity::Validity, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub fn vortex_array::validity::Validity::maybe_len(&self) -> core::option::Option pub fn vortex_array::validity::Validity::not(&self) -> vortex_error::VortexResult @@ -20148,8 +20148,6 @@ pub fn vortex_array::validity::Validity::take(&self, indices: &vortex_array::Arr pub fn vortex_array::validity::Validity::to_array(&self, len: usize) -> vortex_array::ArrayRef -pub fn vortex_array::validity::Validity::to_mask(&self, length: usize) -> vortex_mask::Mask - pub fn vortex_array::validity::Validity::uncompressed_size(&self) -> usize pub fn vortex_array::validity::Validity::union_nullability(self, nullability: vortex_array::dtype::Nullability) -> Self @@ -20168,10 +20166,6 @@ impl core::clone::Clone for vortex_array::validity::Validity pub fn vortex_array::validity::Validity::clone(&self) -> vortex_array::validity::Validity -impl core::cmp::PartialEq for vortex_array::validity::Validity - -pub fn vortex_array::validity::Validity::eq(&self, other: &Self) -> bool - impl core::convert::From<&vortex_array::dtype::Nullability> for vortex_array::validity::Validity pub fn vortex_array::validity::Validity::from(value: &vortex_array::dtype::Nullability) -> Self diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 17d165e3746..b2b9cf38b35 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -200,7 +200,7 @@ impl GroupedAccumulator { let elements = elements.execute::(ctx)?.into_array(); let offsets = groups.offsets(); let sizes = groups.sizes().cast(offsets.dtype().clone())?; - let validity = groups.validity().to_mask(offsets.len()); + let validity = groups.validity().execute_mask(offsets.len(), ctx)?; match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { let offsets = offsets.clone().execute::>(ctx)?; @@ -285,7 +285,7 @@ impl GroupedAccumulator { // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. let elements = elements.execute::(ctx)?.into_array(); - let validity = groups.validity().to_mask(groups.len()); + let validity = groups.validity().execute_mask(groups.len(), ctx)?; let mut accumulator = Accumulator::try_new( self.vtable.clone(), diff --git a/vortex-array/src/arrays/constant/vtable/canonical.rs b/vortex-array/src/arrays/constant/vtable/canonical.rs index 3540257fb85..f8e2d7f21cf 100644 --- a/vortex-array/src/arrays/constant/vtable/canonical.rs +++ b/vortex-array/src/arrays/constant/vtable/canonical.rs @@ -126,7 +126,7 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult { - assert!(validity.all_invalid(array.len())?); + assert!(matches!(validity, Validity::AllInvalid)); // The struct is entirely null, so fields just need placeholder values with the // correct dtype. We use `default_value` which returns a zero for non-nullable // dtypes and null for nullable dtypes, preserving each field's nullability. @@ -496,7 +496,7 @@ mod tests { assert_eq!(canonical.len(), 4); assert_eq!(canonical.list_size(), 3); - assert_eq!(canonical.validity(), &Validity::NonNullable); + assert!(matches!(canonical.validity(), Validity::NonNullable)); // Check that each list is [10, 20, 30]. for i in 0..4 { @@ -523,7 +523,7 @@ mod tests { assert_eq!(canonical.len(), 3); assert_eq!(canonical.list_size(), 2); - assert_eq!(canonical.validity(), &Validity::AllValid); + assert!(matches!(canonical.validity(), Validity::AllValid)); // Check elements. let elements = canonical.elements().to_primitive(); @@ -547,7 +547,7 @@ mod tests { assert_eq!(canonical.len(), 5); assert_eq!(canonical.list_size(), 4); - assert_eq!(canonical.validity(), &Validity::AllInvalid); + assert!(matches!(canonical.validity(), Validity::AllInvalid)); // Elements should be defaults (zeros). let elements = canonical.elements().to_primitive(); @@ -569,7 +569,7 @@ mod tests { assert_eq!(canonical.len(), 10); assert_eq!(canonical.list_size(), 0); - assert_eq!(canonical.validity(), &Validity::NonNullable); + assert!(matches!(canonical.validity(), Validity::NonNullable)); // Elements array should be empty. assert!(canonical.elements().is_empty()); @@ -635,7 +635,7 @@ mod tests { assert_eq!(canonical.len(), 3); assert_eq!(canonical.list_size(), 3); - assert_eq!(canonical.validity(), &Validity::NonNullable); + assert!(matches!(canonical.validity(), Validity::NonNullable)); // Check elements including nulls. let elements = canonical.elements().to_primitive(); diff --git a/vortex-array/src/arrays/datetime/test.rs b/vortex-array/src/arrays/datetime/test.rs index 66dbacfaa29..b77da6b622a 100644 --- a/vortex-array/src/arrays/datetime/test.rs +++ b/vortex-array/src/arrays/datetime/test.rs @@ -6,6 +6,7 @@ use vortex_buffer::buffer; use vortex_error::VortexResult; use crate::IntoArray; +use crate::Precision; use crate::ToCanonical; use crate::array::DynArray; use crate::arrays::PrimitiveArray; @@ -18,6 +19,7 @@ use crate::extension::datetime::TemporalMetadata; use crate::extension::datetime::TimeUnit; use crate::extension::datetime::Timestamp; use crate::extension::datetime::TimestampOptions; +use crate::hash::ArrayEq; use crate::scalar::Scalar; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -194,9 +196,13 @@ fn test_validity_preservation(#[case] validity: Validity) { .into_array(); let temporal_array = TemporalArray::new_timestamp(milliseconds, TimeUnit::Milliseconds, Some("UTC".into())); - assert_eq!( - temporal_array.temporal_values().to_primitive().validity(), - &validity + + assert!( + temporal_array + .temporal_values() + .to_primitive() + .validity() + .array_eq(&validity, Precision::Ptr) ); } diff --git a/vortex-array/src/arrays/decimal/compute/cast.rs b/vortex-array/src/arrays/decimal/compute/cast.rs index c00b65db587..cea97ac91ae 100644 --- a/vortex-array/src/arrays/decimal/compute/cast.rs +++ b/vortex-array/src/arrays/decimal/compute/cast.rs @@ -176,7 +176,7 @@ mod tests { .to_decimal(); assert_eq!(casted.dtype(), &nullable_dtype); - assert_eq!(casted.validity(), &Validity::AllValid); + assert!(matches!(casted.validity(), Validity::AllValid)); assert_eq!(casted.len(), 3); } @@ -196,7 +196,7 @@ mod tests { .to_decimal(); assert_eq!(casted.dtype(), &non_nullable_dtype); - assert_eq!(casted.validity(), &Validity::NonNullable); + assert!(matches!(casted.validity(), Validity::NonNullable)); } #[test] diff --git a/vortex-array/src/arrays/listview/conversion.rs b/vortex-array/src/arrays/listview/conversion.rs index ef3718d49e1..cddee9a2ec5 100644 --- a/vortex-array/src/arrays/listview/conversion.rs +++ b/vortex-array/src/arrays/listview/conversion.rs @@ -287,8 +287,10 @@ mod tests { use super::super::tests::common::create_nullable_listview; use super::super::tests::common::create_overlapping_listview; use super::recursive_list_from_list_view; + use crate::ArrayEq; use crate::IntoArray; use crate::LEGACY_SESSION; + use crate::Precision; use crate::VortexSessionExecute; use crate::arrays::BoolArray; use crate::arrays::FixedSizeListArray; @@ -386,7 +388,11 @@ mod tests { let nullable_list_view = list_view_from_list(nullable_list.clone(), &mut ctx)?; // Verify validity is preserved. - assert_eq!(nullable_list_view.validity(), &validity); + assert!( + nullable_list_view + .validity() + .array_eq(&validity, Precision::Ptr) + ); assert_eq!(nullable_list_view.len(), 3); // Round-trip conversion. diff --git a/vortex-array/src/arrays/masked/execute.rs b/vortex-array/src/arrays/masked/execute.rs index a4cf0f06f7c..936a4ef8d65 100644 --- a/vortex-array/src/arrays/masked/execute.rs +++ b/vortex-array/src/arrays/masked/execute.rs @@ -37,112 +37,147 @@ pub fn mask_validity_canonical( ) -> VortexResult { Ok(match canonical { Canonical::Null(a) => Canonical::Null(mask_validity_null(a, validity_mask)), - Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity_mask)), - Canonical::Primitive(a) => Canonical::Primitive(mask_validity_primitive(a, validity_mask)), - Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity_mask)), + Canonical::Bool(a) => Canonical::Bool(mask_validity_bool(a, validity_mask, ctx)?), + Canonical::Primitive(a) => { + Canonical::Primitive(mask_validity_primitive(a, validity_mask, ctx)?) + } + Canonical::Decimal(a) => Canonical::Decimal(mask_validity_decimal(a, validity_mask, ctx)?), Canonical::VarBinView(a) => { - Canonical::VarBinView(mask_validity_varbinview(a, validity_mask)) + Canonical::VarBinView(mask_validity_varbinview(a, validity_mask, ctx)?) } - Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity_mask)), + Canonical::List(a) => Canonical::List(mask_validity_listview(a, validity_mask, ctx)?), Canonical::FixedSizeList(a) => { - Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity_mask)) + Canonical::FixedSizeList(mask_validity_fixed_size_list(a, validity_mask, ctx)?) } - Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity_mask)), + Canonical::Struct(a) => Canonical::Struct(mask_validity_struct(a, validity_mask, ctx)?), Canonical::Extension(a) => { Canonical::Extension(mask_validity_extension(a, validity_mask, ctx)?) } }) } -fn combine_validity(validity: &Validity, mask: &Mask, len: usize) -> Validity { - let current_mask = validity.to_mask(len); +fn combine_validity( + validity: &Validity, + mask: &Mask, + len: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let current_mask = validity.execute_mask(len, ctx)?; let combined = current_mask.bitand(mask); - Validity::from_mask(combined, Nullability::Nullable) + Ok(Validity::from_mask(combined, Nullability::Nullable)) } fn mask_validity_null(array: NullArray, _mask: &Mask) -> NullArray { array } -fn mask_validity_bool(array: BoolArray, mask: &Mask) -> BoolArray { +fn mask_validity_bool( + array: BoolArray, + mask: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { let len = array.len(); - let new_validity = combine_validity(array.validity(), mask, len); - BoolArray::new(array.to_bit_buffer(), new_validity) + let new_validity = combine_validity(array.validity(), mask, len, ctx)?; + Ok(BoolArray::new(array.to_bit_buffer(), new_validity)) } -fn mask_validity_primitive(array: PrimitiveArray, mask: &Mask) -> PrimitiveArray { +fn mask_validity_primitive( + array: PrimitiveArray, + mask: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { let len = array.len(); let ptype = array.ptype(); - let new_validity = combine_validity(array.validity(), mask, len); + let new_validity = combine_validity(array.validity(), mask, len, ctx)?; // SAFETY: validity has same length as values - unsafe { + Ok(unsafe { PrimitiveArray::new_unchecked_from_handle( array.buffer_handle().clone(), ptype, new_validity, ) - } + }) } -fn mask_validity_decimal(array: DecimalArray, mask: &Mask) -> DecimalArray { +fn mask_validity_decimal( + array: DecimalArray, + mask: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { let len = array.len(); let dec_dtype = array.decimal_dtype(); let values_type = array.values_type(); - let new_validity = combine_validity(array.validity(), mask, len); + let new_validity = combine_validity(array.validity(), mask, len, ctx)?; // SAFETY: We're only changing validity, not the data structure - match_each_decimal_value_type!(values_type, |T| { + Ok(match_each_decimal_value_type!(values_type, |T| { let buffer = array.buffer::(); unsafe { DecimalArray::new_unchecked(buffer, dec_dtype, new_validity) } - }) + })) } /// Mask validity for VarBinViewArray. -fn mask_validity_varbinview(array: VarBinViewArray, mask: &Mask) -> VarBinViewArray { +fn mask_validity_varbinview( + array: VarBinViewArray, + mask: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { let len = array.len(); let dtype = array.dtype().as_nullable(); - let new_validity = combine_validity(array.validity(), mask, len); + let new_validity = combine_validity(array.validity(), mask, len, ctx)?; // SAFETY: We're only changing validity, not the data structure - unsafe { + Ok(unsafe { VarBinViewArray::new_handle_unchecked( array.views_handle().clone(), array.buffers().clone(), dtype, new_validity, ) - } + }) } -fn mask_validity_listview(array: ListViewArray, mask: &Mask) -> ListViewArray { +fn mask_validity_listview( + array: ListViewArray, + mask: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { let len = array.len(); - let new_validity = combine_validity(array.validity(), mask, len); + let new_validity = combine_validity(array.validity(), mask, len, ctx)?; // SAFETY: We're only changing validity, not the data structure - unsafe { + Ok(unsafe { ListViewArray::new_unchecked( array.elements().clone(), array.offsets().clone(), array.sizes().clone(), new_validity, ) - } + }) } -fn mask_validity_fixed_size_list(array: FixedSizeListArray, mask: &Mask) -> FixedSizeListArray { +fn mask_validity_fixed_size_list( + array: FixedSizeListArray, + mask: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { let len = array.len(); let list_size = array.list_size(); - let new_validity = combine_validity(array.validity(), mask, len); + let new_validity = combine_validity(array.validity(), mask, len, ctx)?; // SAFETY: We're only changing validity, not the data structure - unsafe { + Ok(unsafe { FixedSizeListArray::new_unchecked(array.elements().clone(), list_size, new_validity, len) - } + }) } -fn mask_validity_struct(array: StructArray, mask: &Mask) -> StructArray { +fn mask_validity_struct( + array: StructArray, + mask: &Mask, + ctx: &mut ExecutionCtx, +) -> VortexResult { let len = array.len(); - let new_validity = combine_validity(array.validity(), mask, len); + let new_validity = combine_validity(array.validity(), mask, len, ctx)?; let fields = array.unmasked_fields().clone(); let struct_fields = array.struct_fields().clone(); // SAFETY: We're only changing validity, not the data structure - unsafe { StructArray::new_unchecked(fields, struct_fields, len, new_validity) } + Ok(unsafe { StructArray::new_unchecked(fields, struct_fields, len, new_validity) }) } fn mask_validity_extension( diff --git a/vortex-array/src/arrays/masked/tests.rs b/vortex-array/src/arrays/masked/tests.rs index 84a41586173..6e3452b3568 100644 --- a/vortex-array/src/arrays/masked/tests.rs +++ b/vortex-array/src/arrays/masked/tests.rs @@ -7,7 +7,9 @@ use vortex_error::VortexResult; use super::*; use crate::DynArray; use crate::IntoArray; +use crate::LEGACY_SESSION; use crate::ToCanonical as _; +use crate::VortexSessionExecute; use crate::arrays::PrimitiveArray; use crate::assert_arrays_eq; use crate::dtype::DType; @@ -97,5 +99,13 @@ fn test_masked_child_preserves_length(#[case] validity: Validity) { let array = MaskedArray::try_new(child, validity.clone()).unwrap(); assert_eq!(array.len(), len); - assert_eq!(array.validity_mask().unwrap(), validity.to_mask(len)); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert!( + array + .validity() + .unwrap() + .mask_eq(&validity, &mut ctx) + .unwrap(), + ); } diff --git a/vortex-array/src/arrays/primitive/array/cast.rs b/vortex-array/src/arrays/primitive/array/cast.rs index b72d5ff8640..33fbb3f3a09 100644 --- a/vortex-array/src/arrays/primitive/array/cast.rs +++ b/vortex-array/src/arrays/primitive/array/cast.rs @@ -171,7 +171,7 @@ mod tests { result.dtype(), &DType::Primitive(PType::U8, Nullability::Nullable) ); - assert_eq!(result.validity, Validity::AllInvalid); + assert!(matches!(result.validity, Validity::AllInvalid)); } #[rstest] @@ -254,7 +254,7 @@ mod tests { let array2 = PrimitiveArray::new(Buffer::::empty(), Validity::NonNullable); let result2 = array2.narrow().unwrap(); // Empty arrays should not have their validity changed - assert_eq!(result.validity, Validity::AllInvalid); - assert_eq!(result2.validity, Validity::NonNullable); + assert!(matches!(result.validity, Validity::AllInvalid)); + assert!(matches!(result2.validity, Validity::NonNullable)); } } diff --git a/vortex-array/src/arrays/primitive/compute/cast.rs b/vortex-array/src/arrays/primitive/compute/cast.rs index 83e49b86cf5..932d120d50f 100644 --- a/vortex-array/src/arrays/primitive/compute/cast.rs +++ b/vortex-array/src/arrays/primitive/compute/cast.rs @@ -115,6 +115,7 @@ mod test { use crate::validity::Validity; use crate::vtable::ValidityHelper; + #[allow(clippy::cognitive_complexity)] #[test] fn cast_u32_u8() { let arr = buffer![0u32, 10, 200].into_array(); @@ -122,7 +123,7 @@ mod test { // cast from u32 to u8 let p = arr.cast(PType::U8.into()).unwrap().to_primitive(); assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200])); - assert_eq!(p.validity(), &Validity::NonNullable); + assert!(matches!(p.validity(), Validity::NonNullable)); // to nullable let p = p @@ -134,7 +135,7 @@ mod test { p, PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid) ); - assert_eq!(p.validity(), &Validity::AllValid); + assert!(matches!(p.validity(), Validity::AllValid)); // back to non-nullable let p = p @@ -143,7 +144,7 @@ mod test { .unwrap() .to_primitive(); assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200])); - assert_eq!(p.validity(), &Validity::NonNullable); + assert!(matches!(p.validity(), Validity::NonNullable)); // to nullable u32 let p = p @@ -155,7 +156,7 @@ mod test { p, PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid) ); - assert_eq!(p.validity(), &Validity::AllValid); + assert!(matches!(p.validity(), Validity::AllValid)); // to non-nullable u8 let p = p @@ -164,7 +165,7 @@ mod test { .unwrap() .to_primitive(); assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200])); - assert_eq!(p.validity(), &Validity::NonNullable); + assert!(matches!(p.validity(), Validity::NonNullable)); } #[test] diff --git a/vortex-array/src/arrays/struct_/compute/zip.rs b/vortex-array/src/arrays/struct_/compute/zip.rs index 27bae02ef45..425d07a28a3 100644 --- a/vortex-array/src/arrays/struct_/compute/zip.rs +++ b/vortex-array/src/arrays/struct_/compute/zip.rs @@ -47,8 +47,8 @@ impl ZipKernel for Struct { (v1, v2) => { let mask_mask = mask.try_to_mask_fill_null_false(ctx)?; - let v1m = v1.to_mask(if_true.len()); - let v2m = v2.to_mask(if_false.len()); + let v1m = v1.execute_mask(if_true.len(), ctx)?; + let v2m = v2.execute_mask(if_false.len(), ctx)?; let combined = (v1m.bitand(&mask_mask)).bitor(&v2m.bitand(&mask_mask.not())); Validity::from_mask( diff --git a/vortex-array/src/arrays/varbin/array.rs b/vortex-array/src/arrays/varbin/array.rs index 9c72e80487d..5c08545d6ba 100644 --- a/vortex-array/src/arrays/varbin/array.rs +++ b/vortex-array/src/arrays/varbin/array.rs @@ -186,7 +186,7 @@ impl VarBinArray { // Check nullability matches vortex_ensure!( - dtype.is_nullable() != (validity == &Validity::NonNullable), + dtype.is_nullable() != matches!(validity, Validity::NonNullable), InvalidArgument: "incorrect validity {:?} for dtype {}", validity, dtype diff --git a/vortex-array/src/builders/bool.rs b/vortex-array/src/builders/bool.rs index eb12eb6f166..cd8001948d1 100644 --- a/vortex-array/src/builders/bool.rs +++ b/vortex-array/src/builders/bool.rs @@ -188,15 +188,20 @@ mod tests { let chunk_count = 10; let chunk = make_opt_bool_chunks(len, chunk_count); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); let mut builder = builder_with_capacity(chunk.dtype(), len * chunk_count); chunk .clone() - .append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?; + .append_to_builder(builder.as_mut(), &mut ctx)?; let canon_into = builder.finish().to_bool(); let into_canon = chunk.to_bool(); - assert_eq!(canon_into.validity(), into_canon.validity()); + assert!( + canon_into + .validity() + .mask_eq(into_canon.validity(), &mut ctx)? + ); assert_eq!(canon_into.to_bit_buffer(), into_canon.to_bit_buffer()); Ok(()) } diff --git a/vortex-array/src/builders/list.rs b/vortex-array/src/builders/list.rs index 07838775e73..37c44b3cada 100644 --- a/vortex-array/src/builders/list.rs +++ b/vortex-array/src/builders/list.rs @@ -314,6 +314,7 @@ mod tests { use vortex_buffer::buffer; use crate::IntoArray; + use crate::LEGACY_SESSION; use crate::ToCanonical; use crate::array::DynArray; use crate::arrays::ChunkedArray; @@ -326,6 +327,7 @@ mod tests { use crate::dtype::IntegerPType; use crate::dtype::Nullability; use crate::dtype::PType::I32; + use crate::executor::VortexSessionExecute; use crate::scalar::Scalar; use crate::validity::Validity; use crate::vtable::ValidityHelper; @@ -436,8 +438,9 @@ mod tests { .unwrap(); assert_eq!(list.len(), 3); - let mut builder = ListBuilder::::with_capacity(Arc::new(I32.into()), Nullable, 18, 9); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mut builder = ListBuilder::::with_capacity(Arc::new(I32.into()), Nullable, 18, 9); builder.extend_from_array(&list); builder.extend_from_array(&list); builder.extend_from_array(&list.slice(0..0).unwrap()); @@ -465,7 +468,12 @@ mod tests { assert_arrays_eq!(actual.offsets(), expected.offsets()); - assert_eq!(actual.validity(), expected.validity()) + assert!( + actual + .validity() + .mask_eq(expected.validity(), &mut ctx) + .unwrap(), + ); } #[test] diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 6a512d2748c..dda5aa1e7c9 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -34,6 +34,7 @@ use crate::compute::is_sorted; use crate::dtype::DType; use crate::dtype::IntegerPType; use crate::dtype::NativePType; +use crate::dtype::Nullability; use crate::dtype::Nullability::NonNullable; use crate::dtype::PType; use crate::dtype::UnsignedPType; @@ -829,9 +830,16 @@ impl Patches { let Some((new_sparse_indices, value_indices)) = match_each_unsigned_integer_ptype!(indices.ptype(), |Indices| { match_each_integer_ptype!(take_indices.ptype(), |TakeIndices| { + let take_validity = take_indices + .validity() + .execute_mask(take_indices.len(), ctx)?; + let take_nullability = take_indices.validity().nullability(); + let take_slice = take_indices.as_slice::(); take_map::<_, TakeIndices>( indices.as_slice::(), - take_indices, + take_slice, + take_validity, + take_nullability, self.offset(), min_index, max_index, @@ -1007,9 +1015,12 @@ unsafe fn apply_patches_to_buffer_inner( } } +#[allow(clippy::too_many_arguments)] // private function, can clean up one day fn take_map, T: NativePType>( indices: &[I], - take_indices: PrimitiveArray, + take_indices: &[T], + take_validity: Mask, + take_nullability: Nullability, indices_offset: usize, min_index: usize, max_index: usize, @@ -1019,10 +1030,6 @@ where usize: TryFrom, VortexError: From<>::Error>, { - let take_indices_len = take_indices.len(); - let take_indices_validity = take_indices.validity(); - let take_indices_validity_mask = take_indices_validity.to_mask(take_indices_len); - let take_indices = take_indices.as_slice::(); let offset_i = I::try_from(indices_offset)?; let sparse_index_to_value_index: HashMap = indices @@ -1041,7 +1048,7 @@ where .map_err(|_| vortex_err!("Failed to convert index to usize"))?; // If we have to take nulls the take index doesn't matter, make it 0 for consistency - let is_null = match take_indices_validity_mask.bit_buffer() { + let is_null = match take_validity.bit_buffer() { AllOr::All => false, AllOr::None => true, AllOr::Some(buf) => !buf.value(idx_in_take), @@ -1066,7 +1073,8 @@ where } let new_sparse_indices = new_sparse_indices.into_array(); - let values_validity = take_indices_validity.take(&new_sparse_indices)?; + let values_validity = + Validity::from_mask(take_validity, take_nullability).take(&new_sparse_indices)?; Ok(Some(( new_sparse_indices, PrimitiveArray::new(value_indices, values_validity).into_array(), diff --git a/vortex-array/src/scalar_fn/fns/pack.rs b/vortex-array/src/scalar_fn/fns/pack.rs index a3c4ac2bad1..42741d2476f 100644 --- a/vortex-array/src/scalar_fn/fns/pack.rs +++ b/vortex-array/src/scalar_fn/fns/pack.rs @@ -231,7 +231,7 @@ mod tests { let actual_array = test_array().apply(&expr).unwrap().to_struct(); assert_eq!(actual_array.names(), ["one", "two", "three"]); - assert_eq!(actual_array.validity(), &Validity::NonNullable); + assert!(matches!(actual_array.validity(), Validity::NonNullable)); assert_arrays_eq!( primitive_field(&actual_array.clone().into_array(), &["one"]).unwrap(), @@ -302,7 +302,7 @@ mod tests { let actual_array = test_array().apply(&expr).unwrap().to_struct(); assert_eq!(actual_array.names(), ["one", "two", "three"]); - assert_eq!(actual_array.validity(), &Validity::AllValid); + assert!(matches!(actual_array.validity(), Validity::AllValid)); } #[test] diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 2b88a142144..6b7b0ad820a 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -21,12 +21,10 @@ use crate::Canonical; use crate::DynArray; use crate::ExecutionCtx; use crate::IntoArray; -use crate::ToCanonical; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; use crate::arrays::scalar_fn::ScalarFnArrayExt; use crate::builtins::ArrayBuiltins; -use crate::compute::sum; use crate::dtype::DType; use crate::dtype::Nullability; use crate::optimizer::ArrayOptimizer; @@ -122,34 +120,6 @@ impl Validity { } } - #[inline] - pub fn all_valid(&self, len: usize) -> VortexResult { - Ok(match self { - _ if len == 0 => true, - Validity::NonNullable | Validity::AllValid => true, - Validity::AllInvalid => false, - Validity::Array(array) => { - usize::try_from(&sum(array).vortex_expect("must have sum for bool array")) - .vortex_expect("sum must be a usize") - == array.len() - } - }) - } - - #[inline] - pub fn all_invalid(&self, len: usize) -> VortexResult { - Ok(match self { - _ if len == 0 => true, - Validity::NonNullable | Validity::AllValid => false, - Validity::AllInvalid => true, - Validity::Array(array) => { - usize::try_from(&sum(array).vortex_expect("must have sum for bool array")) - .vortex_expect("sum must be a usize") - == 0 - } - }) - } - /// Returns whether the `index` item is valid. #[inline] pub fn is_valid(&self, index: usize) -> VortexResult { @@ -234,21 +204,37 @@ impl Validity { } } - #[inline] - pub fn to_mask(&self, length: usize) -> Mask { + pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult { match self { - Self::NonNullable | Self::AllValid => Mask::AllTrue(length), - Self::AllInvalid => Mask::AllFalse(length), - Self::Array(is_valid) => { + Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)), + Self::AllInvalid => Ok(Mask::AllFalse(length)), + Self::Array(arr) => { assert_eq!( - is_valid.len(), + arr.len(), length, "Validity::Array length must equal to_logical's argument: {}, {}.", - is_valid.len(), + arr.len(), length, ); - is_valid.to_bool().to_mask() + // TODO(ngates): I'm not sure execution should take arrays by ownership. + // If so we should fix call sites to clone and this function takes self. + arr.clone().execute::(ctx) + } + } + } + + /// Compare two Validity values of the same length by executing them into masks if necessary. + pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult { + match (self, other) { + (Validity::NonNullable, Validity::NonNullable) => Ok(true), + (Validity::AllValid, Validity::AllValid) => Ok(true), + (Validity::AllInvalid, Validity::AllInvalid) => Ok(true), + (Validity::Array(a), Validity::Array(b)) => { + let a = a.clone().execute::(ctx)?; + let b = b.clone().execute::(ctx)?; + Ok(a == b) } + _ => Ok(false), } } @@ -299,7 +285,7 @@ impl Validity { _ => {} }; - let own_nullability = if self == Validity::NonNullable { + let own_nullability = if matches!(self, Validity::NonNullable) { Nullability::NonNullable } else { Nullability::Nullable @@ -418,23 +404,6 @@ impl Validity { } } -impl PartialEq for Validity { - #[inline] - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::NonNullable, Self::NonNullable) => true, - (Self::AllValid, Self::AllValid) => true, - (Self::AllInvalid, Self::AllInvalid) => true, - (Self::Array(a), Self::Array(b)) => { - let a = a.to_bool(); - let b = b.to_bool(); - a.to_bit_buffer() == b.to_bit_buffer() - } - _ => false, - } - } -} - impl From for Validity { #[inline] fn from(value: BitBuffer) -> Self { @@ -603,17 +572,21 @@ mod tests { ) { let indices = PrimitiveArray::new(Buffer::copy_from(positions), Validity::NonNullable).into_array(); - assert_eq!( + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + assert!( validity .patch( len, 0, &indices, &patches, - &mut LEGACY_SESSION.create_execution_ctx() + &mut LEGACY_SESSION.create_execution_ctx(), ) - .unwrap(), - expected + .unwrap() + .mask_eq(&expected, &mut ctx) + .unwrap() ); } @@ -671,6 +644,13 @@ mod tests { #[case] indices: ArrayRef, #[case] expected: Validity, ) { - assert_eq!(validity.take(&indices).unwrap(), expected); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert!( + validity + .take(&indices) + .unwrap() + .mask_eq(&expected, &mut ctx) + .unwrap() + ); } } diff --git a/vortex-cuda/src/kernel/encodings/zstd.rs b/vortex-cuda/src/kernel/encodings/zstd.rs index 81472aebb39..682bba222c3 100644 --- a/vortex-cuda/src/kernel/encodings/zstd.rs +++ b/vortex-cuda/src/kernel/encodings/zstd.rs @@ -205,7 +205,7 @@ impl CudaExecute for ZstdExecutor { dtype = %_other, "Only Binary/Utf8 ZSTD arrays supported on GPU, falling back to CPU" ); - zstd.decompress()?.to_canonical() + zstd.decompress(ctx.execution_ctx())?.to_canonical() } } } @@ -313,14 +313,17 @@ async fn decode_zstd(array: ZstdArray, ctx: &mut CudaExecutionCtx) -> VortexResu .await?; let slice_value_indices = validity - .to_mask(n_rows) + .execute_mask(n_rows, ctx.execution_ctx())? .valid_counts_for_indices(&[slice_start, slice_stop]); let slice_value_idx_start = slice_value_indices[0]; let slice_value_idx_stop = slice_value_indices[1]; let sliced_validity = validity.slice(slice_start..slice_stop)?; - match sliced_validity.to_mask(slice_stop - slice_start).indices() { + match sliced_validity + .execute_mask(slice_stop - slice_start, ctx.execution_ctx())? + .indices() + { AllOr::All => { let all_views = vortex::encodings::zstd::reconstruct_views(&host_buffer); let sliced_views = all_views.slice(slice_value_idx_start..slice_value_idx_stop); @@ -368,7 +371,9 @@ mod tests { let zstd_array = ZstdArray::from_var_bin_view(&strings, 3, 0)?; - let cpu_result = zstd_array.decompress()?.to_canonical()?; + let cpu_result = zstd_array + .decompress(cuda_ctx.execution_ctx())? + .to_canonical()?; let gpu_result = ZstdExecutor .execute(zstd_array.into_array(), &mut cuda_ctx) .await?; @@ -403,7 +408,9 @@ mod tests { // 14 strings and 3 values per frame = ceil(14/3) = 5 frames. let zstd_array = ZstdArray::from_var_bin_view(&strings, 3, 3)?; - let cpu_result = zstd_array.decompress()?.to_canonical()?; + let cpu_result = zstd_array + .decompress(cuda_ctx.execution_ctx())? + .to_canonical()?; let gpu_result = ZstdExecutor .execute(zstd_array.into_array(), &mut cuda_ctx) .await?; diff --git a/vortex-duckdb/src/convert/vector.rs b/vortex-duckdb/src/convert/vector.rs index e5c948f7ed4..d6b008efe64 100644 --- a/vortex-duckdb/src/convert/vector.rs +++ b/vortex-duckdb/src/convert/vector.rs @@ -29,6 +29,7 @@ use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_bail; use vortex::extension::datetime::TimeUnit; +use vortex::mask::Mask; use crate::cpp::DUCKDB_TYPE; use crate::cpp::duckdb_date; @@ -149,14 +150,14 @@ fn convert_valid_list_entry( /// downstream operations like converting `ListView` to `List`. fn process_duckdb_lists( entries: &[duckdb_list_entry], - validity: &Validity, -) -> (Buffer, Buffer, usize) { + validity: &Mask, +) -> VortexResult<(Buffer, Buffer, usize)> { let len = entries.len(); let mut offsets = BufferMut::with_capacity(len); let mut sizes = BufferMut::with_capacity(len); match validity { - Validity::NonNullable | Validity::AllValid => { + Mask::AllTrue(_) => { // All entries are valid, so there is no need to check the validity. let mut child_min_length = 0; let mut previous_end = 0; @@ -170,41 +171,37 @@ fn process_duckdb_lists( sizes.push_unchecked(size); } } - (offsets.freeze(), sizes.freeze(), child_min_length) + Ok((offsets.freeze(), sizes.freeze(), child_min_length)) } - Validity::AllInvalid => { + Mask::AllFalse(_) => { // All entries are null, so we can just set offset=0 and size=0. // SAFETY: We allocated enough capacity above. unsafe { offsets.push_n_unchecked(0, len); sizes.push_n_unchecked(0, len); } - (offsets.freeze(), sizes.freeze(), 0) + Ok((offsets.freeze(), sizes.freeze(), 0)) } - Validity::Array(_) => { + Mask::Values(values) => { // We have some number of nulls, so make sure to check validity before updating info. - let mask = validity.to_mask(len); - let child_min_length = mask.iter_bools(|validity_iter| { - let mut child_min_length = 0; - let mut previous_end = 0; - - for (entry, is_valid) in entries.iter().zip(validity_iter) { - let (offset, size) = if is_valid { - convert_valid_list_entry(entry, &mut child_min_length, &mut previous_end) - } else { - (previous_end, 0) - }; - - // SAFETY: We allocated enough capacity above. - unsafe { - offsets.push_unchecked(offset); - sizes.push_unchecked(size); - } + let mut child_min_length = 0; + let mut previous_end = 0; + + for (entry, is_valid) in entries.iter().zip(values.bit_buffer().iter()) { + let (offset, size) = if is_valid { + convert_valid_list_entry(entry, &mut child_min_length, &mut previous_end) + } else { + (previous_end, 0) + }; + + // SAFETY: We allocated enough capacity above. + unsafe { + offsets.push_unchecked(offset); + sizes.push_unchecked(size); } + } - child_min_length - }); - (offsets.freeze(), sizes.freeze(), child_min_length) + Ok((offsets.freeze(), sizes.freeze(), child_min_length)) } } } @@ -321,10 +318,10 @@ pub fn flat_vector_to_vortex(vector: &VectorRef, len: usize) -> VortexResult { - let validity = vector.validity_ref(len).to_validity(); + let validity = vector.validity_ref(len).to_mask(); let entries = vector.as_slice_with_len::(len); - let (offsets, sizes, child_min_length) = process_duckdb_lists(entries, &validity); + let (offsets, sizes, child_min_length) = process_duckdb_lists(entries, &validity)?; let child_data = flat_vector_to_vortex(vector.list_vector_get_child(), child_min_length)?; @@ -332,7 +329,7 @@ pub fn flat_vector_to_vortex(vector: &VectorRef, len: usize) -> VortexResult { (validity_entry & (1u64 << idx_in_entry)) != 0 } - /// Creates a Validity directly from the DuckDB validity mask for optimal performance. - pub fn to_validity(&self) -> Validity { + /// Creates a mask directly from the DuckDB validity mask for optimal performance. + pub fn to_mask(&self) -> Mask { let Some(validity) = self.validity else { // All values are valid - return Validity::AllValid; + return Mask::AllTrue(self.len); }; - Validity::from(BitBuffer::new( + Mask::from_buffer(BitBuffer::new( Buffer::::copy_from(validity).into_byte_buffer(), self.len, )) } + + pub fn to_validity(&self) -> Validity { + Validity::from_mask(self.to_mask(), Nullability::Nullable) + } } #[cfg(test)] mod tests { use vortex::mask::Mask; + use vortex_array::LEGACY_SESSION; + use vortex_array::VortexSessionExecute; use super::*; use crate::cpp::DUCKDB_TYPE; @@ -351,9 +359,8 @@ mod tests { let validity = vector.validity_ref(len); let validity = validity.to_validity(); - assert_eq!( - validity, - Validity::AllValid, + assert!( + matches!(validity, Validity::AllValid), "Expected None for all-valid vector" ); } @@ -373,8 +380,9 @@ mod tests { let validity = validity.to_validity(); assert_eq!(validity.maybe_len(), Some(len)); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); assert_eq!( - validity.to_mask(len), + validity.execute_mask(len, &mut ctx).unwrap(), Mask::from_indices(len, vec![0, 2, 4, 5, 6, 8, 9]) ); } @@ -414,7 +422,7 @@ mod tests { let validity = vector.validity_ref(len); let validity = validity.to_validity(); - assert_eq!(validity, Validity::AllValid); + assert!(matches!(validity, Validity::AllValid)); } #[test] @@ -430,7 +438,7 @@ mod tests { let validity = vector.validity_ref(len); let validity = validity.to_validity(); - assert_eq!(validity, Validity::AllInvalid); + assert!(matches!(validity, Validity::AllInvalid)); } #[test] diff --git a/vortex-duckdb/src/exporter/mod.rs b/vortex-duckdb/src/exporter/mod.rs index d777ce3bb54..f21d4882a8d 100644 --- a/vortex-duckdb/src/exporter/mod.rs +++ b/vortex-duckdb/src/exporter/mod.rs @@ -57,8 +57,9 @@ impl ArrayExporter { cache: &ConversionCache, mut ctx: ExecutionCtx, ) -> VortexResult { - let all_valid = array.validity().all_valid(array.len())?; - assert!(all_valid); + let validity = array.validity().execute_mask(array.len(), &mut ctx)?; + assert!(validity.all_true()); + let fields = array .unmasked_fields() .iter() diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index 18c8b5bbb05..a532fc1adad 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -186,7 +186,8 @@ mod test { use vortex_array::ArrayRef; use vortex_array::IntoArray; - use vortex_array::ToCanonical; + use vortex_array::LEGACY_SESSION; + use vortex_array::VortexSessionExecute; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::dtype::FieldNames; @@ -332,8 +333,15 @@ mod test { .await?; assert_eq!(recovered_array.len(), array.len()); - let recovered_primitive = recovered_array.to_primitive(); - assert_eq!(recovered_primitive.validity(), array.validity()); + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + let recovered_primitive = recovered_array.execute::(&mut ctx)?; + assert!( + recovered_primitive + .validity() + .mask_eq(array.validity(), &mut ctx)? + ); assert_eq!( recovered_primitive.to_buffer::(), array.to_buffer::()