Skip to content

Commit 0bada48

Browse files
committed
feat: Permissive parsing in *_horiztonal
Threw in some tests for hashing as well
1 parent fb6bc07 commit 0bada48

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

narwhals/_plan/demo.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,33 +95,33 @@ def sum(*columns: str) -> DummyExpr:
9595
return col(columns).sum()
9696

9797

98-
def all_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr:
99-
it = (expr._ir for expr in flatten(exprs))
98+
def all_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr:
99+
it = parse.parse_into_seq_of_expr_ir(*exprs)
100100
return boolean.AllHorizontal().to_function_expr(*it).to_narwhals()
101101

102102

103-
def any_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr:
104-
it = (expr._ir for expr in flatten(exprs))
103+
def any_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr:
104+
it = parse.parse_into_seq_of_expr_ir(*exprs)
105105
return boolean.AnyHorizontal().to_function_expr(*it).to_narwhals()
106106

107107

108-
def sum_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr:
109-
it = (expr._ir for expr in flatten(exprs))
108+
def sum_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr:
109+
it = parse.parse_into_seq_of_expr_ir(*exprs)
110110
return F.SumHorizontal().to_function_expr(*it).to_narwhals()
111111

112112

113-
def min_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr:
114-
it = (expr._ir for expr in flatten(exprs))
113+
def min_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr:
114+
it = parse.parse_into_seq_of_expr_ir(*exprs)
115115
return F.MinHorizontal().to_function_expr(*it).to_narwhals()
116116

117117

118-
def max_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr:
119-
it = (expr._ir for expr in flatten(exprs))
118+
def max_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr:
119+
it = parse.parse_into_seq_of_expr_ir(*exprs)
120120
return F.MaxHorizontal().to_function_expr(*it).to_narwhals()
121121

122122

123-
def mean_horizontal(*exprs: DummyExpr | t.Iterable[DummyExpr]) -> DummyExpr:
124-
it = (expr._ir for expr in flatten(exprs))
123+
def mean_horizontal(*exprs: IntoExpr | t.Iterable[IntoExpr]) -> DummyExpr:
124+
it = parse.parse_into_seq_of_expr_ir(*exprs)
125125
return F.MeanHorizontal().to_function_expr(*it).to_narwhals()
126126

127127

tests/plan/expr_parsing_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
from __future__ import annotations
22

33
from typing import TYPE_CHECKING
4+
from typing import Callable
45
from typing import Iterable
56

67
import pytest
78

89
import narwhals as nw
910
import narwhals._plan.demo as nwd
11+
from narwhals._plan import boolean
12+
from narwhals._plan import functions as F # noqa: N812
1013
from narwhals._plan.common import ExprIR
14+
from narwhals._plan.common import Function
15+
from narwhals._plan.dummy import DummyExpr
16+
from narwhals._plan.expr import FunctionExpr
1117

1218
if TYPE_CHECKING:
1319
from narwhals._plan.common import IntoExpr
@@ -35,3 +41,42 @@ def test_parsing(
3541
assert all(
3642
isinstance(node, ExprIR) for node in nwd.select_context(*exprs, **named_exprs)
3743
)
44+
45+
46+
@pytest.mark.parametrize(
47+
("function", "ir_node"),
48+
[
49+
(nwd.all_horizontal, boolean.AllHorizontal),
50+
(nwd.any_horizontal, boolean.AnyHorizontal),
51+
(nwd.sum_horizontal, F.SumHorizontal),
52+
(nwd.min_horizontal, F.MinHorizontal),
53+
(nwd.max_horizontal, F.MaxHorizontal),
54+
(nwd.mean_horizontal, F.MeanHorizontal),
55+
],
56+
)
57+
@pytest.mark.parametrize(
58+
"args",
59+
[
60+
("a", "b", "c"),
61+
(["a", "b", "c"]),
62+
(nwd.col("d", "e", "f"), nwd.col("g"), "q", nwd.nth(9)),
63+
((nwd.lit(1),)),
64+
([nwd.lit(1), nwd.lit(2), nwd.lit(3)]),
65+
],
66+
)
67+
def test_function_expr_horizontal(
68+
function: Callable[..., DummyExpr],
69+
ir_node: type[Function],
70+
args: Seq[IntoExpr | Iterable[IntoExpr]],
71+
) -> None:
72+
variadic = function(*args)
73+
sequence = function(args)
74+
assert isinstance(variadic, DummyExpr)
75+
assert isinstance(sequence, DummyExpr)
76+
variadic_node = variadic._ir
77+
sequence_node = sequence._ir
78+
unrelated_node = nwd.lit(1)._ir
79+
assert isinstance(variadic_node, FunctionExpr)
80+
assert isinstance(variadic_node.function, ir_node)
81+
assert variadic_node == sequence_node
82+
assert sequence_node != unrelated_node

0 commit comments

Comments
 (0)