@@ -24,10 +24,11 @@ use arrow::array::{
2424} ;
2525use arrow:: buffer:: NullBuffer ;
2626use arrow:: datatypes:: {
27- ArrowPrimitiveType , DataType , Date32Type , Date64Type , Field , FieldRef , Int32Type ,
28- Int64Type , Time32MillisecondType , Time32SecondType , Time64MicrosecondType ,
29- Time64NanosecondType , TimeUnit , TimestampMicrosecondType , TimestampMillisecondType ,
30- TimestampNanosecondType , TimestampSecondType , UInt32Type , UInt64Type ,
27+ ArrowPrimitiveType , DataType , Date32Type , Date64Type , Decimal128Type , Decimal256Type ,
28+ Field , FieldRef , Int32Type , Int64Type , Time32MillisecondType , Time32SecondType ,
29+ Time64MicrosecondType , Time64NanosecondType , TimeUnit , TimestampMicrosecondType ,
30+ TimestampMillisecondType , TimestampNanosecondType , TimestampSecondType , UInt32Type ,
31+ UInt64Type ,
3132} ;
3233use datafusion_common:: ScalarValue ;
3334use datafusion_common:: hash_utils:: create_hashes;
@@ -758,6 +759,12 @@ impl AggregateUDFImpl for ApproxDistinct {
758759 DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => {
759760 Box :: new ( NumericHLLAccumulator :: < TimestampNanosecondType > :: new ( ) )
760761 }
762+ DataType :: Decimal128 ( _, _) => {
763+ Box :: new ( NumericHLLAccumulator :: < Decimal128Type > :: new ( ) )
764+ }
765+ DataType :: Decimal256 ( _, _) => {
766+ Box :: new ( NumericHLLAccumulator :: < Decimal256Type > :: new ( ) )
767+ }
761768 DataType :: Utf8
762769 | DataType :: LargeUtf8
763770 | DataType :: Utf8View
@@ -818,6 +825,8 @@ fn is_hll_groups_type(data_type: &DataType) -> bool {
818825 | DataType :: Timestamp ( TimeUnit :: Millisecond , _)
819826 | DataType :: Timestamp ( TimeUnit :: Microsecond , _)
820827 | DataType :: Timestamp ( TimeUnit :: Nanosecond , _)
828+ | DataType :: Decimal128 ( _, _)
829+ | DataType :: Decimal256 ( _, _)
821830 | DataType :: Utf8
822831 | DataType :: LargeUtf8
823832 | DataType :: Utf8View
@@ -834,7 +843,10 @@ mod tests {
834843 #[ cfg( not( feature = "force_hash_collisions" ) ) ]
835844 mod real_hash_test {
836845 use super :: * ;
837- use arrow:: array:: { AsArray , Int64Array , StringViewArray } ;
846+ use arrow:: array:: {
847+ AsArray , Decimal128Array , Decimal256Array , Int64Array , StringViewArray ,
848+ } ;
849+ use arrow:: datatypes:: i256;
838850 use std:: sync:: Arc ;
839851 // A string longer than the 12-byte inline limit
840852 const LONG : & str = "this string is definitely longer than twelve bytes" ;
@@ -846,6 +858,96 @@ mod tests {
846858 }
847859 }
848860
861+ fn assert_count_numerical_acc_and_group_acc < T > ( array : ArrayRef , expected : u64 )
862+ where
863+ T : ArrowPrimitiveType + Debug ,
864+ T :: Native : Hash ,
865+ {
866+ assert ! (
867+ is_hll_groups_type( array. data_type( ) ) ,
868+ "{} should be groups-capable" ,
869+ array. data_type( )
870+ ) ;
871+
872+ let mut acc = NumericHLLAccumulator :: < T > :: new ( ) ;
873+ acc. update_batch ( & [ Arc :: clone ( & array. clone ( ) ) ] ) . unwrap ( ) ;
874+ let per_group_count = match acc. evaluate ( ) . unwrap ( ) {
875+ ScalarValue :: UInt64 ( Some ( v) ) => v,
876+ other => panic ! ( "unexpected evaluate result: {other:?}" ) ,
877+ } ;
878+
879+ let group_indices = vec ! [ 0usize ; array. len( ) ] ;
880+ let mut acc = HllGroupsAccumulator :: new ( ) ;
881+ acc. update_batch ( & [ array. clone ( ) ] , & group_indices, None , 1 )
882+ . unwrap ( ) ;
883+ let groups_count = acc
884+ . evaluate ( EmitTo :: All )
885+ . unwrap ( )
886+ . as_any ( )
887+ . downcast_ref :: < UInt64Array > ( )
888+ . unwrap ( )
889+ . value ( 0 ) ;
890+
891+ assert_eq ! (
892+ per_group_count,
893+ groups_count,
894+ "paths disagree for {}" ,
895+ array. data_type( )
896+ ) ;
897+ assert_eq ! (
898+ per_group_count,
899+ expected,
900+ "wrong count for {}" ,
901+ array. data_type( )
902+ ) ;
903+ }
904+
905+ #[ test]
906+ fn decimal_support_numerical_acc_and_group_acc ( ) {
907+ let decimal_128: ArrayRef = Arc :: new (
908+ Decimal128Array :: from ( vec ! [
909+ 1i128 ,
910+ 2 ,
911+ 2 ,
912+ 3 ,
913+ 3 ,
914+ 3 ,
915+ 0 ,
916+ 0 ,
917+ 1_234_567_890 ,
918+ 9_999_999_999 ,
919+ 9_999_999_999 ,
920+ ] )
921+ . with_precision_and_scale ( 38 , 2 )
922+ . unwrap ( ) ,
923+ ) ;
924+ assert_count_numerical_acc_and_group_acc :: < Decimal128Type > ( decimal_128, 6 ) ;
925+
926+ let big_256_a =
927+ i256:: from_string ( "123456789012345678901234567890123456" ) . unwrap ( ) ;
928+ let big_256_b =
929+ i256:: from_string ( "987654321098765432109876543210987654" ) . unwrap ( ) ;
930+
931+ let decimal_256: ArrayRef = Arc :: new (
932+ Decimal256Array :: from ( vec ! [
933+ i256:: from_i128( 1 ) ,
934+ i256:: from_i128( 2 ) ,
935+ i256:: from_i128( 2 ) ,
936+ i256:: from_i128( 3 ) ,
937+ i256:: from_i128( 3 ) ,
938+ i256:: from_i128( 3 ) ,
939+ i256:: from_i128( 0 ) ,
940+ i256:: from_i128( 0 ) ,
941+ big_256_a,
942+ big_256_b,
943+ big_256_b,
944+ ] )
945+ . with_precision_and_scale ( 40 , 2 )
946+ . unwrap ( ) ,
947+ ) ;
948+ assert_count_numerical_acc_and_group_acc :: < Decimal256Type > ( decimal_256, 6 ) ;
949+ }
950+
849951 /// `approx_distinct(v) FILTER (WHERE nullable_bool)` — a NULL filter row
850952 /// must not be counted (null filter is treated the same as false).
851953 #[ test]
0 commit comments