Skip to content

Commit cad2f48

Browse files
committed
Tracing node execution in cudf-polars
This PR introduces a new *low-overhead* tracing tool for cudf-polars. When enabled, we'll capture a record for each `IR.do_evaluate` node executed while running the polars query.
1 parent 9f2fe17 commit cad2f48

File tree

6 files changed

+372
-5
lines changed

6 files changed

+372
-5
lines changed

python/cudf_polars/cudf_polars/callback.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import pylibcudf
2222
import rmm
23+
import rmm.statistics
2324
from rmm._cuda import gpu
2425

2526
from cudf_polars.dsl.tracing import CUDF_POLARS_NVTX_DOMAIN
@@ -143,6 +144,7 @@ def set_memory_resource(
143144
),
144145
)
145146
rmm.mr.set_current_device_resource(mr)
147+
rmm.statistics.enable_statistics()
146148
try:
147149
yield mr
148150
finally:

python/cudf_polars/cudf_polars/dsl/ir.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from cudf_polars.dsl.expressions.base import ExecutionContext
3434
from cudf_polars.dsl.nodebase import Node
3535
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
36-
from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
36+
from cudf_polars.dsl.tracing import log_do_evaluate, nvtx_annotate_cudf_polars
3737
from cudf_polars.dsl.utils.reshape import broadcast
3838
from cudf_polars.dsl.utils.windows import range_window_bounds
3939
from cudf_polars.utils import dtypes
@@ -439,6 +439,7 @@ def fast_count(self) -> int: # pragma: no cover
439439

