|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from typing import TYPE_CHECKING
|
| 4 | +from typing import Callable |
4 | 5 | from typing import Iterable
|
5 | 6 |
|
6 | 7 | import pytest
|
7 | 8 |
|
8 | 9 | import narwhals as nw
|
9 | 10 | import narwhals._plan.demo as nwd
|
| 11 | +from narwhals._plan import boolean |
| 12 | +from narwhals._plan import functions as F # noqa: N812 |
10 | 13 | from narwhals._plan.common import ExprIR
|
| 14 | +from narwhals._plan.common import Function |
| 15 | +from narwhals._plan.dummy import DummyExpr |
| 16 | +from narwhals._plan.expr import FunctionExpr |
11 | 17 |
|
12 | 18 | if TYPE_CHECKING:
|
13 | 19 | from narwhals._plan.common import IntoExpr
|
@@ -35,3 +41,42 @@ def test_parsing(
|
35 | 41 | assert all(
|
36 | 42 | isinstance(node, ExprIR) for node in nwd.select_context(*exprs, **named_exprs)
|
37 | 43 | )
|
| 44 | + |
| 45 | + |
| 46 | +@pytest.mark.parametrize( |
| 47 | + ("function", "ir_node"), |
| 48 | + [ |
| 49 | + (nwd.all_horizontal, boolean.AllHorizontal), |
| 50 | + (nwd.any_horizontal, boolean.AnyHorizontal), |
| 51 | + (nwd.sum_horizontal, F.SumHorizontal), |
| 52 | + (nwd.min_horizontal, F.MinHorizontal), |
| 53 | + (nwd.max_horizontal, F.MaxHorizontal), |
| 54 | + (nwd.mean_horizontal, F.MeanHorizontal), |
| 55 | + ], |
| 56 | +) |
| 57 | +@pytest.mark.parametrize( |
| 58 | + "args", |
| 59 | + [ |
| 60 | + ("a", "b", "c"), |
| 61 | + (["a", "b", "c"]), |
| 62 | + (nwd.col("d", "e", "f"), nwd.col("g"), "q", nwd.nth(9)), |
| 63 | + ((nwd.lit(1),)), |
| 64 | + ([nwd.lit(1), nwd.lit(2), nwd.lit(3)]), |
| 65 | + ], |
| 66 | +) |
| 67 | +def test_function_expr_horizontal( |
| 68 | + function: Callable[..., DummyExpr], |
| 69 | + ir_node: type[Function], |
| 70 | + args: Seq[IntoExpr | Iterable[IntoExpr]], |
| 71 | +) -> None: |
| 72 | + variadic = function(*args) |
| 73 | + sequence = function(args) |
| 74 | + assert isinstance(variadic, DummyExpr) |
| 75 | + assert isinstance(sequence, DummyExpr) |
| 76 | + variadic_node = variadic._ir |
| 77 | + sequence_node = sequence._ir |
| 78 | + unrelated_node = nwd.lit(1)._ir |
| 79 | + assert isinstance(variadic_node, FunctionExpr) |
| 80 | + assert isinstance(variadic_node.function, ir_node) |
| 81 | + assert variadic_node == sequence_node |
| 82 | + assert sequence_node != unrelated_node |
0 commit comments