Skip to content

Conversation

dangotbanned
Copy link
Member

@dangotbanned dangotbanned commented May 18, 2025

Will close #2571

What type of PR is this? (check all applicable)

  • πŸ’Ύ Refactor
  • ✨ Feature
  • πŸ› Bug Fix
  • πŸ”§ Optimization
  • πŸ“ Documentation
  • βœ… Test
  • 🐳 Other

Related issues

Checklist

  • Code follows style guide (ruff)
  • Tests added
  • Documented the changes

If you have comments or can explain your changes, please do so below

Important

See (#2571) for detail!!!!!!!

Very open to feedback

Tasks

@dangotbanned
Copy link
Member Author

dangotbanned commented May 19, 2025

FunctionFlags, FunctionOptions, Function, FunctionExpr

Feel like there's more I need to understand on the correct way to propagate the flags to the top-level.

@MarcoGorelli whenever you get to this - I've left loads of notes w/ references - hoping that you'd be able to demystify the rust magic πŸ˜„

All 4 classes

They're a bit trimmed down from the polars versions, where I couldn't see a need for us to mirror everything if we wouldn't use it

class FunctionFlags(enum.Flag):
ALLOW_GROUP_AWARE = 1 << 0
"""> Raise if use in group by
Not sure where this is disabled.
"""
INPUT_WILDCARD_EXPANSION = 1 << 4
"""Appears on all the horizontal aggs.
https://github.com/pola-rs/polars/blob/e8ad1059721410e65a3d5c1d84055fb22a4d6d43/crates/polars-plan/src/plans/options.rs#L49-L58
"""
RETURNS_SCALAR = 1 << 5
"""Automatically explode on unit length if it ran as final aggregation."""
ROW_SEPARABLE = 1 << 8
"""Not sure lol.
https://github.com/pola-rs/polars/pull/22573
"""
LENGTH_PRESERVING = 1 << 9
"""mutually exclusive with `RETURNS_SCALAR`"""
def is_elementwise(self) -> bool:
return self in (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING)
def returns_scalar(self) -> bool:
return self in FunctionFlags.RETURNS_SCALAR
def is_length_preserving(self) -> bool:
return self in FunctionFlags.LENGTH_PRESERVING
@staticmethod
def default() -> FunctionFlags:
return FunctionFlags.ALLOW_GROUP_AWARE

class FunctionOptions(Immutable):
"""ExprMetadata` but less god object.
https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs
"""
__slots__ = ("flags",)
flags: FunctionFlags
def is_elementwise(self) -> bool:
return self.flags.is_elementwise()
def returns_scalar(self) -> bool:
return self.flags.returns_scalar()
def is_length_preserving(self) -> bool:
return self.flags.is_length_preserving()
def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions:
if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags:
msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive."
raise TypeError(msg)
obj = FunctionOptions.__new__(FunctionOptions)
object.__setattr__(obj, "flags", self.flags | flags)
return obj
def with_elementwise(self) -> FunctionOptions:
return self.with_flags(
FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING
)
@staticmethod
def default() -> FunctionOptions:
obj = FunctionOptions.__new__(FunctionOptions)
object.__setattr__(obj, "flags", FunctionFlags.default())
return obj
@staticmethod
def elementwise() -> FunctionOptions:
return FunctionOptions.default().with_elementwise()
@staticmethod
def row_separable() -> FunctionOptions:
return FunctionOptions.groupwise().with_flags(FunctionFlags.ROW_SEPARABLE)
@staticmethod
def length_preserving() -> FunctionOptions:
return FunctionOptions.default().with_flags(FunctionFlags.LENGTH_PRESERVING)
@staticmethod
def groupwise() -> FunctionOptions:
return FunctionOptions.default()
@staticmethod
def aggregation() -> FunctionOptions:
return FunctionOptions.groupwise().with_flags(FunctionFlags.RETURNS_SCALAR)

class Function(ExprIR):
"""Shared by expr functions and namespace functions.
https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114
"""
@property
def function_options(self) -> FunctionOptions:
from narwhals._plan.options import FunctionOptions
return FunctionOptions.default()
@property
def is_scalar(self) -> bool:
return self.function_options.returns_scalar()
def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]:
from narwhals._plan.expr import FunctionExpr
from narwhals._plan.options import FunctionOptions
# NOTE: Still need to figure out how these should be generated
# Feel like it should be the union of `input` & `function`
PLACEHOLDER = FunctionOptions.default() # noqa: N806
return FunctionExpr(input=inputs, function=self, options=PLACEHOLDER)

