Skip to content

Commit 0840763

Browse files
timsaucerclaudentjohnson1
authored
feat: accept distinct kwarg on sum and avg (#1556)
* feat: accept distinct kwarg on sum and avg Upstream exposes `sum_distinct` / `avg_distinct` / `count_distinct` as sibling functions that call the same underlying UDAF with `distinct: bool = true`. The Rust binding side already routes `distinct=Some(true)` through the aggregate builder for `sum`, `avg`, and `count` — but only `count` exposed the kwarg on the Python wrapper. Add `distinct: bool = False` to `sum()` and `avg()` mirroring the existing `count()` signature, and update SKILL.md so the check-upstream audit does not re-flag the three upstream `*_distinct` shortcuts as gaps. The plan emitted by `sum(col, distinct=True)` matches what upstream's `sum_distinct(col)` builds. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: fold sum/avg distinct tests into parameterized aggregation test Move the standalone test_sum_distinct_kwarg and test_avg_distinct_kwarg from test_functions.py into the existing test_aggregation::test_aggregation parameterization, matching how distinct is already covered for median, array_agg, count, and bit_xor. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs: clarify distinct kwarg on sum and avg Drop the unhelpful "upstream avg_distinct/sum_distinct shortcut" reference in favor of describing the actual behavior. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs: note sum/avg distinct argument-order breaking change distinct is inserted before filter on sum and avg for consistency with the other aggregate functions, breaking positional filter callers. Add a DataFusion 54.0.0 upgrade-guide entry covering the migration. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Update docs/source/user-guide/upgrade-guides.rst Co-authored-by: Nick <24689722+ntjohnson1@users.noreply.github.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Nick <24689722+ntjohnson1@users.noreply.github.com>
1 parent 744dd23 commit 0840763

4 files changed

Lines changed: 54 additions & 5 deletions

File tree

.ai/skills/check-upstream/SKILL.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,18 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all"
8888
- Python API: `python/datafusion/functions.py` (aggregate functions are mixed in with scalar functions)
8989
- Rust bindings: `crates/core/src/functions.rs`
9090

91+
**Evaluated and not requiring separate Python exposure:**
92+
- `count_distinct` — covered by `count(expr, distinct=True)`. Both forms call
93+
`count_udaf` with `distinct: bool = true` and produce the same logical plan.
94+
- `sum_distinct` — covered by `sum(expr, distinct=True)`.
95+
- `avg_distinct` — covered by `avg(expr, distinct=True)`.
96+
9197
**How to check:**
9298
1. Fetch the upstream aggregate function documentation page
9399
2. Compare against aggregate functions in `python/datafusion/functions.py` (check `__all__` list and function definitions)
94100
3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding
95-
4. Report only functions missing from the Python API
101+
4. Check against the "evaluated and not requiring exposure" list before flagging as a gap
102+
5. Report only functions missing from the Python API
96103

97104
### 3. Window Functions
98105

docs/source/user-guide/upgrade-guides.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,26 @@ After:
4545
config = SessionConfig().set("datafusion.execution.batch_size", "4096")
4646
ctx = SessionContext(config)
4747
48+
The aggregate functions :py:func:`~datafusion.functions.sum` and
49+
:py:func:`~datafusion.functions.avg` now accept a ``distinct`` argument, matching
50+
the other aggregate functions. ``distinct`` is inserted *before* ``filter`` in the
51+
argument list, so any code that passed ``filter`` positionally must be updated to
52+
pass it as a keyword argument. The types are distinct so a type checker should flag this.
53+
54+
Before:
55+
56+
.. code-block:: python
57+
58+
f.sum(column("a"), my_filter)
59+
f.avg(column("a"), my_filter)
60+
61+
Now:
62+
63+
.. code-block:: python
64+
65+
f.sum(column("a"), filter=my_filter)
66+
f.avg(column("a"), filter=my_filter)
67+
4868
DataFusion 53.0.0
4969
-----------------
5070

python/datafusion/functions.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4563,17 +4563,19 @@ def grouping(
45634563

45644564
def avg(
45654565
expression: Expr,
4566+
distinct: bool = False,
45664567
filter: Expr | None = None,
45674568
) -> Expr:
45684569
"""Returns the average value.
45694570
45704571
This aggregate function expects a numeric expression and will return a float.
45714572
45724573
If using the builder functions described in ref:`_aggregation` this function ignores
4573-
the options ``order_by``, ``null_treatment``, and ``distinct``.
4574+
the options ``order_by`` and ``null_treatment``.
45744575
45754576
Args:
45764577
expression: Values to combine into an array
4578+
distinct: If True, duplicate values are removed before averaging.
45774579
filter: If provided, only compute against rows for which the filter is True
45784580
45794581
Examples:
@@ -4593,9 +4595,17 @@ def avg(
45934595
... ).alias("v")])
45944596
>>> result.collect_column("v")[0].as_py()
45954597
2.5
4598+
4599+
>>> df = ctx.from_pydict({"a": [1.0, 1.0, 2.0, 3.0]})
4600+
>>> result = df.aggregate(
4601+
... [], [dfn.functions.avg(
4602+
... dfn.col("a"), distinct=True,
4603+
... ).alias("v")])
4604+
>>> result.collect_column("v")[0].as_py()
4605+
2.0
45964606
"""
45974607
filter_raw = filter.expr if filter is not None else None
4598-
return Expr(f.avg(expression.expr, filter=filter_raw))
4608+
return Expr(f.avg(expression.expr, distinct=distinct, filter=filter_raw))
45994609

46004610

46014611
def corr(value_y: Expr, value_x: Expr, filter: Expr | None = None) -> Expr:
@@ -4880,17 +4890,19 @@ def min(expression: Expr, filter: Expr | None = None) -> Expr:
48804890

48814891
def sum(
48824892
expression: Expr,
4893+
distinct: bool = False,
48834894
filter: Expr | None = None,
48844895
) -> Expr:
48854896
"""Computes the sum of a set of numbers.
48864897
48874898
This aggregate function expects a numeric expression.
48884899
48894900
If using the builder functions described in ref:`_aggregation` this function ignores
4890-
the options ``order_by``, ``null_treatment``, and ``distinct``.
4901+
the options ``order_by`` and ``null_treatment``.
48914902
48924903
Args:
48934904
expression: Values to combine into an array
4905+
distinct: If True, duplicate values are removed before summing.
48944906
filter: If provided, only compute against rows for which the filter is True
48954907
48964908
Examples:
@@ -4910,9 +4922,17 @@ def sum(
49104922
... ).alias("v")])
49114923
>>> result.collect_column("v")[0].as_py()
49124924
5
4925+
4926+
>>> df = ctx.from_pydict({"a": [1, 1, 2, 3]})
4927+
>>> result = df.aggregate(
4928+
... [], [dfn.functions.sum(
4929+
... dfn.col("a"), distinct=True,
4930+
... ).alias("v")])
4931+
>>> result.collect_column("v")[0].as_py()
4932+
6
49134933
"""
49144934
filter_raw = filter.expr if filter is not None else None
4915-
return Expr(f.sum(expression.expr, filter=filter_raw))
4935+
return Expr(f.sum(expression.expr, distinct=distinct, filter=filter_raw))
49164936

49174937

49184938
def stddev(expression: Expr, filter: Expr | None = None) -> Expr:

python/tests/test_aggregation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def test_aggregation_stats(df, agg_expr, calc_expected):
192192
False,
193193
),
194194
(f.avg(column("b"), filter=column("a") != lit(1)), pa.array([5.0]), False),
195+
(f.avg(column("b"), distinct=True), pa.array([5.0]), False),
195196
(f.sum(column("b"), filter=column("a") != lit(1)), pa.array([10]), False),
197+
(f.sum(column("b"), distinct=True), pa.array([10]), False),
196198
(f.count(column("b"), distinct=True), pa.array([2]), False),
197199
(f.count(column("b"), filter=column("a") != 3), pa.array([2]), False),
198200
(f.count(), pa.array([3]), False),

0 commit comments

Comments
 (0)