Skip to content

Commit e15371c

Browse files
committed
feat: Support decimal in approx_distinct
1 parent 284ae30 commit e15371c

3 files changed

Lines changed: 216 additions & 9 deletions

File tree

datafusion/functions-aggregate/benches/approx_distinct.rs

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ use std::hint::black_box;
1919
use std::sync::Arc;
2020

2121
use arrow::array::{
22-
ArrayRef, Int8Array, Int16Array, Int64Array, StringArray, StringViewArray,
23-
UInt8Array, UInt16Array,
22+
ArrayRef, Decimal128Array, Decimal256Array, Int8Array, Int16Array, Int64Array,
23+
StringArray, StringViewArray, UInt8Array, UInt16Array,
2424
};
25-
use arrow::datatypes::{DataType, Field, Schema};
25+
use arrow::datatypes::{DataType, Field, Schema, i256};
2626
use criterion::{Criterion, criterion_group, criterion_main};
2727
use datafusion_expr::function::AccumulatorArgs;
2828
use datafusion_expr::{
@@ -61,6 +61,34 @@ fn prepare_accumulator(data_type: DataType) -> Box<dyn Accumulator> {
6161
ApproxDistinct::new().accumulator(accumulator_args).unwrap()
6262
}
6363

64+
const DECIMAL128_PRECISION: u8 = 10;
65+
const DECIMAL256_PRECISION: u8 = 40;
66+
const DECIMAL_SCALE: i8 = 2;
67+
68+
/// Creates a `Decimal128Array` from a pool of `n_distinct` values.
69+
fn create_decimal128_array(n_distinct: usize) -> Decimal128Array {
70+
let mut rng = StdRng::seed_from_u64(42);
71+
let pool: Vec<i128> = (0..n_distinct).map(|i| i as i128 * 50).collect();
72+
(0..BATCH_SIZE)
73+
.map(|_| Some(pool[rng.random_range(0..pool.len())]))
74+
.collect::<Decimal128Array>()
75+
.with_precision_and_scale(DECIMAL128_PRECISION, DECIMAL_SCALE)
76+
.unwrap()
77+
}
78+
79+
/// Creates a `Decimal256Array` from a pool of `n_distinct` values.
80+
fn create_decimal256_array(n_distinct: usize) -> Decimal256Array {
81+
let mut rng = StdRng::seed_from_u64(42);
82+
let pool: Vec<i256> = (0..n_distinct)
83+
.map(|i| i256::from_i128(i as i128 * 50))
84+
.collect();
85+
(0..BATCH_SIZE)
86+
.map(|_| Some(pool[rng.random_range(0..pool.len())]))
87+
.collect::<Decimal256Array>()
88+
.with_precision_and_scale(DECIMAL256_PRECISION, DECIMAL_SCALE)
89+
.unwrap()
90+
}
91+
6492
/// Creates an Int64Array where values are drawn from `0..n_distinct`.
6593
fn create_i64_array(n_distinct: usize) -> Int64Array {
6694
let mut rng = StdRng::seed_from_u64(42);
@@ -224,6 +252,34 @@ fn approx_distinct_benchmark(c: &mut Criterion) {
224252
.unwrap()
225253
})
226254
});
255+
256+
// Decimal128
257+
let values = Arc::new(create_decimal128_array(200)) as ArrayRef;
258+
c.bench_function("approx_distinct decimal128", |b| {
259+
b.iter(|| {
260+
let mut accumulator = prepare_accumulator(DataType::Decimal128(
261+
DECIMAL128_PRECISION,
262+
DECIMAL_SCALE,
263+
));
264+
accumulator
265+
.update_batch(std::slice::from_ref(&values))
266+
.unwrap()
267+
})
268+
});
269+
270+
// Decimal256
271+
let values = Arc::new(create_decimal256_array(200)) as ArrayRef;
272+
c.bench_function("approx_distinct decimal256", |b| {
273+
b.iter(|| {
274+
let mut accumulator = prepare_accumulator(DataType::Decimal256(
275+
DECIMAL256_PRECISION,
276+
DECIMAL_SCALE,
277+
));
278+
accumulator
279+
.update_batch(std::slice::from_ref(&values))
280+
.unwrap()
281+
})
282+
});
227283
}
228284

