diff --git a/src/kernelbench_tinker/modal/app.py b/src/kernelbench_tinker/modal/app.py index 01fa67b..ce87066 100644 --- a/src/kernelbench_tinker/modal/app.py +++ b/src/kernelbench_tinker/modal/app.py @@ -81,7 +81,7 @@ # IMPORTANT: Keep cuda_version, flavor, operating_sys consistent with KernelBench's # scripts/generate_and_eval_single_sample_modal.py for compatibility. -cuda_version = "12.8.0" +cuda_version = "13.0.0" flavor = "devel" # "devel" includes full CUDA toolkit for kernel compilation operating_sys = "ubuntu22.04" tag = f"{cuda_version}-{flavor}-{operating_sys}" @@ -313,24 +313,34 @@ def evaluate( # ----------------------------------------------------------------- # speedup = baseline_runtime / kernel_runtime # Only calculated for correct kernels with valid timing data. - # - # TODO: Consider running baseline in same container for consistency - # instead of using precomputed baseline times. - + runtime_ms = result.runtime if result.runtime > 0 else None - baseline_runtime_ms = ( - result.metadata.get("baseline_runtime_ms") - or result.metadata.get("baseline_runtime") - or None - ) - + baseline_runtime_ms = None + + ref_runtime = getattr(result, "ref_runtime", None) + if ref_runtime and ref_runtime > 0: + baseline_runtime_ms = ref_runtime + elif measure_performance and result.correctness and runtime_ms is not None: + try: + from kernelbench.timing import measure_ref_program_time + + baseline_stats = measure_ref_program_time( + ref_arch_name="baseline", + ref_arch_src=ref_code, + num_warmup=5, + num_trials=num_perf_trials, + discard_first=1, + timing_method=timing_method, + precision=precision, + verbose=False, + ) + if baseline_stats: + baseline_runtime_ms = baseline_stats.get("mean") + except Exception: + pass + speedup = None - if ( - result.correctness - and runtime_ms is not None - and baseline_runtime_ms - and baseline_runtime_ms > 0 - ): + if result.correctness and runtime_ms and baseline_runtime_ms and baseline_runtime_ms > 0: speedup = baseline_runtime_ms / runtime_ms # -----------------------------------------------------------------