Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 90 additions & 36 deletions bench.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
"""
bench.py -- AutoKernel benchmark harness (FIXED -- the agent NEVER modifies this file).
bench.py -- AutoKernel benchmark harness.

Handles:
1. GPU hardware detection and roofline modelling
Expand Down Expand Up @@ -32,6 +32,46 @@
import torch
import torch.nn.functional as F


_USE_XPU = (not torch.cuda.is_available()) and hasattr(torch, "xpu") and torch.xpu.is_available()
_DEFAULT_DEVICE = "xpu" if _USE_XPU else "cuda"


def _sync_device() -> None:
if _USE_XPU:
torch.xpu.synchronize()
else:
torch.cuda.synchronize()


def _empty_cache() -> None:
if _USE_XPU:
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()


def _reset_peak_memory_stats() -> None:
if _USE_XPU:
torch.xpu.reset_peak_memory_stats()
else:
torch.cuda.reset_peak_memory_stats()


def _max_memory_allocated() -> int:
if _USE_XPU:
return torch.xpu.max_memory_allocated()
return torch.cuda.max_memory_allocated()


def _event(enable_timing: bool = True):
if _USE_XPU:
return torch.xpu.Event(enable_timing=enable_timing)
return torch.cuda.Event(enable_timing=enable_timing)


_OOM_ERROR = torch.OutOfMemoryError if _USE_XPU else torch.cuda.OutOfMemoryError

# ---------------------------------------------------------------------------
# Timeout helper (cross-platform)
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -132,15 +172,15 @@ class GPUSpec:

def detect_gpu() -> GPUSpec:
"""Auto-detect current GPU and return its spec."""
if not torch.cuda.is_available():
print("WARNING: No CUDA GPU detected, using dummy spec")
if (not torch.cuda.is_available()) and (not _USE_XPU):
print("WARNING: No CUDA/XPU GPU detected, using dummy spec")
return GPUSpec()

props = torch.cuda.get_device_properties(0)
props = torch.xpu.get_device_properties(0) if _USE_XPU else torch.cuda.get_device_properties(0)
name = props.name
sm_count = props.multi_processor_count
sm_count = getattr(props, "multi_processor_count", 0)
memory_gb = round(props.total_memory / (1024 ** 3), 1)
cc = (props.major, props.minor)
cc = (getattr(props, "major", 0), getattr(props, "minor", 0))

# On ROCm, device name may be empty; try gcnArchName-based lookup first
gcn_arch = getattr(props, 'gcnArchName', '')
Expand Down Expand Up @@ -473,6 +513,10 @@ def _dtype_bytes(dtype: torch.dtype) -> int:
torch.bfloat16: {"atol": 2e-2, "rtol": 2e-2},
torch.float32: {"atol": 1e-4, "rtol": 1e-4},
},
# Intel XPU bf16 can accumulate larger drift at big shapes.
"xpu_tolerances": {
torch.bfloat16: {"atol": 1e-1, "rtol": 5e-2},
},
# gate_proj + up_proj + silu + mul + down_proj
"flops_fn": lambda s: 2 * s["batch"] * s["dim"] * s["hidden"] * 3,
"bytes_fn": lambda s, dt: (s["batch"] * s["dim"] + s["hidden"] * s["dim"] * 3 + s["batch"] * s["dim"]) * _dtype_bytes(dt),
Expand Down Expand Up @@ -640,7 +684,7 @@ def _has_nan_inf(t: torch.Tensor) -> bool:

def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> dict:
"""Run all correctness stages. Returns dict with results."""
device = "cuda"
device = _DEFAULT_DEVICE
results = {
"smoke_test": "SKIP",
"shape_sweep": "SKIP",
Expand All @@ -656,7 +700,10 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
ref_fn = config["reference_fn"]
sizes = config["test_sizes"]
dtypes = config["test_dtypes"]
tols = config["tolerances"]
tols = dict(config["tolerances"])
if _USE_XPU:
for dt, tol in config.get("xpu_tolerances", {}).items():
tols[dt] = tol

# ------------------------------------------------------------------
# Stage 1: SMOKE TEST -- tiny input, tight tolerance
Expand Down Expand Up @@ -692,7 +739,7 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
details.append(" smoke: TIMEOUT")
all_pass = False
print(" FAIL: TIMEOUT")
except torch.cuda.OutOfMemoryError:
except _OOM_ERROR:
results["smoke_test"] = "FAIL"
details.append(" smoke: OOM")
all_pass = False
Expand Down Expand Up @@ -751,10 +798,10 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
else:
print(f" PASS: {label} {dtype} (max_err={cmp['max_abs_error']:.2e}, within_tol={cmp['pct_within_tol']:.1f}%)")

except torch.cuda.OutOfMemoryError:
except _OOM_ERROR:
# OOM on larger sizes is acceptable -- just skip
print(f" SKIP: {label} {dtype} -> OOM")
torch.cuda.empty_cache()
_empty_cache()
continue
except BenchTimeoutError:
sweep_pass = False
Expand All @@ -767,7 +814,7 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
details.append(f" sweep {label}/{dtype}: {type(e).__name__}: {e}")
print(f" FAIL: {label} {dtype} -> {type(e).__name__}: {e}")
finally:
torch.cuda.empty_cache()
_empty_cache()

if sweep_pass:
results["shape_sweep"] = f"PASS ({sweep_count} configs, worst_err={worst_error:.2e} at {worst_case})"
Expand Down Expand Up @@ -850,9 +897,9 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
details.append(f" stability {case_name}: {cmp['reason']}")
print(f" FAIL: {case_name} -> {cmp['reason']}")

except torch.cuda.OutOfMemoryError:
except _OOM_ERROR:
print(f" SKIP: {case_name} -> OOM")
torch.cuda.empty_cache()
_empty_cache()
except BenchTimeoutError:
stability_pass = False
details.append(f" stability {case_name}: TIMEOUT")
Expand All @@ -862,7 +909,7 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
details.append(f" stability {case_name}: {type(e).__name__}: {e}")
print(f" FAIL: {case_name} -> {type(e).__name__}: {e}")
finally:
torch.cuda.empty_cache()
_empty_cache()

results["numerical_stability"] = "PASS" if stability_pass else "FAIL"
if not stability_pass:
Expand Down Expand Up @@ -903,7 +950,7 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
details.append(f" determinism: {type(e).__name__}: {e}")
print(f" FAIL: {type(e).__name__}: {e}")
finally:
torch.cuda.empty_cache()
_empty_cache()

if not determinism_pass:
all_pass = False
Expand Down Expand Up @@ -940,9 +987,9 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
details.append(f" edge {label}: {cmp['reason']}")
print(f" FAIL: {label} -> {cmp['reason']}")

except torch.cuda.OutOfMemoryError:
except _OOM_ERROR:
print(f" SKIP: {label} -> OOM")
torch.cuda.empty_cache()
_empty_cache()
except BenchTimeoutError:
edge_pass = False
details.append(f" edge {label}: TIMEOUT")
Expand All @@ -952,7 +999,7 @@ def run_correctness(kernel_fn: Callable, config: dict, quick: bool = False) -> d
details.append(f" edge {label}: {type(e).__name__}: {e}")
print(f" FAIL: {label} -> {type(e).__name__}: {e}")
finally:
torch.cuda.empty_cache()
_empty_cache()

results["edge_cases"] = "PASS" if edge_pass else "FAIL"
if not edge_pass:
Expand Down Expand Up @@ -984,16 +1031,16 @@ def _do_bench(fn: Callable, warmup: int = 25, rep: int = 100) -> float:
# Warmup
for _ in range(warmup):
fn()
torch.cuda.synchronize()
_sync_device()

times = []
for _ in range(rep):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start = _event(enable_timing=True)
end = _event(enable_timing=True)
start.record()
fn()
end.record()
torch.cuda.synchronize()
_sync_device()
times.append(start.elapsed_time(end))

times.sort()
Expand All @@ -1003,7 +1050,7 @@ def _do_bench(fn: Callable, warmup: int = 25, rep: int = 100) -> float:
def run_performance(kernel_fn: Callable, config: dict, gpu: GPUSpec,
sizes_filter: str = "all") -> dict:
"""Run performance benchmarks. Returns dict with metrics."""
device = "cuda"
device = _DEFAULT_DEVICE
gen_fn = config["input_generator"]
ref_fn = config["reference_fn"]
flops_fn = config["flops_fn"]
Expand Down Expand Up @@ -1110,16 +1157,16 @@ def run_performance(kernel_fn: Callable, config: dict, gpu: GPUSpec,
f"speedup: {speedup:.3f}x | {throughput_tflops:.3f} TFLOPS | "
f"{pct_peak_compute:.1f}% peak")

except torch.cuda.OutOfMemoryError:
except _OOM_ERROR:
print(f" SKIP: {label} -> OOM")
torch.cuda.empty_cache()
_empty_cache()
except BenchTimeoutError:
print(f" SKIP: {label} -> TIMEOUT")
except Exception as e:
print(f" ERROR: {label} -> {type(e).__name__}: {e}")
traceback.print_exc()
finally:
torch.cuda.empty_cache()
_empty_cache()

# If we didn't bench the primary size, use the last successful one
if primary_result is None and all_results:
Expand All @@ -1137,7 +1184,7 @@ def run_performance(kernel_fn: Callable, config: dict, gpu: GPUSpec,

def run_profile(kernel_fn: Callable, config: dict):
"""Run torch profiler and save a trace."""
device = "cuda"
device = _DEFAULT_DEVICE
gen_fn = config["input_generator"]
sizes = config["test_sizes"]

Expand All @@ -1159,22 +1206,29 @@ def run_profile(kernel_fn: Callable, config: dict):
print("\n=== PROFILING ===")
print(f"Profiling with size: {prof_size}, dtype: {dtype}")

activities = [torch.profiler.ProfilerActivity.CPU]
if _USE_XPU:
xpu_activity = getattr(torch.profiler.ProfilerActivity, "XPU", None)
if xpu_activity is not None:
activities.append(xpu_activity)
else:
print("WARNING: torch.profiler XPU activity unavailable; collecting CPU-only trace.")
else:
activities.append(torch.profiler.ProfilerActivity.CUDA)

with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
activities=activities,
record_shapes=True,
with_stack=True,
) as prof:
# Warmup
for _ in range(5):
kernel_fn(**inputs)
torch.cuda.synchronize()
_sync_device()
# Profiled runs
for _ in range(10):
kernel_fn(**inputs)
torch.cuda.synchronize()
_sync_device()

trace_path = os.path.join(trace_dir, "kernel_trace.json")
prof.export_chrome_trace(trace_path)
Expand Down Expand Up @@ -1322,9 +1376,9 @@ def main():
sizes_filter = args.sizes
if args.quick:
sizes_filter = "large"
torch.cuda.reset_peak_memory_stats()
_reset_peak_memory_stats()
perf_results = run_performance(kernel_fn, config, gpu, sizes_filter=sizes_filter)
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
peak_vram_mb = _max_memory_allocated() / 1024 / 1024
except Exception as e:
print(f"\nFATAL: Performance benchmarking crashed: {type(e).__name__}: {e}")
traceback.print_exc()
Expand Down
Loading