440440
@classmethod
441441
@nvtx_annotate_cudf_polars(message="Scan")
442+
@log_do_evaluate
442443
def do_evaluate(
443444
cls,
444445
schema: Schema,
@@ -929,6 +930,7 @@ def _write_parquet(
929930

930931
@classmethod
931932
@nvtx_annotate_cudf_polars(message="Sink")
933+
@log_do_evaluate
932934
def do_evaluate(
933935
cls,
934936
schema: Schema,
@@ -988,6 +990,7 @@ def is_equal(self, other: Self) -> bool: # noqa: D102
988990

989991
@classmethod
990992
@nvtx_annotate_cudf_polars(message="Cache")
993+
@log_do_evaluate
991994
def do_evaluate(
992995
cls, key: int, refcount: int | None, df: DataFrame
993996
) -> DataFrame: # pragma: no cover; basic evaluation never calls this
@@ -1068,6 +1071,7 @@ def get_hashable(self) -> Hashable:
10681071

10691072
@classmethod
10701073
@nvtx_annotate_cudf_polars(message="DataFrameScan")
1074+
@log_do_evaluate
10711075
def do_evaluate(
10721076
cls,
10731077
schema: Schema,
@@ -1127,6 +1131,7 @@ def _is_len_expr(exprs: tuple[expr.NamedExpr, ...]) -> bool: # pragma: no cover
11271131

11281132
@classmethod
11291133
@nvtx_annotate_cudf_polars(message="Select")
1134+
@log_do_evaluate
11301135
def do_evaluate(
11311136
cls,
11321137
exprs: tuple[expr.NamedExpr, ...],
@@ -1210,6 +1215,7 @@ def __init__(
12101215

12111216
@classmethod
12121217
@nvtx_annotate_cudf_polars(message="Reduce")
1218+
@log_do_evaluate
12131219
def do_evaluate(
12141220
cls,
12151221
exprs: tuple[expr.NamedExpr, ...],
@@ -1306,6 +1312,7 @@ def __init__(
13061312

13071313
@classmethod
13081314
@nvtx_annotate_cudf_polars(message="Rolling")
1315+
@log_do_evaluate
13091316
def do_evaluate(
13101317
cls,
13111318
index: expr.NamedExpr,
@@ -1430,6 +1437,7 @@ def __init__(
14301437

14311438
@classmethod
14321439
@nvtx_annotate_cudf_polars(message="GroupBy")
1440+
@log_do_evaluate
14331441
def do_evaluate(
14341442
cls,
14351443
schema: Schema,
@@ -1594,6 +1602,7 @@ def __init__(
15941602

15951603
@classmethod
15961604
@nvtx_annotate_cudf_polars(message="ConditionalJoin")
1605+
@log_do_evaluate
15971606
def do_evaluate(
15981607
cls,
15991608
predicate_wrapper: Predicate,
@@ -1805,6 +1814,7 @@ def _build_columns(
18051814

18061815
@classmethod
18071816
@nvtx_annotate_cudf_polars(message="Join")
1817+
@log_do_evaluate
18081818
def do_evaluate(
18091819
cls,
18101820
left_on_exprs: Sequence[expr.NamedExpr],
@@ -1950,6 +1960,7 @@ def __init__(
19501960

19511961
@classmethod
19521962
@nvtx_annotate_cudf_polars(message="HStack")
1963+
@log_do_evaluate
19531964
def do_evaluate(
19541965
cls,
19551966
exprs: Sequence[expr.NamedExpr],
@@ -2015,6 +2026,7 @@ def __init__(
20152026

20162027
@classmethod
20172028
@nvtx_annotate_cudf_polars(message="Distinct")
2029+
@log_do_evaluate
20182030
def do_evaluate(
20192031
cls,
20202032
keep: plc.stream_compaction.DuplicateKeepOption,
@@ -2105,6 +2117,7 @@ def __init__(
21052117

21062118
@classmethod
21072119
@nvtx_annotate_cudf_polars(message="Sort")
2120+
@log_do_evaluate
21082121
def do_evaluate(
21092122
cls,
21102123
by: Sequence[expr.NamedExpr],
@@ -2155,6 +2168,7 @@ def __init__(self, schema: Schema, offset: int, length: int | None, df: IR):
21552168

21562169
@classmethod
21572170
@nvtx_annotate_cudf_polars(message="Slice")
2171+
@log_do_evaluate
21582172
def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame:
21592173
"""Evaluate and return a dataframe."""
21602174
return df.slice((offset, length))
@@ -2176,6 +2190,7 @@ def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR):
21762190

21772191
@classmethod
21782192
@nvtx_annotate_cudf_polars(message="Filter")
2193+
@log_do_evaluate
21792194
def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame:
21802195
"""Evaluate and return a dataframe."""
21812196
(mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows)
@@ -2195,6 +2210,7 @@ def __init__(self, schema: Schema, df: IR):
21952210

21962211
@classmethod
21972212
@nvtx_annotate_cudf_polars(message="Projection")
2213+
@log_do_evaluate
21982214
def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame:
21992215
"""Evaluate and return a dataframe."""
22002216
# This can reorder things.
@@ -2224,6 +2240,7 @@ def __init__(self, schema: Schema, key: str, left: IR, right: IR):
22242240

22252241
@classmethod
22262242
@nvtx_annotate_cudf_polars(message="MergeSorted")
2243+
@log_do_evaluate
22272244
def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
22282245
"""Evaluate and return a dataframe."""
22292246
left, right = dfs
@@ -2344,6 +2361,7 @@ def get_hashable(self) -> Hashable:
23442361

23452362
@classmethod
23462363
@nvtx_annotate_cudf_polars(message="MapFunction")
2364+
@log_do_evaluate
23472365
def do_evaluate(
23482366
cls, schema: Schema, name: str, options: Any, df: DataFrame
23492367
) -> DataFrame:
@@ -2444,6 +2462,7 @@ def __init__(self, schema: Schema, zlice: Zlice | None, *children: IR):
24442462

24452463
@classmethod
24462464
@nvtx_annotate_cudf_polars(message="Union")
2465+
@log_do_evaluate
24472466
def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
24482467
"""Evaluate and return a dataframe."""
24492468
# TODO: only evaluate what we need if we have a slice?
@@ -2501,6 +2520,7 @@ def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table:
25012520

25022521
@classmethod
25032522
@nvtx_annotate_cudf_polars(message="HConcat")
2523+
@log_do_evaluate
25042524
def do_evaluate(
25052525
cls,
25062526
should_broadcast: bool, # noqa: FBT001
@@ -2546,6 +2566,7 @@ def __init__(self, schema: Schema):
25462566

25472567
@classmethod
25482568
@nvtx_annotate_cudf_polars(message="Empty")
2569+
@log_do_evaluate
25492570
def do_evaluate(cls, schema: Schema) -> DataFrame: # pragma: no cover
25502571
"""Evaluate and return a dataframe."""
25512572
return DataFrame(

python/cudf_polars/cudf_polars/dsl/tracing.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,145 @@
66
from __future__ import annotations
77

88
import functools
9+
import os
10+
import time
11+
from typing import TYPE_CHECKING, Any, Literal
912

1013
import nvtx
14+
from typing_extensions import ParamSpec
15+
16+
import rmm
17+
import rmm.statistics
18+
19+
import cudf_polars.containers
20+
21+
try:
22+
import structlog
23+
except ImportError:
24+
HAS_STRUCTLOG = False
25+
else:
26+
HAS_STRUCTLOG = True
27+
28+
# Question: should this be toggleable at runtime?
29+
LOG_TRACES = HAS_STRUCTLOG and os.environ.get("CUDF_POLARS_LOG_TRACES", "0") in {
30+
"1",
31+
"true",
32+
"y",
33+
"yes",
34+
}
1135

1236
CUDF_POLARS_NVTX_DOMAIN = "cudf_polars"
1337

1438
nvtx_annotate_cudf_polars = functools.partial(
1539
nvtx.annotate, domain=CUDF_POLARS_NVTX_DOMAIN
1640
)
41+
42+
if TYPE_CHECKING:
43+
from collections.abc import Callable
44+
45+
from cudf_polars.dsl import ir
46+
47+
48+
def make_snaphot(
49+
node_type: type[ir.IR],
50+
frames: list[cudf_polars.containers.DataFrame],
51+
extra: dict[str, Any] | None = None,
52+
phase: Literal["input", "output"] = "input",
53+
) -> dict:
54+
"""
55+
Log the evaluation of an IR node.
56+
57+
Parameters
58+
----------
59+
node_type
60+
The type of the IR node.
61+
frames
62+
The frames being evaluated.
63+
extra
64+
Extra information to log.
65+
phase
66+
The phase of the evaluation. Either "input" or "output".
67+
"""
68+
ir_name = node_type.__name__
69+
70+
d = {
71+
"type": ir_name,
72+
f"count_frames_{phase}": len(frames),
73+
f"frames_{phase}": [
74+
{
75+
"shape": frame.table.shape(),
76+
"size": sum(col.device_buffer_size() for col in frame.table.columns()),
77+
}
78+
for frame in frames
79+
],
80+
}
81+
d[f"total_bytes_{phase}"] = sum(x["size"] for x in d[f"frames_{phase}"]) # type: ignore[attr-defined]
82+
83+
stats = rmm.statistics.get_statistics()
84+
if stats:
85+
d.update(
86+
{
87+
f"rmm_current_bytes_{phase}": stats.current_bytes,
88+
f"rmm_current_count_{phase}": stats.current_count,
89+
f"rmm_peak_bytes_{phase}": stats.peak_bytes,
90+
f"rmm_peak_count_{phase}": stats.peak_count,
91+
f"rmm_total_bytes_{phase}": stats.total_bytes,
92+
f"rmm_total_count_{phase}": stats.total_count,
93+
}
94+
)
95+
96+
if extra:
97+
d.update(extra)
98+
99+
# log.info("Execute IR", **d)
100+
return d
101+
102+
103+
P = ParamSpec("P")
104+
105+
106+
def log_do_evaluate(
107+
func: Callable[P, cudf_polars.containers.DataFrame],
108+
) -> Callable[P, cudf_polars.containers.DataFrame]:
109+
"""
110+
Decorator for an ``IR.do_evaluate`` method that logs information before and after evaluation.
111+
112+
Parameters
113+
----------
114+
func
115+
The ``IR.do_evaluate`` method to wrap.
116+
"""
117+
118+
@functools.wraps(func)
119+
def wrapper(
120+
cls: type[ir.IR], *args: Any, **kwargs: Any
121+
) -> cudf_polars.containers.DataFrame:
122+
if LOG_TRACES:
123+
log = structlog.get_logger()
124+
frames = [
125+
arg
126+
for arg in list(args) + list(kwargs.values())
127+
if isinstance(arg, cudf_polars.containers.DataFrame)
128+
]
129+
130+
before = make_snaphot(cls, frames, phase="input")
131+
132+
# TODO: fix these types! Want some way to say
133+
# Callable[ir.IR, *P.args, **P.kwargs], cudf_polars.containers.DataFrame]
134+
# i.e. the first arg is an IR, it returns a DataFrame, and does
135+
# whatever for the remaining args/kwargs.
136+
start = time.monotonic_ns()
137+
result = func(cls, *args, **kwargs) # type: ignore
138+
stop = time.monotonic_ns()
139+
140+
after = make_snaphot(
141+
cls, [result], phase="output", extra={"start": start, "stop": stop}
142+
)
143+
record = before | after
144+
log.info("Execute IR", **record)
145+
146+
return result
147+
else:
148+
return func(cls, *args, **kwargs) # type: ignore
149+
150+
return wrapper # type: ignore

python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def run_duckdb(benchmark: Any, options: Sequence[str] | None = None) -> None:
124124
result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
125125

126126
t1 = time.time()
127-
record = Record(query=q_id, duration=t1 - t0)
127+
record = Record(query=q_id, iteration=i, duration=t1 - t0)
128128
if args.print_results:
129129
print(result)
130130

0 commit comments

Comments
 (0)