-
Notifications
You must be signed in to change notification settings - Fork 987
Description
Is your feature request related to a problem? Please describe.
Not sure if this is a feature request or a bug. Sorry if this a duplicate, browsed through the issues but couldn't seem to find a related issues. It seems UDFs are not supported with agg on GroupBy:
df.groupby("groups")["values"].agg(my_udf)Perhaps naively I would say this should behave in the same way as a rolling window UDF, where a callable that takes in a group and returns a single value. However, if I pass a custom function, instead I get:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[61], line 11
---> 11 grp["value_1"].agg(geometric_mean)
File [cudf/core/groupby/groupby.py:3429], in SeriesGroupBy.agg(self, func, engine, engine_kwargs, *args, **kwargs)
3428 def agg(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
-> 3429 result = super().agg(
3430 func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
3431 )
3433 # downcast the result to a Series:
3434 if result._num_columns:
File [cudf/utils/performance_tracking.py:51], in _performance_tracking.<locals>.wrapper(*args, **kwargs)
43 if nvtx.enabled():
44 stack.enter_context(
45 nvtx.annotate(
46 message=func.__qualname__,
(...) 49 )
50 )
---> 51 return func(*args, **kwargs)
File [cudf/core/groupby/groupby.py:1012], in GroupBy.agg(self, func, engine, engine_kwargs, *args, **kwargs)
1002 orig_dtypes = tuple(c.dtype for c in columns)
1004 # Note: When there are no key columns, the below produces
1005 # an Index with float64 dtype, while Pandas returns
1006 # an Index with int64 dtype.
1007 # (GH: 6945)
1008 (
1009 result_columns,
1010 grouped_key_cols,
1011 included_aggregations,
-> 1012 ) = self._aggregate(columns, normalized_aggs)
1014 result_index = self.grouping.keys._from_columns_like_self(
1015 grouped_key_cols,
1016 )
1018 multilevel = _is_multi_agg(func)
File [cudf/core/groupby/groupby.py:830], in GroupBy._aggregate(self, values, aggregations)
826 if _is_unsupported_agg_for_type(col.dtype, str_agg):
827 raise TypeError(
828 f"{col.dtype} type does not support {agg} operations"
829 )
--> 830 agg_obj = aggregation.make_aggregation(agg)
831 if (
832 valid_aggregations == "ALL"
833 or agg_obj.kind in valid_aggregations
834 ):
835 included_aggregations_i.append((agg, agg_obj.kind))
File [cudf/core/_internals/aggregation.py:287], in make_aggregation(op, kwargs)
285 return Aggregation.from_udf(op, **kwargs)
286 else:
--> 287 return op(Aggregation)
288 raise TypeError(f"Unknown aggregation {op}")
Cell In[61], line 4, in geometric_mean(values)
2 n = values.size
3 product = 1
----> 4 for value in values:
5 product *= value
6 return (product) ** (1/n)
TypeError: 'type' object is not iterable
Clearly the last bit of my error is specific to my UDF, but I would expect an error more in line with "we don't support this at the moment". This is confusing because the agg docstring clearly states it accepts "callables", but it is not entirely clear what is an acceptable callable.
I had had a quick look at aggregation.make_aggregation, which also seems to be used for the rolling window UDFs, and it seems this supports UDFs when an output data type is supplied, which is not done in
cudf/python/cudf/cudf/core/groupby/groupby.py
Line 830 in f4e35ca
| agg_obj = aggregation.make_aggregation(agg) |
Basically, I'm wondering whether this is an oversight or there is a good reason this doesn't work.
Describe the solution you'd like
For UDFs to work in GroupBy.transform and GroupBy.agg in the same way as Rolling.apply. Alteratively, if not possible for good reasons, a more constructive error message should be raised that UDFs are not supported, or the docs should clarify what kind of callables are supported.
Describe alternatives you've considered
I'm exploring the groupby functionality specifically for something I'm writing up about cudf and bumped into this issue. Of course one can always write a custom CUDA kernel.
Additional context
I'm on stable version 25.10