229285
/// Build a `GroupsAccumulator` the same way the aggregate operator does: use the
@@ -287,6 +343,20 @@ fn build_grouped_batches(data_type: &DataType) -> Vec<(ArrayRef, Vec<usize>)> {
287343
.map(|_| Some(pool[rng.random_range(0..pool.len())].as_str()))
288344
.collect::<StringViewArray>(),
289345
),
346+
DataType::Decimal128(p, s) => Arc::new(
347+
(0..BATCH_SIZE)
348+
.map(|_| Some(rng.random::<i64>() as i128))
349+
.collect::<Decimal128Array>()
350+
.with_precision_and_scale(*p, *s)
351+
.unwrap(),
352+
),
353+
DataType::Decimal256(p, s) => Arc::new(
354+
(0..BATCH_SIZE)
355+
.map(|_| Some(i256::from_i128(rng.random::<i64>() as i128)))
356+
.collect::<Decimal256Array>()
357+
.with_precision_and_scale(*p, *s)
358+
.unwrap(),
359+
),
290360
other => panic!("unsupported grouped bench type: {other}"),
291361
};
292362
(values, group_indices)
@@ -300,7 +370,13 @@ fn approx_distinct_grouped_benchmark(c: &mut Criterion) {
300370
let mut group = c.benchmark_group("approx_distinct_grouped");
301371
group.sample_size(10);
302372

303-
for data_type in [DataType::Int64, DataType::Utf8, DataType::Utf8View] {
373+
for data_type in [
374+
DataType::Int64,
375+
DataType::Utf8,
376+
DataType::Utf8View,
377+
DataType::Decimal128(DECIMAL128_PRECISION, DECIMAL_SCALE),
378+
DataType::Decimal256(DECIMAL256_PRECISION, DECIMAL_SCALE),
379+
] {
304380
let batches = build_grouped_batches(&data_type);
305381
let label = format!("{data_type:?} {N_GROUPS} groups");
306382
group.bench_function(&label, |b| {

datafusion/functions-aggregate/src/approx_distinct.rs

Lines changed: 107 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ use arrow::array::{
2424
};
2525
use arrow::buffer::NullBuffer;
2626
use arrow::datatypes::{
27-
ArrowPrimitiveType, DataType, Date32Type, Date64Type, Field, FieldRef, Int32Type,
28-
Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
29-
Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
30-
TimestampNanosecondType, TimestampSecondType, UInt32Type, UInt64Type,
27+
ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
28+
Field, FieldRef, Int32Type, Int64Type, Time32MillisecondType, Time32SecondType,
29+
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
30+
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt32Type,
31+
UInt64Type,
3132
};
3233
use datafusion_common::ScalarValue;
3334
use datafusion_common::hash_utils::create_hashes;
@@ -758,6 +759,12 @@ impl AggregateUDFImpl for ApproxDistinct {
758759
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
759760
Box::new(NumericHLLAccumulator::<TimestampNanosecondType>::new())
760761
}
762+
DataType::Decimal128(_, _) => {
763+
Box::new(NumericHLLAccumulator::<Decimal128Type>::new())
764+
}
765+
DataType::Decimal256(_, _) => {
766+
Box::new(NumericHLLAccumulator::<Decimal256Type>::new())
767+
}
761768
DataType::Utf8
762769
| DataType::LargeUtf8
763770
| DataType::Utf8View
@@ -818,6 +825,8 @@ fn is_hll_groups_type(data_type: &DataType) -> bool {
818825
| DataType::Timestamp(TimeUnit::Millisecond, _)
819826
| DataType::Timestamp(TimeUnit::Microsecond, _)
820827
| DataType::Timestamp(TimeUnit::Nanosecond, _)
828+
| DataType::Decimal128(_, _)
829+
| DataType::Decimal256(_, _)
821830
| DataType::Utf8
822831
| DataType::LargeUtf8
823832
| DataType::Utf8View
@@ -834,7 +843,10 @@ mod tests {
834843
#[cfg(not(feature = "force_hash_collisions"))]
835844
mod real_hash_test {
836845
use super::*;
837-
use arrow::array::{AsArray, Int64Array, StringViewArray};
846+
use arrow::array::{
847+
AsArray, Decimal128Array, Decimal256Array, Int64Array, StringViewArray,
848+
};
849+
use arrow::datatypes::i256;
838850
use std::sync::Arc;
839851
// A string longer than the 12-byte inline limit
840852
const LONG: &str = "this string is definitely longer than twelve bytes";
@@ -846,6 +858,96 @@ mod tests {
846858
}
847859
}
848860

861+
fn assert_count_numerical_acc_and_group_acc<T>(array: ArrayRef, expected: u64)
862+
where
863+
T: ArrowPrimitiveType + Debug,
864+
T::Native: Hash,
865+
{
866+
assert!(
867+
is_hll_groups_type(array.data_type()),
868+
"{} should be groups-capable",
869+
array.data_type()
870+
);
871+
872+
let mut acc = NumericHLLAccumulator::<T>::new();
873+
acc.update_batch(&[Arc::clone(&array.clone())]).unwrap();
874+
let per_group_count = match acc.evaluate().unwrap() {
875+
ScalarValue::UInt64(Some(v)) => v,
876+
other => panic!("unexpected evaluate result: {other:?}"),
877+
};
878+
879+
let group_indices = vec![0usize; array.len()];
880+
let mut acc = HllGroupsAccumulator::new();
881+
acc.update_batch(&[array.clone()], &group_indices, None, 1)
882+
.unwrap();
883+
let groups_count = acc
884+
.evaluate(EmitTo::All)
885+
.unwrap()
886+
.as_any()
887+
.downcast_ref::<UInt64Array>()
888+
.unwrap()
889+
.value(0);
890+
891+
assert_eq!(
892+
per_group_count,
893+
groups_count,
894+
"paths disagree for {}",
895+
array.data_type()
896+
);
897+
assert_eq!(
898+
per_group_count,
899+
expected,
900+
"wrong count for {}",
901+
array.data_type()
902+
);
903+
}
904+
905+
#[test]
906+
fn decimal_support_numerical_acc_and_group_acc() {
907+
let decimal_128: ArrayRef = Arc::new(
908+
Decimal128Array::from(vec![
909+
1i128,
910+
2,
911+
2,
912+
3,
913+
3,
914+
3,
915+
0,
916+
0,
917+
1_234_567_890,
918+
9_999_999_999,
919+
9_999_999_999,
920+
])
921+
.with_precision_and_scale(38, 2)
922+
.unwrap(),
923+
);
924+
assert_count_numerical_acc_and_group_acc::<Decimal128Type>(decimal_128, 6);
925+
926+
let big_256_a =
927+
i256::from_string("123456789012345678901234567890123456").unwrap();
928+
let big_256_b =
929+
i256::from_string("987654321098765432109876543210987654").unwrap();
930+
931+
let decimal_256: ArrayRef = Arc::new(
932+
Decimal256Array::from(vec![
933+
i256::from_i128(1),
934+
i256::from_i128(2),
935+
i256::from_i128(2),
936+
i256::from_i128(3),
937+
i256::from_i128(3),
938+
i256::from_i128(3),
939+
i256::from_i128(0),
940+
i256::from_i128(0),
941+
big_256_a,
942+
big_256_b,
943+
big_256_b,
944+
])
945+
.with_precision_and_scale(40, 2)
946+
.unwrap(),
947+
);
948+
assert_count_numerical_acc_and_group_acc::<Decimal256Type>(decimal_256, 6);
949+
}
950+
849951
/// `approx_distinct(v) FILTER (WHERE nullable_bool)` — a NULL filter row
850952
/// must not be counted (null filter is treated the same as false).
851953
#[test]

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1978,6 +1978,35 @@ true
19781978
statement ok
19791979
DROP TABLE approx_distinct_dense_test;
19801980

1981+
# This test runs approx_distinct over decimal128 and decimal256 for the scalar and the grouped path.
1982+
statement ok
1983+
CREATE TABLE approx_distinct_decimal_test (g INT, dec128 DECIMAL(20, 2), dec256 DECIMAL(40, 2)) AS VALUES
1984+
(1, 12345678901234.56, 12345678901234567890123456.78),
1985+
(1, 98765432109876.54, 98765432109876543210987654.32),
1986+
(1, 98765432109876.54, 98765432109876543210987654.32),
1987+
(2, 55555555555555.55, 55555555555555555555555555.55),
1988+
(2, -0.0, -0.0),
1989+
(2, 0.0, 0.0);
1990+
1991+
# Scalar path
1992+
query II
1993+
SELECT approx_distinct(dec128), approx_distinct(dec256) FROM approx_distinct_decimal_test;
1994+
----
1995+
4 4
1996+
1997+
# Grouped path
1998+
query III
1999+
SELECT g, approx_distinct(dec128), approx_distinct(dec256)
2000+
FROM approx_distinct_decimal_test GROUP BY g ORDER BY g;
2001+
----
2002+
1 2 2
2003+
2 2 2
2004+
2005+
statement ok
2006+
DROP TABLE approx_distinct_decimal_test;
2007+
2008+
2009+
19812010
## This test executes the APPROX_PERCENTILE_CONT aggregation against the test
19822011
## data, asserting the estimated quantiles are ±5% their actual values.
19832012
##

0 commit comments

Comments
 (0)