Skip to content

[WIP] 20250722 benchmark sweep #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 177 additions & 20 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmarks/run.py`.

NOTE: It's recommended to run `rm -rf /tmp/torchinductor_${USER}/*` before running this script to enable autotuning and ensure best performance.

Usage:
$ python benchmarks/run.py [tritonbench args...] [--kernel <kernel_name(s)>]

Expand All @@ -27,6 +29,9 @@
import sys
from typing import Any
from typing import Callable
import time

import torch

# Maps tritonbench op names to Helion kernel examples
# Can map to a single kernel or a list of kernel variants
Expand All @@ -37,38 +42,100 @@
# - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict)
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
"embedding": (
"tritonbench.operators.embedding.operator",
"examples.embedding",
"embedding_tritonbench",
),
"vector_exp": (
"tritonbench.operators.vector_exp.operator",
"examples.exp",
"exp_tritonbench",
),
"rms_norm": (
"tritonbench.operators.rms_norm.operator",
"examples.rms_norm",
"rms_norm_tritonbench",
{
"num_inputs": 3
}, # TODO(yf225): reduction dim size = 8192 currently throws error
),
"sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
"layer_norm": (
"tritonbench.operators.layer_norm.operator",
"examples.layer_norm",
"layer_norm_fwd_tritonbench",
),
"softmax": (
"tritonbench.operators.softmax.operator",
"examples.softmax",
"softmax",
{
"only_shapes": [
[4096, 6912 * 2 - 4096],
[4096, 6976 * 2 - 4096],
[4096, 7040 * 2 - 4096],
[4096, 7104 * 2 - 4096],
[4096, 7168 * 2 - 4096],
[4096, 7232 * 2 - 4096],
[4096, 7296 * 2 - 4096],
[4096, 7360 * 2 - 4096],
[4096, 7424 * 2 - 4096],
[4096, 7488 * 2 - 4096],
[4096, 7552 * 2 - 4096],
[4096, 7616 * 2 - 4096],
[4096, 7680 * 2 - 4096],
[4096, 7744 * 2 - 4096],
[4096, 7808 * 2 - 4096],
[4096, 7872 * 2 - 4096],
[4096, 7936 * 2 - 4096],
[4096, 8000 * 2 - 4096],
[4096, 8064 * 2 - 4096],
[4096, 8128 * 2 - 4096],
[4096, 8192 * 2 - 4096],
[4096, 8256 * 2 - 4096],
[4096, 8320 * 2 - 4096],
[4096, 8384 * 2 - 4096],
]
},
),
"cross_entropy": (
"tritonbench.operators.cross_entropy.operator",
"examples.cross_entropy",
"cross_entropy",
{"B": 4, "T": 512, "v_range": "10,15"}
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
else {},
),
"sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
"jagged_mean": (
"tritonbench.operators.jagged_mean.operator",
"examples.jagged_mean",
"jagged_mean_tritonbench",
{"B": 32, "M": 8, "seqlen": 64}
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
else {},
else {"B": 512, "M": 64},
),
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
"embedding": (
"tritonbench.operators.embedding.operator",
"examples.embedding",
"embedding_tritonbench",
{
"only_shapes": [
(8, 2048, 4096, 16384),
(8, 2048, 4096, 32768),
(8, 2048, 4096, 65536),
(8, 2048, 4096, 131072),
]
},
),
"vector_exp": (
"tritonbench.operators.vector_exp.operator",
"examples.exp",
"exp_tritonbench",
{
"only_shapes": [
65536,
131072,
262144,
524288,
1048576,
2097152,
4194304,
8388608,
16777216,
33554432,
67108864,
134217728,
]
},
),
"fp8_gemm": (
"tritonbench.operators.fp8_gemm.fp8_gemm",
Expand Down Expand Up @@ -99,7 +166,7 @@
"layer_norm": (
"tritonbench.operators.layer_norm.operator",
"examples.layer_norm",
"layer_norm_fwd",
"layer_norm_fwd_tritonbench",
),
"jagged_softmax": (
"tritonbench.operators.jagged_softmax.operator",
Expand Down Expand Up @@ -257,6 +324,7 @@ def run_kernel(

# Extract operator args if present
operator_args = {}
only_shapes = None

# Normalize to list of variants format
if isinstance(mapping[1], list):
Expand All @@ -265,15 +333,21 @@ def run_kernel(
variants = mapping[1]
# Check if last element is args dict
if len(mapping) > 2 and isinstance(mapping[2], dict):
operator_args = mapping[2]
operator_args = mapping[2].copy()
# Extract only_shapes if present
if "only_shapes" in operator_args:
only_shapes = operator_args.pop("only_shapes")
else:
# Single kernel format
if len(mapping) == 4 and isinstance(mapping[3], dict):
# With args
tritonbench_module = mapping[0]
module_path = mapping[1]
func_name = mapping[2]
operator_args = mapping[3] # pyright: ignore[reportGeneralTypeIssues]
operator_args = mapping[3].copy() # pyright: ignore[reportGeneralTypeIssues]
# Extract only_shapes if present
if "only_shapes" in operator_args:
only_shapes = operator_args.pop("only_shapes")
variants = [(module_path, func_name)]
else:
# Without args
Expand All @@ -289,6 +363,7 @@ def run_kernel(
tritonbench_args,
input_shard_info,
operator_args,
only_shapes,
)


Expand All @@ -299,6 +374,7 @@ def run_kernel_variants(
tritonbench_args: list[str],
input_shard_info: tuple[int, int] | None = None,
operator_args: dict[str, Any] | None = None,
only_shapes: list[str] | None = None,
) -> None:
"""Run kernel variants in the same benchmark run."""

Expand Down Expand Up @@ -363,6 +439,87 @@ def run_kernel_variants(
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
register_benchmark,
)

# Always extract all inputs beforehand
# Override the get_input_iter method for the operator class
original_get_input_iter = Operator.get_input_iter
original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None

# Create a list to store all inputs
all_inputs = []

# Collect all inputs
torch.manual_seed(42)
temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args)
for inputs in original_get_input_iter(temp_operator):
# Set random seed for reproducibility
torch.manual_seed(42)
all_inputs.append(inputs)

# If only_shapes is provided, filter the inputs
if only_shapes:
print(f"Using only_shapes for {kernel_name}: {only_shapes}", file=sys.stderr)

# Create a list to store filtered inputs
filtered_inputs = []

# Filter inputs that match the shape filter
for inputs in all_inputs:
# Get the shape value for this input
shape_value = None

if original_get_x_val:
# Use the operator's get_x_val method to get shape representation
shape_value = original_get_x_val(temp_operator, inputs)
else:
# Fallback: try to get shape from the inputs directly
if isinstance(inputs, tuple) and len(inputs) > 0:
if hasattr(inputs[0], 'shape'):
shape_value = list(inputs[0].shape)
elif isinstance(inputs[0], (int, float)):
shape_value = inputs[0]
else:
# For complex inputs, try to extract meaningful shape info
shape_value = inputs

# Check if this shape matches any in our filter using direct comparison
match_found = False
for expected_shape in only_shapes:
if shape_value == expected_shape:
match_found = True
break
# Also check if shape_value is a tuple/list that matches
elif isinstance(shape_value, (tuple, list)) and isinstance(expected_shape, (tuple, list)):
if len(shape_value) == len(expected_shape) and all(a == b for a, b in zip(shape_value, expected_shape)):
match_found = True
break

if match_found:
filtered_inputs.append(inputs)
print(f" Including shape: {shape_value}", file=sys.stderr)

if not filtered_inputs:
print(f"Warning: No shapes matched the filter for {kernel_name}", file=sys.stderr)

# Use filtered inputs instead of all inputs
inputs_to_use = filtered_inputs
else:
# Use all inputs
inputs_to_use = all_inputs

del temp_operator # Clean up temporary operator

# Create a new input iterator function
def new_get_input_iter(self):
"""Custom input iterator that yields pre-collected inputs."""
for inputs in inputs_to_use:
yield inputs

# Monkey-patch the operator class
Operator.get_input_iter = new_get_input_iter

# Also override _available_num_inputs for proper sharding support
Operator._available_num_inputs = len(inputs_to_use)

# Register all variants as separate methods
for module_path, func_name in variants:
Expand Down Expand Up @@ -408,7 +565,7 @@ def helion_method(
# This ensures we run autotuning even if the kernel has pre-specified configs
if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1":
attr.settings.force_autotune = True
attr.settings.static_shape = True # pyright: ignore[reportAttributeAccessIssue]
attr.settings.static_shapes = True # pyright: ignore[reportAttributeAccessIssue]

def _inner() -> Callable[..., Any] | object:
# BENCHMARK HOT PATH, do not add any new logic here
Expand Down
49 changes: 49 additions & 0 deletions benchmarks/run_input_shard.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
[[ -z "$RANK_OFFSET" ]] && { echo "Error: RANK_OFFSET is not set"; exit 1; }
[[ -z "$SHARD" ]] && { echo "Error: SHARD is not set"; exit 1; }
[[ -z "$WORLD_SIZE" ]] && { echo "Error: WORLD_SIZE is not set"; exit 1; }

# Capture timestamp once for consistent filename
TIMESTAMP=$(date +%s)
OUTPUT_DIR="benchmarks_results/benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}"

KERNEL_NAME_LIST=(
"rms_norm"
"layer_norm"
"softmax"
"cross_entropy"
"sum"
"jagged_mean"
"vector_add"
"embedding"
"vector_exp"
)

# Retry until success
attempt=0
for KERNEL_NAME in "${KERNEL_NAME_LIST[@]}"; do
while true; do
# while (( attempt < 10 )); do
attempt=$((attempt + 1))
echo "Attempt $attempt: Running benchmark for shard ${SHARD}/${WORLD_SIZE}..."

# TIMESTAMP=$(date +%s)
# OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt"

mkdir -p ${OUTPUT_DIR} || true
OUTPUT_FILE="${OUTPUT_DIR}/${KERNEL_NAME}.log"
CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --kernel ${KERNEL_NAME} --metrics accuracy,tflops,gbps,speedup --latency-measure-mode inductor_benchmarker --csv --output-dir ${OUTPUT_DIR} >"${OUTPUT_FILE}" 2>&1

exit_code=$?
# Check for success: exit code 0 AND no exception message in output
if [ $exit_code -eq 0 ] && ! grep -q "Caught exception, terminating early with partial results" "${OUTPUT_FILE}"; then
echo "Success! Benchmark completed for shard ${SHARD}/${WORLD_SIZE}"
break
else
echo "Failed with exit code $exit_code. Retrying..."
sleep 10 # wait a few seconds before retrying
fi
done
done

# Runs the 1st shard of input on GPU-0:
# SHARD=1 RANK_OFFSET=4 WORLD_SIZE=4 bash benchmarks/run_input_shard.sh
Loading
Loading