Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions src/kernelbench_tinker/modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
# IMPORTANT: Keep cuda_version, flavor, operating_sys consistent with KernelBench's
# scripts/generate_and_eval_single_sample_modal.py for compatibility.

cuda_version = "12.8.0"
cuda_version = "13.0.0"
flavor = "devel" # "devel" includes full CUDA toolkit for kernel compilation
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
Expand Down Expand Up @@ -313,24 +313,34 @@ def evaluate(
# -----------------------------------------------------------------
# speedup = baseline_runtime / kernel_runtime
# Only calculated for correct kernels with valid timing data.
#
# TODO: Consider running baseline in same container for consistency
# instead of using precomputed baseline times.


runtime_ms = result.runtime if result.runtime > 0 else None
baseline_runtime_ms = (
result.metadata.get("baseline_runtime_ms")
or result.metadata.get("baseline_runtime")
or None
)

baseline_runtime_ms = None

ref_runtime = getattr(result, "ref_runtime", None)
if ref_runtime and ref_runtime > 0:
baseline_runtime_ms = ref_runtime
elif measure_performance and result.correctness and runtime_ms is not None:
try:
from kernelbench.timing import measure_ref_program_time

baseline_stats = measure_ref_program_time(
ref_arch_name="baseline",
ref_arch_src=ref_code,
num_warmup=5,
num_trials=num_perf_trials,
discard_first=1,
timing_method=timing_method,
precision=precision,
verbose=False,
)
if baseline_stats:
baseline_runtime_ms = baseline_stats.get("mean")
except Exception:
pass

speedup = None
if (
result.correctness
and runtime_ms is not None
and baseline_runtime_ms
and baseline_runtime_ms > 0
):
if result.correctness and runtime_ms and baseline_runtime_ms and baseline_runtime_ms > 0:
speedup = baseline_runtime_ms / runtime_ms

# -----------------------------------------------------------------
Expand Down