Skip to content

Commit f4b5119

Browse files
timsaucerclaude
andcommitted
refactor(spark): reshape varargs to match pyspark signatures
Replace generic ``*args`` with explicit pyspark-style signatures: - json_tuple(col, *fields) — first positional is the JSON expr - format_string(format, *cols) — `format` is the printf template; a plain ``str`` is auto-promoted to a literal - parse_url(url, partToExtract, key=None) — `key` is optional and only meaningful with ``partToExtract='QUERY'`` - try_parse_url(url, partToExtract, key=None) — same shape - url_decode(str), try_url_decode(str), url_encode(str) — single-argument forms (multi-arg calls were always semantically wrong) Tests cover the three-arg parse_url path and the plain-str format_string auto-promotion. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent e9113dd commit f4b5119

2 files changed

Lines changed: 60 additions & 16 deletions

File tree

python/datafusion/functions/spark.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,7 @@ def xxhash64(*cols: Expr) -> Expr:
958958
# ---------------------------------------------------------------------------
959959

960960

961-
def json_tuple(*args: Expr) -> Expr:
961+
def json_tuple(col: Expr, *fields: Expr) -> Expr:
962962
"""Spark ``json_tuple``: extract top-level fields from a JSON string.
963963
964964
Examples:
@@ -972,7 +972,7 @@ def json_tuple(*args: Expr) -> Expr:
972972
>>> r.collect_column("v")[0].as_py()
973973
{'c0': '1', 'c1': 'x'}
974974
"""
975-
return Expr(_f.json_tuple(*[a.expr for a in args]))
975+
return Expr(_f.json_tuple(col.expr, *[f.expr for f in fields]))
976976

977977

978978
# ---------------------------------------------------------------------------
@@ -1439,23 +1439,25 @@ def luhn_check(col: Expr) -> Expr:
14391439
return Expr(_f.luhn_check(col.expr))
14401440

14411441

1442-
def format_string(*args: Expr) -> Expr:
1442+
def format_string(format: str | Expr, *cols: Expr) -> Expr:
14431443
"""Spark ``format_string``: printf-style format string.
14441444
1445-
First arg is the format, remaining args are values to substitute.
1445+
``format`` is the printf-style template (a plain ``str`` is auto-promoted
1446+
to a literal expression); remaining args are values to substitute.
14461447
14471448
Examples:
14481449
>>> ctx = dfn.SessionContext()
14491450
>>> df = ctx.from_pydict({"x": [1]})
14501451
>>> r = df.select(
14511452
... dfn.functions.spark.format_string(
1452-
... dfn.lit("%d-%s"), dfn.lit(42), dfn.lit("hi")
1453+
... "%d-%s", dfn.lit(42), dfn.lit("hi")
14531454
... ).alias("v")
14541455
... )
14551456
>>> r.collect_column("v")[0].as_py()
14561457
'42-hi'
14571458
"""
1458-
return Expr(_f.format_string(*[a.expr for a in args]))
1459+
fmt_expr = format if isinstance(format, Expr) else Expr.literal(format)
1460+
return Expr(_f.format_string(fmt_expr.expr, *[c.expr for c in cols]))
14591461

14601462

14611463
def space(col: Expr) -> Expr:
@@ -1555,9 +1557,17 @@ def make_valid_utf8(str: Expr) -> Expr:
15551557
# ---------------------------------------------------------------------------
15561558

15571559

1558-
def parse_url(*args: Expr) -> Expr:
1560+
def parse_url(
1561+
url: Expr,
1562+
partToExtract: Expr, # noqa: N803
1563+
key: Expr | None = None,
1564+
) -> Expr:
15591565
"""Spark ``parse_url``: extract a part from a URL; errors on invalid URLs.
15601566
1567+
``partToExtract`` is one of ``"HOST"``, ``"PATH"``, ``"QUERY"``,
1568+
``"REF"``, ``"PROTOCOL"``, ``"FILE"``, ``"AUTHORITY"``, ``"USERINFO"``.
1569+
Pass ``key`` only with ``"QUERY"`` to extract a single parameter.
1570+
15611571
Examples:
15621572
>>> ctx = dfn.SessionContext()
15631573
>>> df = ctx.from_pydict({"x": [1]})
@@ -1568,11 +1578,27 @@ def parse_url(*args: Expr) -> Expr:
15681578
... )
15691579
>>> r.collect_column("v")[0].as_py()
15701580
'example.com'
1581+
1582+
>>> r = df.select(
1583+
... dfn.functions.spark.parse_url(
1584+
... dfn.lit("http://example.com/path?q=1"),
1585+
... dfn.lit("QUERY"),
1586+
... key=dfn.lit("q"),
1587+
... ).alias("v")
1588+
... )
1589+
>>> r.collect_column("v")[0].as_py()
1590+
'1'
15711591
"""
1572-
return Expr(_f.parse_url(*[a.expr for a in args]))
1592+
if key is None:
1593+
return Expr(_f.parse_url(url.expr, partToExtract.expr))
1594+
return Expr(_f.parse_url(url.expr, partToExtract.expr, key.expr))
15731595

