-
Notifications
You must be signed in to change notification settings - Fork 161
feat(RFC): A richer Expr
IR
#2572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
class FunctionFlags(enum.Flag): | |
ALLOW_GROUP_AWARE = 1 << 0 | |
"""> Raise if use in group by | |
Not sure where this is disabled. | |
""" | |
INPUT_WILDCARD_EXPANSION = 1 << 4 | |
"""Appears on all the horizontal aggs. | |
https://github.com/pola-rs/polars/blob/e8ad1059721410e65a3d5c1d84055fb22a4d6d43/crates/polars-plan/src/plans/options.rs#L49-L58 | |
""" | |
RETURNS_SCALAR = 1 << 5 | |
"""Automatically explode on unit length if it ran as final aggregation.""" | |
ROW_SEPARABLE = 1 << 8 | |
"""Not sure lol. | |
https://github.com/pola-rs/polars/pull/22573 | |
""" | |
LENGTH_PRESERVING = 1 << 9 | |
"""mutually exclusive with `RETURNS_SCALAR`""" | |
def is_elementwise(self) -> bool: | |
return self in (FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING) | |
def returns_scalar(self) -> bool: | |
return self in FunctionFlags.RETURNS_SCALAR | |
def is_length_preserving(self) -> bool: | |
return self in FunctionFlags.LENGTH_PRESERVING | |
@staticmethod | |
def default() -> FunctionFlags: | |
return FunctionFlags.ALLOW_GROUP_AWARE |
narwhals/narwhals/_plan/options.py
Lines 52 to 108 in 0bada48
class FunctionOptions(Immutable): | |
"""ExprMetadata` but less god object. | |
https://github.com/pola-rs/polars/blob/3fd7ecc5f9de95f62b70ea718e7e5dbf951b6d1c/crates/polars-plan/src/plans/options.rs | |
""" | |
__slots__ = ("flags",) | |
flags: FunctionFlags | |
def is_elementwise(self) -> bool: | |
return self.flags.is_elementwise() | |
def returns_scalar(self) -> bool: | |
return self.flags.returns_scalar() | |
def is_length_preserving(self) -> bool: | |
return self.flags.is_length_preserving() | |
def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: | |
if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: | |
msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." | |
raise TypeError(msg) | |
obj = FunctionOptions.__new__(FunctionOptions) | |
object.__setattr__(obj, "flags", self.flags | flags) | |
return obj | |
def with_elementwise(self) -> FunctionOptions: | |
return self.with_flags( | |
FunctionFlags.ROW_SEPARABLE | FunctionFlags.LENGTH_PRESERVING | |
) | |
@staticmethod | |
def default() -> FunctionOptions: | |
obj = FunctionOptions.__new__(FunctionOptions) | |
object.__setattr__(obj, "flags", FunctionFlags.default()) | |
return obj | |
@staticmethod | |
def elementwise() -> FunctionOptions: | |
return FunctionOptions.default().with_elementwise() | |
@staticmethod | |
def row_separable() -> FunctionOptions: | |
return FunctionOptions.groupwise().with_flags(FunctionFlags.ROW_SEPARABLE) | |
@staticmethod | |
def length_preserving() -> FunctionOptions: | |
return FunctionOptions.default().with_flags(FunctionFlags.LENGTH_PRESERVING) | |
@staticmethod | |
def groupwise() -> FunctionOptions: | |
return FunctionOptions.default() | |
@staticmethod | |
def aggregation() -> FunctionOptions: | |
return FunctionOptions.groupwise().with_flags(FunctionFlags.RETURNS_SCALAR) |
narwhals/narwhals/_plan/common.py
Lines 149 to 172 in 0bada48
class Function(ExprIR): | |
"""Shared by expr functions and namespace functions. | |
https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/expr.rs#L114 | |
""" | |
@property | |
def function_options(self) -> FunctionOptions: | |
from narwhals._plan.options import FunctionOptions | |
return FunctionOptions.default() | |
@property | |
def is_scalar(self) -> bool: | |
return self.function_options.returns_scalar() | |
def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]: | |
from narwhals._plan.expr import FunctionExpr | |
from narwhals._plan.options import FunctionOptions | |
# NOTE: Still need to figure out how these should be generated | |
# Feel like it should be the union of `input` & `function` | |
PLACEHOLDER = FunctionOptions.default() # noqa: N806 | |
return FunctionExpr(input=inputs, function=self, options=PLACEHOLDER) |
narwhals/narwhals/_plan/expr.py
Lines 157 to 185 in 0bada48
class FunctionExpr(ExprIR, t.Generic[_FunctionT]): | |
"""**Representing `Expr::Function`**. | |
https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L114-L120 | |
https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 | |
""" | |
__slots__ = ("function", "input", "options") | |
input: Seq[ExprIR] | |
function: _FunctionT | |
"""Enum type is named `FunctionExpr` in `polars`. | |
Mirroring *exactly* doesn't make much sense in OOP. | |
https://github.com/pola-rs/polars/blob/112cab39380d8bdb82c6b76b31aca9b58c98fd93/crates/polars-plan/src/dsl/function_expr/mod.rs#L123 | |
""" | |
options: FunctionOptions | |
"""Assuming this is **either**: | |
1. `function.function_options` | |
2. The union of (1) and any `FunctionOptions` in `inputs` | |
""" | |
def with_options(self, options: FunctionOptions, /) -> Self: | |
options = self.options.with_flags(options.flags) | |
return type(self)(input=self.input, function=self.function, options=options) |
- Mentioned in (#2391 (comment)) - Needed again for #2572
at the moment it looks like this adds a self-standing |
* chore(typing): Add `_typing_compat.py` - Mentioned in (#2391 (comment)) - Needed again for #2572 * refactor: Reuse `TypeVar` import * refactor: Reuse `@deprecated` import * refactor: Reuse `Protocol38` import * docs: Add module-level docstring
Still need: - reprs - fix the hierarchy issue (#2572 (comment)) - Flag summing (#2572 (comment))
- 1 step closer to the understanding for (#2572 (comment)) - There's still some magic going on when `polars` serializes - Need to track down where `'collect_groups': 'ElementWise'` and `'collect_groups': 'GroupWise'` first appear - Seems like the flags get reduced
narwhals/_plan/functions.py
Outdated
@property | ||
def function_options(self) -> FunctionOptions: | ||
return FunctionOptions.length_preserving() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like FunctionOptions.length_preserving
is one we need to pay more attention to
Thanks for peeking @MarcoGorelli
That is definitely the eventual goal! π€ Despite how quickly things have progressed, I still feel I'm a few steps behind being ready for that just yet. General overviewI'm trying to focus on modeling these structures and how they interact:
My thought was that So like what I have in Current
|
This comment was marked as resolved.
This comment was marked as resolved.
Can't tell if this means `FirstT` will match the entry `firstt`, but preserve the `firstt` fix (https://github.com/codespell-project/codespell#ignoring-words) (#2572 (comment))
I should've expected this, but it was a nice suprise to find we get hashable selectors for free π from narwhals._plan import selectors as ndcs
>>> ndcs.matches("[^z]a")._ir == ndcs.matches("[^z]a")._ir
True
>>> ndcs.matches("[^z]a")._ir == ndcs.matches("abc")._ir
False @MarcoGorelli regarding (#2291) from narwhals._plan import selectors as ndcs
>>> ndcs.all()._ir == ndcs.all()._ir
True
lhs = ndcs.all()
rhs = ndcs.all().mean()
>>> lhs._ir == rhs._ir
False
>>> lhs._ir == rhs._ir.expr
True And the same holds for the non-selectors from narwhals._plan import demo as nwd
lhs = nwd.all()
rhs = nwd.all().mean()
>>> lhs._ir == rhs._ir
False
>>> lhs._ir == rhs._ir.expr
True
>>> type(rhs._ir)
narwhals._plan.aggregation.Mean |
An experiment towards (#2572 (comment))
def test_valid_windows() -> None: | ||
"""Was planning to test this matched, but we seem to allow elementwise horizontal? | ||
https://github.com/narwhals-dev/narwhals/blob/63c8e4771a1df4e0bfeea5559c303a4a447d5cc2/tests/expression_parsing_test.py#L10-L45 | ||
""" | ||
ELEMENTWISE_ERR = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) # noqa: N806 | ||
a = nwd.col("a") | ||
assert a.cum_sum() | ||
assert a.cum_sum().over(order_by="id") | ||
with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): | ||
assert a.cum_sum().abs().over(order_by="id") | ||
|
||
assert (a.cum_sum() + 1).over(order_by="id") | ||
assert a.cum_sum().cum_sum().over(order_by="id") | ||
assert a.cum_sum().cum_sum() | ||
assert nwd.sum_horizontal(a, a.cum_sum()) | ||
with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): | ||
assert nwd.sum_horizontal(a, a.cum_sum()).over(order_by="a") | ||
|
||
assert nwd.sum_horizontal(a, a.cum_sum().over(order_by="i")) | ||
assert nwd.sum_horizontal(a.diff(), a.cum_sum().over(order_by="i")) | ||
with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): | ||
assert nwd.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i") | ||
|
||
with pytest.raises(InvalidOperationError, match=ELEMENTWISE_ERR): | ||
assert nwd.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarcoGorelli quick question
This is adapted from an existing test:
tests.expression_parsing_test.test_window_kind
narwhals/tests/expression_parsing_test.py
Lines 10 to 45 in 63c8e47
@pytest.mark.parametrize( | |
("expr", "expected"), | |
[ | |
(nw.col("a"), 0), | |
(nw.col("a").mean(), 0), | |
(nw.col("a").cum_sum(), 1), | |
(nw.col("a").cum_sum().over(order_by="id"), 0), | |
(nw.col("a").cum_sum().abs().over(order_by="id"), 1), | |
((nw.col("a").cum_sum() + 1).over(order_by="id"), 1), | |
(nw.col("a").cum_sum().cum_sum().over(order_by="id"), 1), | |
(nw.col("a").cum_sum().cum_sum(), 2), | |
(nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()), 1), | |
(nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum()).over(order_by="a"), 1), | |
(nw.sum_horizontal(nw.col("a"), nw.col("a").cum_sum().over(order_by="i")), 0), | |
( | |
nw.sum_horizontal( | |
nw.col("a").diff(), nw.col("a").cum_sum().over(order_by="i") | |
), | |
1, | |
), | |
( | |
nw.sum_horizontal(nw.col("a").diff(), nw.col("a").cum_sum()).over( | |
order_by="i" | |
), | |
2, | |
), | |
( | |
nw.sum_horizontal(nw.col("a").diff().abs(), nw.col("a").cum_sum()).over( | |
order_by="i" | |
), | |
2, | |
), | |
], | |
) | |
def test_window_kind(expr: nw.Expr, expected: int) -> None: | |
assert expr._metadata.n_orderable_ops == expected |
AFAICT, all of the expressions I've needed a InvalidOperationError
for shouldn't be valid.
But they aren't raising in current narwhals
π€
1
import narwhals as nw
a = nw.col("a")
a.cum_sum().abs().over(order_by="id")
This error explicitly mentions abs
narwhals/narwhals/_expression_parsing.py
Lines 357 to 362 in 9bd10ad
if self.is_elementwise or self.is_filtration: | |
msg = ( | |
"Cannot use `over` on expressions which are elementwise\n" | |
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)." | |
) | |
raise InvalidOperationError(msg) |
2, 3, 4
These are all raising the same as (1), but the issue seems to be that horizontal functions aren't being treated as elementwise
import narwhals as nw
a = nw.col("a")
nw.sum_horizontal(a, a.cum_sum()).over(order_by="a")
nw.sum_horizontal(a.diff(), a.cum_sum()).over(order_by="i")
nw.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i")
In polars
, they all seem to be elementwise
but with an additional flag
I've done the same in this PR, but I don't think that flag would factor into this?
narwhals/narwhals/_plan/functions.py
Lines 291 to 299 in 9bd10ad
class SumHorizontal(Function): | |
@property | |
def function_options(self) -> FunctionOptions: | |
return FunctionOptions.elementwise().with_flags( | |
FunctionFlags.INPUT_WILDCARD_EXPANSION | |
) | |
def __repr__(self) -> str: | |
return "sum_horizontal" |
Will close #2571
What type of PR is this? (check all applicable)
Related issues
Expr
internal representationΒ #2571Checklist
If you have comments or can explain your changes, please do so below
Important
See (#2571) for detail!!!!!!!
Very open to feedback
Tasks
pl.Expr.meta
pl.Expr.meta
)ExprIR
meta
methods_typing_compat
moduleΒ #2578Merge another PR with (perf: Avoid module-levelimportlib.util.find_spec
Β #2391 (comment)) firstTypeVar
defaults moreTypeVar("T", bound=Thing, default=Thing)
, instead of an opaqueExprIR
Selector
(s)narwhals/narwhals/_plan/expr.py
Lines 336 to 337 in 0bada48
BinaryExpr
that describes the restricted set of operators that are allowedpolars
, since they wrappl.col
internallyIntoExpr
in more places (including and beyond whatnarwhals
allows now)demo.py
*_horizontal
concat_str
)dummy.py
over
,sort_by
)FunctionOptions
+ friends see commentWhere does the{flags: ...}
->{collect_groups: ..., flags: ...}
expansion happen?polars>=1.3.0
fixed the issue (see comment)Ternary
when-then-otherwise
π₯³)Meta
is_*
,has_*
,output_name
meta methodsroot_names
undo_aliases
,pop
Name
name.py
(Expr::KeepName
,Expr::RenameAlias
)polars
will help with themeta
methodsCat
,Struct
,List
(a3e29d1)String
(72c33ce)DateTime
(aee0a7e)_expression_parsing.py
rulesrust
version worksExpansionFlags
expand_function_inputs
rewrite_projections
replace_selector
expand_selector
replace_selector_inner
replace_and_add_to_results
replace_nth
prepare_excluded
expand_columns
expand_dtypes
replace_dtype_or_index_with_column
dtypes_match
(probably can solve w/ existingnarwhals
)expand_indices
replace_index_with_column
replace_wildcard
rewrite_special_aliases
replace_wildcard_with_column
replace_regex
expand_regex
ExprIR.map_ir
Expr
IRΒ #2572 (comment))ExprIR.map_ir
for most nodesWindowExpr.map_ir
FunctionExpr.map_ir
RollingExpr
,AnonymousExpr
inheritselectors