Skip to content

Commit 9f9ffab

Browse files
separate rendering from benchmark function
Signed-off-by: Jaya Venkatesh <[email protected]>
1 parent 7d37e32 commit 9f9ffab

File tree

1 file changed

+107
-102
lines changed

1 file changed

+107
-102
lines changed

rapids_cli/benchmark/benchmark.py

Lines changed: 107 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def benchmark_run(
8787
Notes:
8888
-----
8989
The function discovers and loads benchmark functions defined in entry points
90-
under the 'rapids_benchmark_check' group. Each benchmark function should
90+
under the 'rapids_benchmark' group. Each benchmark function should
9191
return a tuple of (cpu_time, gpu_time) in seconds.
9292
9393
Example:
@@ -96,11 +96,33 @@ def benchmark_run(
9696
> benchmark_run(verbose=False, dry_run=False, filters=['cudf']) # Run cuDF benchmarks
9797
"""
9898
filters = [] if not filters else filters
99-
console.print(
100-
f"[bold green]{BENCHMARK_SYMBOL} Running RAPIDS benchmarks [/bold green]"
101-
)
10299

100+
# Discover benchmarks
101+
benchmarks = _discover_benchmarks(filters, verbose)
102+
103+
# Handle dry run
104+
if dry_run:
105+
_render_dry_run()
106+
return True
107+
108+
if not benchmarks:
109+
_render_no_benchmarks()
110+
return True
111+
112+
# Execute benchmarks and collect results
113+
results = _execute_benchmarks(benchmarks, runs, verbose)
114+
115+
# Render results
116+
_render_results_rich(results, verbose)
117+
118+
# Return overall success
119+
return all(result.status for result in results)
120+
121+
122+
def _discover_benchmarks(filters: list[str], verbose: bool) -> list:
123+
"""Discover available benchmark functions."""
103124
benchmarks = []
125+
104126
if verbose:
105127
console.print("Discovering benchmarks")
106128

@@ -114,19 +136,15 @@ def benchmark_run(
114136

115137
if verbose:
116138
console.print(f"Discovered {len(benchmarks)} benchmarks")
139+
140+
return benchmarks
117141

118-
if not dry_run:
119-
console.print(f"Running benchmarks ({runs} runs each)")
120-
else:
121-
console.print("Dry run, skipping benchmarks")
122-
return True
123-
124-
if not benchmarks:
125-
console.print(
126-
"[yellow]No benchmarks found. Install RAPIDS libraries to enable benchmarks.[/yellow]"
127-
)
128-
return True
129142

143+
def _execute_benchmarks(benchmarks: list, runs: int, verbose: bool) -> list[BenchmarkResult]:
144+
"""Execute all benchmarks and collect results."""
145+
console.print(f"[bold green]{BENCHMARK_SYMBOL} Running RAPIDS CPU vs GPU benchmarks [/bold green]")
146+
console.print(f"Running benchmarks ({runs} runs each)")
147+
130148
results: list[BenchmarkResult] = []
131149

132150
with Progress(
@@ -139,18 +157,18 @@ def benchmark_run(
139157
) as progress:
140158

141159
for i, benchmark_fn in enumerate(benchmarks):
142-
error = None
143-
caught_warnings = None
144-
all_cpu_times = []
145-
all_gpu_times = []
146-
147160
task_id = progress.add_task(
148161
f"[{i+1}/{len(benchmarks)}]",
149162
total=runs,
150163
benchmark_name=f"[{i+1}/{len(benchmarks)}] {benchmark_fn.__name__}",
151164
completed=0,
152165
)
153166

167+
all_cpu_times = []
168+
all_gpu_times = []
169+
caught_warnings = None
170+
error = None
171+
154172
try:
155173
for run in range(runs):
156174
with warnings.catch_warnings(record=True) as w:
@@ -168,106 +186,87 @@ def benchmark_run(
168186

169187
progress.update(task_id, completed=run + 1)
170188

189+
# Compute statistics
171190
if all_cpu_times and all_gpu_times:
172191
avg_cpu_time = sum(all_cpu_times) / len(all_cpu_times)
173192
avg_gpu_time = sum(all_gpu_times) / len(all_gpu_times)
174193
speedup = avg_cpu_time / avg_gpu_time
175-
176-
# Calculate standard deviations (only if we have multiple runs)
177-
cpu_std = (
178-
statistics.stdev(all_cpu_times)
179-
if len(all_cpu_times) > 1
180-
else 0.0
181-
)
182-
gpu_std = (
183-
statistics.stdev(all_gpu_times)
184-
if len(all_gpu_times) > 1
185-
else 0.0
186-
)
187-
194+
cpu_std = statistics.stdev(all_cpu_times) if len(all_cpu_times) > 1 else 0.0
195+
gpu_std = statistics.stdev(all_gpu_times) if len(all_gpu_times) > 1 else 0.0
188196
status = True
189-
190-
# Remove progress task and show completion summary with variance
191-
progress.remove_task(task_id)
192-
console.print(
193-
f"[green]✓[/green] [{i+1}/{len(benchmarks)}] {benchmark_fn.__name__}"
194-
)
195-
196-
# Show timing details with standard deviation
197-
if runs > 1:
198-
console.print(
199-
f" CPU Time: [red]{avg_cpu_time:.3f}s ± {cpu_std:.3f}s[/red] "
200-
f"GPU Time: [green]{avg_gpu_time:.3f}s ± {gpu_std:.3f}s[/green] "
201-
f"Speedup: [bold green]{speedup:.1f}x[/bold green]"
202-
)
203-
else:
204-
console.print(
205-
f" CPU Time: [red]{avg_cpu_time:.3f}s[/red] "
206-
f"GPU Time: [green]{avg_gpu_time:.3f}s[/green] "
207-
f"Speedup: [bold green]{speedup:.1f}x[/bold green]"
208-
)
209-
210197
else:
211-
avg_cpu_time = None
212-
avg_gpu_time = None
213-
cpu_std = None
214-
gpu_std = None
215-
speedup = None
198+
avg_cpu_time = avg_gpu_time = speedup = cpu_std = gpu_std = None
216199
status = False
217200

218-
# Remove progress and show failure
219-
progress.remove_task(task_id)
220-
console.print(
221-
f"[red]❌[/red] [{i+1}/{len(benchmarks)}] {benchmark_fn.__name__} - "
222-
f"[bold red]Failed[/bold red]"
223-
)
224-
225201
except Exception as e:
226202
error = e
227203
status = False
228-
avg_cpu_time = None
229-
avg_gpu_time = None
230-
cpu_std = None
231-
gpu_std = None
232-
speedup = None
233-
234-
# Remove progress and show failure
235-
progress.remove_task(task_id)
236-
console.print(
237-
f"[red]❌[/red] [{i+1}/{len(benchmarks)}] {benchmark_fn.__name__} - "
238-
f"[bold red]Error: {str(e)}[/bold red]"
239-
)
240-
241-
results.append(
242-
BenchmarkResult(
243-
name=benchmark_fn.__name__,
244-
description=(
245-
benchmark_fn.__doc__.strip().split("\n")[0]
246-
if benchmark_fn.__doc__
247-
else "No description"
248-
),
249-
status=status,
250-
cpu_time=avg_cpu_time,
251-
gpu_time=avg_gpu_time,
252-
cpu_std=cpu_std,
253-
gpu_std=gpu_std,
254-
speedup=speedup,
255-
error=error,
256-
warnings=caught_warnings,
257-
)
204+
avg_cpu_time = avg_gpu_time = speedup = cpu_std = gpu_std = None
205+
206+
# Create result
207+
result = BenchmarkResult(
208+
name=benchmark_fn.__name__,
209+
description=(
210+
benchmark_fn.__doc__.strip().split("\n")[0]
211+
if benchmark_fn.__doc__
212+
else "No description"
213+
),
214+
status=status,
215+
cpu_time=avg_cpu_time,
216+
gpu_time=avg_gpu_time,
217+
cpu_std=cpu_std,
218+
gpu_std=gpu_std,
219+
speedup=speedup,
220+
error=error,
221+
warnings=caught_warnings,
258222
)
223+
results.append(result)
224+
225+
# Show immediate feedback
226+
_render_benchmark_completion(result, i + 1, len(benchmarks), runs)
227+
228+
# Remove progress task
229+
progress.remove_task(task_id)
230+
231+
return results
259232

260-
# Print warnings
233+
234+
def _render_benchmark_completion(result: BenchmarkResult, index: int, total: int, runs: int):
235+
"""Render completion of a single benchmark."""
236+
if result.status:
237+
console.print(f"[green]✓[/green] [{index}/{total}] {result.name}")
238+
239+
# Show timing details
240+
cpu_display = f"{result.cpu_time:.3f}s"
241+
gpu_display = f"{result.gpu_time:.3f}s"
242+
243+
if runs > 1:
244+
cpu_display += f" ± {result.cpu_std:.3f}s"
245+
gpu_display += f" ± {result.gpu_std:.3f}s"
246+
247+
console.print(
248+
f" CPU Time: [red]{cpu_display}[/red] "
249+
f"GPU Time: [green]{gpu_display}[/green] "
250+
f"Speedup: [bold green]{result.speedup:.1f}x[/bold green]"
251+
)
252+
else:
253+
console.print(f"[red]❌[/red] [{index}/{total}] {result.name} - "
254+
f"[bold red]Failed[/bold red]")
255+
256+
257+
def _render_results_rich(results: list[BenchmarkResult], verbose: bool):
258+
"""Render final results using Rich console output."""
259+
# Show warnings
261260
for result in results:
262261
if result.warnings:
263262
for warning in result.warnings:
264263
console.print(f"[bold yellow]Warning[/bold yellow]: {warning.message}")
265264

266-
# Display results in a table
265+
# Show results table
267266
if any(result.status for result in results):
268267
_display_benchmark_results(results, verbose)
269268

270-
# Check for failures
269+
# Show failures
271270
failed_benchmarks = [result for result in results if not result.status]
272271
if failed_benchmarks:
273272
console.print("\n[bold red]Failed benchmarks:[/bold red]")
@@ -278,9 +277,6 @@ def benchmark_run(
278277
raise result.error
279278
except Exception:
280279
console.print_exception()
281-
return False
282-
283-
return True
284280

285281

286282
def _display_benchmark_results(results: list[BenchmarkResult], verbose: bool) -> None:
@@ -322,3 +318,12 @@ def _display_benchmark_results(results: list[BenchmarkResult], verbose: bool) ->
322318

323319
console.print()
324320
console.print(table)
321+
322+
def _render_dry_run():
323+
"""Render dry run output."""
324+
console.print(f"[bold green]{BENCHMARK_SYMBOL} Running RAPIDS CPU vs GPU benchmarks [/bold green]")
325+
console.print("Dry run, skipping benchmarks")
326+
327+
def _render_no_benchmarks():
328+
"""Render output when no benchmarks are found."""
329+
console.print("[yellow]No benchmarks found. Install RAPIDS libraries to enable benchmarks.[/yellow]")

0 commit comments

Comments
 (0)