class FunctionExpr(ExprIR, t.Generic[_FunctionT]):
"""**Representing `Expr::Function`**.
https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120
https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123
"""
__slots__ = ("function", "input", "options")
input: Seq[ExprIR]
function: _FunctionT
"""Enum type is named `FunctionExpr` in `polars`.
Mirroring *exactly* doesn't make much sense in OOP.
https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123
"""
options: FunctionOptions
"""Assuming this is **either**:
1. `function.function_options`
2. The union of (1) and any `FunctionOptions` in `inputs`
"""
def with_options(self, options: FunctionOptions, /) -> Self:
options = self.options.with_flags(options.flags)
return type(self)(input=self.input, function=self.function, options=options)

dangotbanned added a commit that referenced this pull request May 20, 2025
- Mentioned in (#2391 (comment))
- Needed again for #2572
@dangotbanned dangotbanned mentioned this pull request May 20, 2025
10 tasks
@MarcoGorelli
Copy link
Member

@MarcoGorelli whenever you get to this - I've left loads of notes w/ references - hoping that you'd be able to demystify the rust magic πŸ˜„

at the moment it looks like this adds a self-standing _plan, that's not integrated into the rest of Narwhals? if it's possible to integrate it with the rest to prove that this is feasible, i'd be very interested in taking a close look

MarcoGorelli pushed a commit that referenced this pull request May 21, 2025
* chore(typing): Add `_typing_compat.py`

- Mentioned in (#2391 (comment))
- Needed again for #2572

* refactor: Reuse `TypeVar` import

* refactor: Reuse `@deprecated` import

* refactor: Reuse `Protocol38` import

* docs: Add module-level docstring
dangotbanned added a commit that referenced this pull request May 21, 2025
Still need:
- reprs
- fix the hierarchy issue (#2572 (comment))
- Flag summing (#2572 (comment))
dangotbanned added a commit that referenced this pull request May 21, 2025
dangotbanned added a commit that referenced this pull request May 21, 2025
- 1 step closer to the understanding for (#2572 (comment))
- There's still some magic going on when `polars` serializes
  - Need to track down where `'collect_groups': 'ElementWise'` and `'collect_groups': 'GroupWise'` first appear
  - Seems like the flags get reduced
Comment on lines 200 to 202
@property
def function_options(self) -> FunctionOptions:
return FunctionOptions.length_preserving()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like FunctionOptions.length_preserving is one we need to pay more attention to

@dangotbanned
Copy link
Member Author

Thanks for peeking @MarcoGorelli

at the moment it looks like this adds a self-standing _plan, that's not integrated into the rest of Narwhals? if it's possible to integrate it with the rest to prove that this is feasible, i'd be very interested in taking a close look

That is definitely the eventual goal! 🀞

Despite how quickly things have progressed, I still feel I'm a few steps behind being ready for that just yet.

General overview

I'm trying to focus on modeling these structures and how they interact:

My thought was that narwhals currently solves similar problems, but in different ways to polars.
By getting this subset of polars that we need translated from rust -> python first, we're then in a good position to make decisions on how to bridge any gaps that narwhals is occupying at the moment, if that makes sense?

So like what I have in narwhals/_plan/dummy.py is all about creating an accurate expression graph.
Consuming the graph (evaluating an expression) is a very important step - but the main difference I'm anticipating is:

Current

  • Expressions are based on evaluating callables
  • with other callables used for output names & aliases
  • They also optionally store a ExprMetadata object
  • Some backends do other things with function names, kwargs, depth
  • Lazy backends use more callables for handling windows

ExprIR

  • Expressions are (likely still) based on evaluating callables
  • All of the remaining details are either:
    • Encoded into a node on the graph
      • E.g. the nodes for Expr.var, Expr.std have the name and ddof already
      • Depth can be computed anywhere by traversing the graph
    • Or, can be expressed in terms of operating on a node

I am confident we'll end up with something that's easier to maintain - but trying to integrate the two mid-solve and maintaining that branch over time seems like it'd be a real challenge πŸ˜”


FunctionOptions, FunctionFlags

The parts mentioned in (#2572 (comment)) are some were one of the main hurdles left.
However I think it is really just a skill issue on my part in understanding the rust code. That was all I was hoping for some help with πŸ™
In narwhals terms, it is closest to (https://github.com/narwhals-dev/narwhals/blob/1b93c0ed2dc2bf47d7b8e4b4cab4c9e9cad59800/narwhals/_expression_parsing.py) but only a subset of the rules - since it only concerns functions and how they compose.

Note

Right before I was about to send this comment, I managed to fix the issue I had by updating to 1.30.0 πŸ€¦β€β™‚οΈ
See (diff), which removed the flags I was having trouble finding in (0982b3a)

Example

Now all the flags I've been using are propagated in the same way as in polars! πŸŽ‰

Repro code

import polars as pl

from narwhals._plan import demo as nwd  # noqa: F811
from narwhals._plan import meta  # noqa: F811

expr_pl = (
    pl.col("a")
    .sort()
    .fill_null(1)
    .shift(1)
    .abs()
    .drop_nulls()
    .skew()
    .alias("col->sort->fill_null->shift->abs->drop_nulls->skew->alias")
)
expr_nwd = (
    nwd.col("a")
    .sort()
    .fill_null(1)
    .shift(1)
    .abs()
    .drop_nulls()
    .skew()
    .alias("col->sort->fill_null->shift->abs->drop_nulls->skew->alias")
)
roundtrip_pl = meta.polars_expr_to_dict(expr_pl)
roundtrip_nw = str(expr_nwd._ir)

roundtrip_pl

I haven't added ALLOW_EMPTY_INPUTS - so that's an expected difference between the two

>>> roundtrip_pl
{'Alias': [{'Function': {'input': [{'Function': {'input': [{'Function': {'input': [{'Function': {'input': [{'Function': {'input': [{'Sort': {'expr': {'Column': 'a'},
                   'options': {'descending': False,
                    'nulls_last': False,
                    'multithreaded': True,
                    'maintain_order': False,
                    'limit': None}}},
                 {'Literal': {'Dyn': {'Int': 1}}}],
                'function': 'FillNull',
                'options': {'check_lengths': True,
                 'flags': 'ALLOW_GROUP_AWARE | ROW_SEPARABLE | LENGTH_PRESERVING'}}},
              {'Literal': {'Dyn': {'Int': 1}}}],
             'function': 'Shift',
             'options': {'check_lengths': True,
              'flags': 'ALLOW_GROUP_AWARE | LENGTH_PRESERVING'}}}],
          'function': 'Abs',
          'options': {'check_lengths': True,
           'flags': 'ALLOW_GROUP_AWARE | ROW_SEPARABLE | LENGTH_PRESERVING'}}}],
       'function': 'DropNulls',
       'options': {'check_lengths': True,
        'flags': 'ALLOW_GROUP_AWARE | ALLOW_EMPTY_INPUTS | ROW_SEPARABLE'}}}],
    'function': {'Skew': True},
    'options': {'check_lengths': True,
     'flags': 'ALLOW_GROUP_AWARE | RETURNS_SCALAR'}}},
  'col->sort->fill_null->shift->abs->drop_nulls->skew->alias']}

roundtrip_nw

To produce this I removed the outer ", " and pasted back to ruff to format

The overall shape is very similar and the deviations from polars have been documented

Alias(
    expr=FunctionExpr(
        function=Skew(),
        input=[
            FunctionExpr(
                function=DropNulls(),
                input=[
                    FunctionExpr(
                        function=Abs(),
                        input=[
                            FunctionExpr(
                                function=Shift(n=1),
                                input=[
                                    FunctionExpr(
                                        function=FillNull(),
                                        input=[
                                            Sort(
                                                expr=Column(name="a"),
                                                options=SortOptions(
                                                    descending=False, nulls_last=False
                                                ),
                                            ),
                                            Literal(
                                                value=ScalarLiteral(
                                                    dtype=Unknown, value=1
                                                )
                                            ),
                                        ],
                                        options=FunctionOptions(
                                            flags="ALLOW_GROUP_AWARE | ROW_SEPARABLE | LENGTH_PRESERVING"
                                        ),
                                    )
                                ],
                                options=FunctionOptions(
                                    flags="ALLOW_GROUP_AWARE | LENGTH_PRESERVING"
                                ),
                            )
                        ],
                        options=FunctionOptions(
                            flags="ALLOW_GROUP_AWARE | ROW_SEPARABLE | LENGTH_PRESERVING"
                        ),
                    )
                ],
                options=FunctionOptions(flags="ALLOW_GROUP_AWARE | ROW_SEPARABLE"),
            )
        ],
        options=FunctionOptions(flags="ALLOW_GROUP_AWARE | RETURNS_SCALAR"),
    ),
    name="col->sort->fill_null->shift->abs->drop_nulls->skew->alias",
)

@dangotbanned

This comment was marked as resolved.

dangotbanned added a commit that referenced this pull request May 23, 2025
Can't tell if this means `FirstT` will match the entry `firstt`, but preserve the `firstt` fix (https://github.com/codespell-project/codespell#ignoring-words)

(#2572 (comment))
@dangotbanned
Copy link
Member Author

I should've expected this, but it was a nice suprise to find we get hashable selectors for free πŸ˜„

from narwhals._plan import selectors as ndcs

>>> ndcs.matches("[^z]a")._ir == ndcs.matches("[^z]a")._ir
True

>>> ndcs.matches("[^z]a")._ir == ndcs.matches("abc")._ir
False

@MarcoGorelli regarding (#2291)

from narwhals._plan import selectors as ndcs

>>> ndcs.all()._ir == ndcs.all()._ir
True

lhs = ndcs.all()
rhs = ndcs.all().mean()

>>> lhs._ir == rhs._ir
False

>>> lhs._ir == rhs._ir.expr
True

And the same holds for the non-selectors all πŸ₯³

from narwhals._plan import demo as nwd

lhs = nwd.all()
rhs = nwd.all().mean()
>>> lhs._ir == rhs._ir
False

>>> lhs._ir == rhs._ir.expr
True

>>> type(rhs._ir)
narwhals._plan.aggregation.Mean

dangotbanned added a commit that referenced this pull request May 24, 2025
dangotbanned added a commit that referenced this pull request May 26, 2025
Comment on lines +85 to +110
def test_valid_windows() -> None:
"""Was planning to test this matched, but we seem to allow elementwise horizontal?
https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L10-L45
"""
ELEMENTWISE_ERR = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) # noqa: N806
a = nwd.col("a")
assert a.cum_sum()
assert a.cum_sum().over(order_by="id")
with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR):
assert a.cum_sum().abs().over(order_by="id")

assert (a.cum_sum() + 1).over(order_by="id")
assert a.cum_sum().cum_sum().over(order_by="id")
assert a.cum_sum().cum_sum()
assert nwd.sum_horizontal(a, a.cum_sum())
with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR):
assert nwd.sum_horizontal(a, a.cum_sum()).over(order_by="a")

assert nwd.sum_horizontal(a, a.cum_sum().over(order_by="i"))
assert nwd.sum_horizontal(a.diff(), a.cum_sum().over(order_by="i"))
with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR):
assert nwd.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i")

with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR):
assert nwd.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli quick question

This is adapted from an existing test:

tests.expression_parsing_test.test_window_kind

@pytest.mark.parametrize(
("expr", "expected"),
[
(nw.col("a"), 0),
(nw.col("a").mean(), 0),
(nw.col("a").cum_sum(), 1),
(nw.col("a").cum_sum().over(order_by="id"), 0),
(nw.col("a").cum_sum().abs().over(order_by="id"), 1),
((nw.col("a").cum_sum() + 1).over(order_by="id"), 1),
(nw.col("a").cum_sum().cum_sum().over(order_by="id"), 1),
(nw.col("a").cum_sum().cum_sum(), 2),
(nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), 1),
(nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over(order_by="a"), 1),
(nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), 0),
(
nw.sum_horizontal(
nw.col("a").diff(), nw.col("a").cum_sum().over(order_by="i")
),
1,
),
(
nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum()).over(
order_by="i"
),
2,
),
(
nw.sum_horizontal(nw.col("a").diff().abs(), nw.col("a").cum_sum()).over(
order_by="i"
),
2,
),
],
)
def test_window_kind(expr: nw.Expr, expected: int) -> None:
assert expr._metadata.n_orderable_ops == expected

AFAICT, all of the expressions I've needed a InvalidOperationError for shouldn't be valid.

But they aren't raising in current narwhals πŸ€”

1

import narwhals as nw

a = nw.col("a")
a.cum_sum().abs().over(order_by="id")
This error explicitly mentions abs

if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)

2, 3, 4

These are all raising the same as (1), but the issue seems to be that horizontal functions aren't being treated as elementwise

import narwhals as nw

a = nw.col("a")
nw.sum_horizontal(a, a.cum_sum()).over(order_by="a")
nw.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i")
nw.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i")

In polars, they all seem to be elementwise but with an additional flag

https://github.com/pola-rs/polars/blob/944bf553f1111c31259b55348a7dd0a512ae51a1/crates/polars-plan/src/dsl/function_expr/mod.rs#L1388-L1392

I've done the same in this PR, but I don't think that flag would factor into this?

class SumHorizontal(Function):
@property
def function_options(self) -> FunctionOptions:
return FunctionOptions.elementwise().with_flags(
FunctionFlags.INPUT_WILDCARD_EXPANSION
)
def __repr__(self) -> str:
return "sum_horizontal"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Enh]: A richer Expr internal representation
2 participants