15741596

1575-
def try_parse_url(*args: Expr) -> Expr:
1597+
def try_parse_url(
1598+
url: Expr,
1599+
partToExtract: Expr, # noqa: N803
1600+
key: Expr | None = None,
1601+
) -> Expr:
15761602
"""Spark ``try_parse_url``: like ``parse_url`` but returns NULL on invalid URLs.
15771603
15781604
Examples:
@@ -1586,10 +1612,12 @@ def try_parse_url(*args: Expr) -> Expr:
15861612
>>> r.collect_column("v")[0].as_py()
15871613
'example.com'
15881614
"""
1589-
return Expr(_f.try_parse_url(*[a.expr for a in args]))
1615+
if key is None:
1616+
return Expr(_f.try_parse_url(url.expr, partToExtract.expr))
1617+
return Expr(_f.try_parse_url(url.expr, partToExtract.expr, key.expr))
15901618

15911619

1592-
def url_decode(*args: Expr) -> Expr:
1620+
def url_decode(str: Expr) -> Expr:
15931621
"""Spark ``url_decode``: decode an application/x-www-form-urlencoded string.
15941622
15951623
Examples:
@@ -1600,10 +1628,10 @@ def url_decode(*args: Expr) -> Expr:
16001628
>>> r.collect_column("v")[0].as_py()
16011629
'a b'
16021630
"""
1603-
return Expr(_f.url_decode(*[a.expr for a in args]))
1631+
return Expr(_f.url_decode(str.expr))
16041632

16051633

1606-
def try_url_decode(*args: Expr) -> Expr:
1634+
def try_url_decode(str: Expr) -> Expr:
16071635
"""Spark ``try_url_decode``: like ``url_decode``; returns NULL on invalid input.
16081636
16091637
Examples:
@@ -1614,10 +1642,10 @@ def try_url_decode(*args: Expr) -> Expr:
16141642
>>> r.collect_column("v")[0].as_py()
16151643
'a b'
16161644
"""
1617-
return Expr(_f.try_url_decode(*[a.expr for a in args]))
1645+
return Expr(_f.try_url_decode(str.expr))
16181646

16191647

1620-
def url_encode(*args: Expr) -> Expr:
1648+
def url_encode(str: Expr) -> Expr:
16211649
"""Spark ``url_encode``: encode a string in application/x-www-form-urlencoded.
16221650
16231651
Examples:
@@ -1628,7 +1656,7 @@ def url_encode(*args: Expr) -> Expr:
16281656
>>> r.collect_column("v")[0].as_py()
16291657
'a+b'
16301658
"""
1631-
return Expr(_f.url_encode(*[a.expr for a in args]))
1659+
return Expr(_f.url_encode(str.expr))
16321660

16331661

16341662
__all__ = [

python/tests/test_spark_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,22 @@ def test_like_escape_raises():
278278
spark.ilike(lit("a"), lit("a"), escapeChar="\\")
279279

280280

281+
def test_parse_url_three_arg():
282+
"""parse_url(url, partToExtract, key=...) extracts query parameters."""
283+
ctx = SessionContext()
284+
df = ctx.from_pydict({"x": [1]})
285+
url = lit("http://example.com/p?q=hello&n=1")
286+
assert _val(df, spark.parse_url(url, lit("QUERY"), key=lit("q"))) == "hello"
287+
assert _val(df, spark.try_parse_url(url, lit("QUERY"), key=lit("n"))) == "1"
288+
289+
290+
def test_format_string_plain_str_format():
291+
"""format_string accepts a plain str format that is auto-promoted to lit."""
292+
ctx = SessionContext()
293+
df = ctx.from_pydict({"x": [1]})
294+
assert _val(df, spark.format_string("%d-%s", lit(42), lit("hi"))) == "42-hi"
295+
296+
281297
def test_aggregate_positional_compat():
282298
"""Pyspark-style positional calls still work after the rename to ``col``."""
283299
ctx = SessionContext()

0 commit comments

Comments
 (0)