diff --git a/datafusion/functions-aggregate/benches/approx_distinct.rs b/datafusion/functions-aggregate/benches/approx_distinct.rs index 44b45431e3eb1..4608c39d548b9 100644 --- a/datafusion/functions-aggregate/benches/approx_distinct.rs +++ b/datafusion/functions-aggregate/benches/approx_distinct.rs @@ -19,10 +19,11 @@ use std::hint::black_box; use std::sync::Arc; use arrow::array::{ - ArrayRef, Int8Array, Int16Array, Int64Array, StringArray, StringViewArray, - UInt8Array, UInt16Array, + ArrayRef, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + Int8Array, Int16Array, Int64Array, StringArray, StringViewArray, UInt8Array, + UInt16Array, }; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema, i256}; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ @@ -44,6 +45,12 @@ const N_GROUPS: usize = 50_000; const AVG_ROWS_PER_GROUP: usize = 8; const STRING_POOL_SIZE: usize = 100_000; +const DECIMAL32_PRECISION: u8 = 9; +const DECIMAL64_PRECISION: u8 = 18; +const DECIMAL128_PRECISION: u8 = 10; +const DECIMAL256_PRECISION: u8 = 40; +const DECIMAL_SCALE: i8 = 2; + fn prepare_accumulator(data_type: DataType) -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", data_type, true)])); let expr = col("f", &schema).unwrap(); @@ -61,6 +68,52 @@ fn prepare_accumulator(data_type: DataType) -> Box { ApproxDistinct::new().accumulator(accumulator_args).unwrap() } +/// Creates a `Decimal32Array` from a pool of `n_distinct` values. +fn create_decimal32_array(n_distinct: usize) -> Decimal32Array { + let mut rng = StdRng::seed_from_u64(42); + let pool: Vec = (0..n_distinct).map(|i| i as i32 * 50).collect(); + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())])) + .collect::() + .with_precision_and_scale(DECIMAL32_PRECISION, DECIMAL_SCALE) + .unwrap() +} + +/// Creates a `Decimal64Array` from a pool of `n_distinct` values. +fn create_decimal64_array(n_distinct: usize) -> Decimal64Array { + let mut rng = StdRng::seed_from_u64(42); + let pool: Vec = (0..n_distinct).map(|i| i as i64 * 50).collect(); + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())])) + .collect::() + .with_precision_and_scale(DECIMAL64_PRECISION, DECIMAL_SCALE) + .unwrap() +} + +/// Creates a `Decimal128Array` from a pool of `n_distinct` values. +fn create_decimal128_array(n_distinct: usize) -> Decimal128Array { + let mut rng = StdRng::seed_from_u64(42); + let pool: Vec = (0..n_distinct).map(|i| i as i128 * 50).collect(); + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())])) + .collect::() + .with_precision_and_scale(DECIMAL128_PRECISION, DECIMAL_SCALE) + .unwrap() +} + +/// Creates a `Decimal256Array` from a pool of `n_distinct` values. +fn create_decimal256_array(n_distinct: usize) -> Decimal256Array { + let mut rng = StdRng::seed_from_u64(42); + let pool: Vec = (0..n_distinct) + .map(|i| i256::from_i128(i as i128 * 50)) + .collect(); + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())])) + .collect::() + .with_precision_and_scale(DECIMAL256_PRECISION, DECIMAL_SCALE) + .unwrap() +} + /// Creates an Int64Array where values are drawn from `0..n_distinct`. fn create_i64_array(n_distinct: usize) -> Int64Array { let mut rng = StdRng::seed_from_u64(42); @@ -224,6 +277,62 @@ fn approx_distinct_benchmark(c: &mut Criterion) { .unwrap() }) }); + + // Decimal32 + let values = Arc::new(create_decimal32_array(200)) as ArrayRef; + c.bench_function("approx_distinct decimal32", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Decimal32( + DECIMAL32_PRECISION, + DECIMAL_SCALE, + )); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // Decimal64 + let values = Arc::new(create_decimal64_array(200)) as ArrayRef; + c.bench_function("approx_distinct decimal64", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Decimal64( + DECIMAL64_PRECISION, + DECIMAL_SCALE, + )); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // Decimal128 + let values = Arc::new(create_decimal128_array(200)) as ArrayRef; + c.bench_function("approx_distinct decimal128", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Decimal128( + DECIMAL128_PRECISION, + DECIMAL_SCALE, + )); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // Decimal256 + let values = Arc::new(create_decimal256_array(200)) as ArrayRef; + c.bench_function("approx_distinct decimal256", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Decimal256( + DECIMAL256_PRECISION, + DECIMAL_SCALE, + )); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); } /// Build a `GroupsAccumulator` the same way the aggregate operator does: use the @@ -287,6 +396,34 @@ fn build_grouped_batches(data_type: &DataType) -> Vec<(ArrayRef, Vec)> { .map(|_| Some(pool[rng.random_range(0..pool.len())].as_str())) .collect::(), ), + DataType::Decimal32(p, s) => Arc::new( + (0..BATCH_SIZE) + .map(|_| Some(rng.random::())) + .collect::() + .with_precision_and_scale(*p, *s) + .unwrap(), + ), + DataType::Decimal64(p, s) => Arc::new( + (0..BATCH_SIZE) + .map(|_| Some(rng.random::())) + .collect::() + .with_precision_and_scale(*p, *s) + .unwrap(), + ), + DataType::Decimal128(p, s) => Arc::new( + (0..BATCH_SIZE) + .map(|_| Some(rng.random::() as i128)) + .collect::() + .with_precision_and_scale(*p, *s) + .unwrap(), + ), + DataType::Decimal256(p, s) => Arc::new( + (0..BATCH_SIZE) + .map(|_| Some(i256::from_i128(rng.random::() as i128))) + .collect::() + .with_precision_and_scale(*p, *s) + .unwrap(), + ), other => panic!("unsupported grouped bench type: {other}"), }; (values, group_indices) @@ -300,7 +437,15 @@ fn approx_distinct_grouped_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("approx_distinct_grouped"); group.sample_size(10); - for data_type in [DataType::Int64, DataType::Utf8, DataType::Utf8View] { + for data_type in [ + DataType::Int64, + DataType::Utf8, + DataType::Utf8View, + DataType::Decimal32(DECIMAL32_PRECISION, DECIMAL_SCALE), + DataType::Decimal64(DECIMAL64_PRECISION, DECIMAL_SCALE), + DataType::Decimal128(DECIMAL128_PRECISION, DECIMAL_SCALE), + DataType::Decimal256(DECIMAL256_PRECISION, DECIMAL_SCALE), + ] { let batches = build_grouped_batches(&data_type); let label = format!("{data_type:?} {N_GROUPS} groups"); group.bench_function(&label, |b| { diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 90cc8d0630af7..1062a478b7bea 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -24,9 +24,10 @@ use arrow::array::{ }; use arrow::buffer::NullBuffer; use arrow::datatypes::{ - ArrowPrimitiveType, DataType, Date32Type, Date64Type, Field, FieldRef, Int32Type, - Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type, + Decimal128Type, Decimal256Type, Field, FieldRef, Int32Type, Int64Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt32Type, UInt64Type, }; use datafusion_common::ScalarValue; @@ -758,6 +759,18 @@ impl AggregateUDFImpl for ApproxDistinct { DataType::Timestamp(TimeUnit::Nanosecond, _) => { Box::new(NumericHLLAccumulator::::new()) } + DataType::Decimal32(_, _) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Decimal64(_, _) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Decimal128(_, _) => { + Box::new(NumericHLLAccumulator::::new()) + } + DataType::Decimal256(_, _) => { + Box::new(NumericHLLAccumulator::::new()) + } DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View @@ -818,6 +831,10 @@ fn is_hll_groups_type(data_type: &DataType) -> bool { | DataType::Timestamp(TimeUnit::Millisecond, _) | DataType::Timestamp(TimeUnit::Microsecond, _) | DataType::Timestamp(TimeUnit::Nanosecond, _) + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) | DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View @@ -834,7 +851,11 @@ mod tests { #[cfg(not(feature = "force_hash_collisions"))] mod real_hash_test { use super::*; - use arrow::array::{AsArray, Int64Array, StringViewArray}; + use arrow::array::{ + AsArray, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + Int64Array, StringViewArray, + }; + use arrow::datatypes::i256; use std::sync::Arc; // A string longer than the 12-byte inline limit const LONG: &str = "this string is definitely longer than twelve bytes"; @@ -846,6 +867,134 @@ mod tests { } } + fn assert_count_numerical_acc_and_group_acc(array: ArrayRef, expected: u64) + where + T: ArrowPrimitiveType + Debug, + T::Native: Hash, + { + assert!( + is_hll_groups_type(array.data_type()), + "{} should be groups-capable", + array.data_type() + ); + + let mut acc = NumericHLLAccumulator::::new(); + acc.update_batch(&[Arc::clone(&array)]).unwrap(); + let per_group_count = match acc.evaluate().unwrap() { + ScalarValue::UInt64(Some(v)) => v, + other => panic!("unexpected evaluate result: {other:?}"), + }; + + let group_indices = vec![0usize; array.len()]; + let mut acc = HllGroupsAccumulator::new(); + acc.update_batch(std::slice::from_ref(&array), &group_indices, None, 1) + .unwrap(); + let groups_count = acc + .evaluate(EmitTo::All) + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + + assert_eq!( + per_group_count, + groups_count, + "paths disagree for {}", + array.data_type() + ); + assert_eq!( + per_group_count, + expected, + "wrong count for {}", + array.data_type() + ); + } + + #[test] + fn decimal_support_numerical_acc_and_group_acc() { + let decimal_32: ArrayRef = Arc::new( + Decimal32Array::from(vec![ + 1i32, + 2, + 2, + 3, + 3, + 3, + 0, + 0, + 123_456_789, + 999_999_999, + 999_999_999, + ]) + .with_precision_and_scale(9, 2) + .unwrap(), + ); + assert_count_numerical_acc_and_group_acc::(decimal_32, 6); + + let decimal_64: ArrayRef = Arc::new( + Decimal64Array::from(vec![ + 1i64, + 2, + 2, + 3, + 3, + 3, + 0, + 0, + 1_234_567_890_123, + 9_999_999_999_999, + 9_999_999_999_999, + ]) + .with_precision_and_scale(18, 2) + .unwrap(), + ); + assert_count_numerical_acc_and_group_acc::(decimal_64, 6); + + let decimal_128: ArrayRef = Arc::new( + Decimal128Array::from(vec![ + 1i128, + 2, + 2, + 3, + 3, + 3, + 0, + 0, + 1_234_567_890, + 9_999_999_999, + 9_999_999_999, + ]) + .with_precision_and_scale(38, 2) + .unwrap(), + ); + assert_count_numerical_acc_and_group_acc::(decimal_128, 6); + + let big_256_a = + i256::from_string("123456789012345678901234567890123456").unwrap(); + let big_256_b = + i256::from_string("987654321098765432109876543210987654").unwrap(); + + let decimal_256: ArrayRef = Arc::new( + Decimal256Array::from(vec![ + i256::from_i128(1), + i256::from_i128(2), + i256::from_i128(2), + i256::from_i128(3), + i256::from_i128(3), + i256::from_i128(3), + i256::from_i128(0), + i256::from_i128(0), + big_256_a, + big_256_b, + big_256_b, + ]) + .with_precision_and_scale(40, 2) + .unwrap(), + ); + assert_count_numerical_acc_and_group_acc::(decimal_256, 6); + } + /// `approx_distinct(v) FILTER (WHERE nullable_bool)` — a NULL filter row /// must not be counted (null filter is treated the same as false). #[test] diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 18c09acf08887..b5e86daf28b61 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1978,6 +1978,35 @@ true statement ok DROP TABLE approx_distinct_dense_test; +# This test runs approx_distinct over decimal128 and decimal256 for the scalar and the grouped path. +statement ok +CREATE TABLE approx_distinct_decimal_test (g INT, dec128 DECIMAL(20, 2), dec256 DECIMAL(40, 2)) AS VALUES + (1, 12345678901234.56, 12345678901234567890123456.78), + (1, 98765432109876.54, 98765432109876543210987654.32), + (1, 98765432109876.54, 98765432109876543210987654.32), + (2, 55555555555555.55, 55555555555555555555555555.55), + (2, -0.0, -0.0), + (2, 0.0, 0.0); + +# Scalar path +query II +SELECT approx_distinct(dec128), approx_distinct(dec256) FROM approx_distinct_decimal_test; +---- +4 4 + +# Grouped path +query III +SELECT g, approx_distinct(dec128), approx_distinct(dec256) +FROM approx_distinct_decimal_test GROUP BY g ORDER BY g; +---- +1 2 2 +2 2 2 + +statement ok +DROP TABLE approx_distinct_decimal_test; + + + ## This test executes the APPROX_PERCENTILE_CONT aggregation against the test ## data, asserting the estimated quantiles are ±5% their actual values. ##