From f69c9254df006c11d60a099aca3cc81769305581 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 30 Jul 2025 20:35:32 -0700 Subject: [PATCH 01/16] 20250722_benchmark_sweep --- benchmarks/run.py | 1 + benchmarks/run_input_shard.sh | 31 ++++++++++++++++++++++++++++++ helion/autotuner/base_search.py | 34 +++++++++++++++++++++++++-------- helion/autotuner/config_spec.py | 4 ++-- 4 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 benchmarks/run_input_shard.sh diff --git a/benchmarks/run.py b/benchmarks/run.py index 2c25d275..5473a588 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -27,6 +27,7 @@ import sys from typing import Any from typing import Callable +import time # Maps tritonbench op names to Helion kernel examples # Can map to a single kernel or a list of kernel variants diff --git a/benchmarks/run_input_shard.sh b/benchmarks/run_input_shard.sh new file mode 100644 index 00000000..101d2534 --- /dev/null +++ b/benchmarks/run_input_shard.sh @@ -0,0 +1,31 @@ +[[ -z "$RANK_OFFSET" ]] && { echo "Error: RANK_OFFSET is not set"; exit 1; } +[[ -z "$SHARD" ]] && { echo "Error: SHARD is not set"; exit 1; } +[[ -z "$WORLD_SIZE" ]] && { echo "Error: WORLD_SIZE is not set"; exit 1; } + +# Capture timestamp once for consistent filename +TIMESTAMP=$(date +%s) +OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_$((SHARD+1))_of_${WORLD_SIZE}.txt" + +# Retry until success +attempt=0 +while true; do +# while (( attempt < 10 )); do + attempt=$((attempt + 1)) + echo "Attempt $attempt: Running benchmark for shard $((SHARD+1))/${WORLD_SIZE}..." + + # TIMESTAMP=$(date +%s) + # OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_$((SHARD+1))_of_${WORLD_SIZE}.txt" + + CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD)) python benchmarks/run.py --input-shard $((SHARD+1))/${WORLD_SIZE} --metrics accuracy,tflops,gbps,speedup >"$OUTPUT_FILE" 2>&1 + + exit_code=$? + if [ $exit_code -eq 0 ]; then + echo "Success! Benchmark completed for shard $((SHARD+1))/${WORLD_SIZE}" + break + else + echo "Failed with exit code $exit_code. Retrying..." + sleep 10 # wait a few seconds before retrying + fi +done + +# SHARD=0 RANK_OFFSET=4 WORLD_SIZE=4 bash benchmarks/run_input_shard.sh diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index b66343cf..f728f6b9 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -17,6 +17,8 @@ from typing import NamedTuple from typing import NoReturn +from triton.compiler.errors import CompilationError + if TYPE_CHECKING: from triton.runtime.jit import JITFunction @@ -109,10 +111,13 @@ def benchmark(self, config: Config) -> float: Returns: The performance of the configuration in seconds. """ - fn = self.kernel.compile_config(config, allow_print=False) - if self.start_precompile_and_check_for_hangs(config, fn)(): - return self.benchmark_function(config, fn) - return inf + try: + fn = self.kernel.compile_config(config, allow_print=False) + if self.start_precompile_and_check_for_hangs(config, fn)(): + return self.benchmark_function(config, fn) + return inf + except Exception as e: + return inf def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: """ @@ -146,9 +151,11 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: self.log.debug("Benchmarking failed: OutOfResources") except PTXASError: self.log.warning(f"PTXASError compiling config: {config}") + except CompilationError: + self.log.debug("Benchmarking failed: Triton CompilationError") except Exception as e: - if not _expected_errors_regexp.search(str(e)): - raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e + # if not _expected_errors_regexp.search(str(e)): + # raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}") return inf @@ -170,6 +177,8 @@ def start_precompile_and_check_for_hangs( """ if not self.settings.autotune_precompile: return PrecompileFuture.skip(self, config, True) + if fn is None: + return PrecompileFuture.skip(self, config, False) ctx = mp.get_context("fork") def extract_launcher( @@ -190,6 +199,8 @@ def extract_launcher( precompiler = make_precompiler(e.kernel)(*e.args, **e.kwargs) if precompiler is already_compiled: return PrecompileFuture.skip(self, config, True) + except Exception as e: + return PrecompileFuture.skip(self, config, False) process: mp.Process = ctx.Process(target=precompiler) # pyright: ignore[reportAssignmentType] process.start() return PrecompileFuture( @@ -209,7 +220,13 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float] Returns: A list of tuples containing configurations and their performance. """ - fns = [self.kernel.compile_config(c, allow_print=False) for c in configs] + fns = [] + for c in configs: + try: + compile_result = self.kernel.compile_config(c, allow_print=False) + fns.append(compile_result) + except Exception as e: + fns.append(None) if self.settings.autotune_precompile: is_workings = PrecompileFuture.wait_for_all( [ @@ -390,11 +407,12 @@ def population_statistics(population: list[PopulationMember]) -> str: raise exc.NoConfigFound return ( f"failed={len(population) - len(working)} " + ) + ( f"min={working[0].perf:.4f} " f"mid={working[len(working) // 2].perf:.4f} " f"max={working[-1].perf:.4f} " f"best={population[0].config!s}" - ) + ) if len(working) > 0 else "all failed!" return ( f"min={population[0].perf:.4f} " f"mid={population[len(population) // 2].perf:.4f} " diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 812d3f27..80337d05 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -411,8 +411,8 @@ def _flat_config( default = min(high, 4096) value = fn(BlockSizeFragment(low, high, default)) assert isinstance(value, int) - if value >= self.size_hint: - return None # max size becomes persistent reduction + if value >= self.size_hint or value < low: + return None # max size or invalid value becomes persistent reduction return value def _normalize(self, name: str, value: object) -> int | None: From 07519c3d9ab93d3753794fcaba2ded53c91f7625 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 31 Jul 2025 12:36:56 -0700 Subject: [PATCH 02/16] output to CSV folder; change SHARD to start from 1 --- benchmarks/run_input_shard.sh | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/benchmarks/run_input_shard.sh b/benchmarks/run_input_shard.sh index 101d2534..50c71dd6 100644 --- a/benchmarks/run_input_shard.sh +++ b/benchmarks/run_input_shard.sh @@ -4,23 +4,25 @@ # Capture timestamp once for consistent filename TIMESTAMP=$(date +%s) -OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_$((SHARD+1))_of_${WORLD_SIZE}.txt" +OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt" +CSV_OUTPUT_DIR="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}_csv" # Retry until success attempt=0 while true; do # while (( attempt < 10 )); do attempt=$((attempt + 1)) - echo "Attempt $attempt: Running benchmark for shard $((SHARD+1))/${WORLD_SIZE}..." + echo "Attempt $attempt: Running benchmark for shard ${SHARD}/${WORLD_SIZE}..." # TIMESTAMP=$(date +%s) - # OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_$((SHARD+1))_of_${WORLD_SIZE}.txt" + # OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt" - CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD)) python benchmarks/run.py --input-shard $((SHARD+1))/${WORLD_SIZE} --metrics accuracy,tflops,gbps,speedup >"$OUTPUT_FILE" 2>&1 + mkdir -p ${CSV_OUTPUT_DIR} || true + CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --metrics accuracy,tflops,gbps,speedup --csv --output-dir ${CSV_OUTPUT_DIR} >"$OUTPUT_FILE" 2>&1 exit_code=$? if [ $exit_code -eq 0 ]; then - echo "Success! Benchmark completed for shard $((SHARD+1))/${WORLD_SIZE}" + echo "Success! Benchmark completed for shard ${SHARD}/${WORLD_SIZE}" break else echo "Failed with exit code $exit_code. Retrying..." @@ -28,4 +30,5 @@ while true; do fi done -# SHARD=0 RANK_OFFSET=4 WORLD_SIZE=4 bash benchmarks/run_input_shard.sh +# Runs the 1st shard of input on GPU-0: +# SHARD=1 RANK_OFFSET=4 WORLD_SIZE=4 bash benchmarks/run_input_shard.sh From 637dcfc6bb6a5863257b7e2eeba55eef7277394e Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 4 Aug 2025 17:03:40 -0700 Subject: [PATCH 03/16] More error catching; run kernels from explicit list --- benchmarks/run.py | 54 +++++++++++++++++++++++++++-------- benchmarks/run_input_shard.sh | 47 +++++++++++++++++++----------- 2 files changed, 73 insertions(+), 28 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index 5473a588..8215fbd3 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -38,17 +38,6 @@ # - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict) KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType] # : (, , ) - "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), - "embedding": ( - "tritonbench.operators.embedding.operator", - "examples.embedding", - "embedding_tritonbench", - ), - "vector_exp": ( - "tritonbench.operators.vector_exp.operator", - "examples.exp", - "exp_tritonbench", - ), "rms_norm": ( "tritonbench.operators.rms_norm.operator", "examples.rms_norm", @@ -57,12 +46,25 @@ "num_inputs": 3 }, # TODO(yf225): reduction dim size = 8192 currently throws error ), - "sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"), + "layer_norm": ( + "tritonbench.operators.layer_norm.operator", + "examples.layer_norm", + "layer_norm_fwd", + ), "softmax": ( "tritonbench.operators.softmax.operator", "examples.softmax", "softmax", ), + "cross_entropy": ( + "tritonbench.operators.cross_entropy.operator", + "examples.cross_entropy", + "cross_entropy", + {"B": 4, "T": 512, "v_range": "10,15"} + if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" + else {}, + ), + "sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"), "jagged_mean": ( "tritonbench.operators.jagged_mean.operator", "examples.jagged_mean", @@ -71,6 +73,17 @@ if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" else {}, ), + "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), + "embedding": ( + "tritonbench.operators.embedding.operator", + "examples.embedding", + "embedding_tritonbench", + ), + "vector_exp": ( + "tritonbench.operators.vector_exp.operator", + "examples.exp", + "exp_tritonbench", + ), "fp8_gemm": ( "tritonbench.operators.fp8_gemm.fp8_gemm", "examples.fp8_gemm", @@ -303,6 +316,23 @@ def run_kernel_variants( ) -> None: """Run kernel variants in the same benchmark run.""" + # Configure Helion to use fewer generations for faster autotuning during benchmarks + import helion + from helion.autotuner import DifferentialEvolutionSearch, LocalAutotuneCache + from helion.runtime.kernel import BoundKernel + from typing import Sequence + + def fast_autotuner_fn( + bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object + ) -> LocalAutotuneCache: + # Use only 1 generation instead of default 20 for faster benchmarking + return LocalAutotuneCache( + DifferentialEvolutionSearch(bound_kernel, args, num_generations=1, **kwargs) + ) + + # Set the custom autotuner function + helion.set_default_settings(helion.Settings(autotuner_fn=fast_autotuner_fn)) + # Import tritonbench components try: from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports] diff --git a/benchmarks/run_input_shard.sh b/benchmarks/run_input_shard.sh index 50c71dd6..2ef03a49 100644 --- a/benchmarks/run_input_shard.sh +++ b/benchmarks/run_input_shard.sh @@ -7,27 +7,42 @@ TIMESTAMP=$(date +%s) OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt" CSV_OUTPUT_DIR="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}_csv" +KERNEL_NAME_LIST=( + "rms_norm" + "layer_norm" + "softmax" + "cross_entropy" + "sum" + "jagged_mean" + "vector_add" + "embedding" + "vector_exp" +) + # Retry until success attempt=0 -while true; do -# while (( attempt < 10 )); do - attempt=$((attempt + 1)) - echo "Attempt $attempt: Running benchmark for shard ${SHARD}/${WORLD_SIZE}..." +for KERNEL_NAME in "${KERNEL_NAME_LIST[@]}"; do + while true; do + # while (( attempt < 10 )); do + attempt=$((attempt + 1)) + echo "Attempt $attempt: Running benchmark for shard ${SHARD}/${WORLD_SIZE}..." - # TIMESTAMP=$(date +%s) - # OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt" + # TIMESTAMP=$(date +%s) + # OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt" - mkdir -p ${CSV_OUTPUT_DIR} || true - CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --metrics accuracy,tflops,gbps,speedup --csv --output-dir ${CSV_OUTPUT_DIR} >"$OUTPUT_FILE" 2>&1 + mkdir -p ${CSV_OUTPUT_DIR} || true + CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --kernel ${KERNEL_NAME} --metrics accuracy,tflops,gbps,speedup --csv --output-dir ${CSV_OUTPUT_DIR} >"$OUTPUT_FILE" 2>&1 - exit_code=$? - if [ $exit_code -eq 0 ]; then - echo "Success! Benchmark completed for shard ${SHARD}/${WORLD_SIZE}" - break - else - echo "Failed with exit code $exit_code. Retrying..." - sleep 10 # wait a few seconds before retrying - fi + exit_code=$? + # Check for success: exit code 0 AND no exception message in output + if [ $exit_code -eq 0 ] && ! grep -q "Caught exception, terminating early with partial results" "$OUTPUT_FILE"; then + echo "Success! Benchmark completed for shard ${SHARD}/${WORLD_SIZE}" + break + else + echo "Failed with exit code $exit_code. Retrying..." + sleep 10 # wait a few seconds before retrying + fi + done done # Runs the 1st shard of input on GPU-0: From 3c42e5e8e290b911f78478317385f860c6657d01 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 4 Aug 2025 17:14:48 -0700 Subject: [PATCH 04/16] try fix Tile(block_id) error --- helion/_compiler/type_propagation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 70bea5b3..ab3a91d7 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -1004,12 +1004,19 @@ def __init__(self, origin: Origin, block_id: int) -> None: self.block_id = block_id def proxy(self) -> object: + from ..language.tile_proxy import Tile as TileClass + with proxy_tensor.disable_proxy_modes_tracing(): fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue] torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue] ) try: - return Tile(self.block_id) + # Create a Tile instance using torch.as_subclass to properly handle tensor subclassing + # This avoids the "already associated to a python object" error + base_tensor = torch.empty([], dtype=torch.int64, device='meta') + tile = base_tensor.as_subclass(TileClass) + tile.block_id = self.block_id + return tile finally: assert fake_mode is not None torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue] From 33151ba99d579daa6486c9efeac94a408948a65a Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 4 Aug 2025 17:30:52 -0700 Subject: [PATCH 05/16] improve script --- benchmarks/run_input_shard.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/run_input_shard.sh b/benchmarks/run_input_shard.sh index 2ef03a49..9ce73051 100644 --- a/benchmarks/run_input_shard.sh +++ b/benchmarks/run_input_shard.sh @@ -4,8 +4,7 @@ # Capture timestamp once for consistent filename TIMESTAMP=$(date +%s) -OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt" -CSV_OUTPUT_DIR="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}_csv" +OUTPUT_DIR="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}" KERNEL_NAME_LIST=( "rms_norm" @@ -30,12 +29,13 @@ for KERNEL_NAME in "${KERNEL_NAME_LIST[@]}"; do # TIMESTAMP=$(date +%s) # OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt" - mkdir -p ${CSV_OUTPUT_DIR} || true - CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --kernel ${KERNEL_NAME} --metrics accuracy,tflops,gbps,speedup --csv --output-dir ${CSV_OUTPUT_DIR} >"$OUTPUT_FILE" 2>&1 + mkdir -p ${OUTPUT_DIR} || true + OUTPUT_FILE="${OUTPUT_DIR}/${KERNEL_NAME}.log" + CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --kernel ${KERNEL_NAME} --metrics accuracy,tflops,gbps,speedup --csv --output-dir ${OUTPUT_DIR} >"${OUTPUT_FILE}" 2>&1 exit_code=$? # Check for success: exit code 0 AND no exception message in output - if [ $exit_code -eq 0 ] && ! grep -q "Caught exception, terminating early with partial results" "$OUTPUT_FILE"; then + if [ $exit_code -eq 0 ] && ! grep -q "Caught exception, terminating early with partial results" "${OUTPUT_FILE}"; then echo "Success! Benchmark completed for shard ${SHARD}/${WORLD_SIZE}" break else From 9c57ed3729a20f229ce109ba3405afe7350eb9ac Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 4 Aug 2025 21:31:52 -0700 Subject: [PATCH 06/16] fix bug --- benchmarks/run.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index 8215fbd3..b0acde49 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -316,23 +316,6 @@ def run_kernel_variants( ) -> None: """Run kernel variants in the same benchmark run.""" - # Configure Helion to use fewer generations for faster autotuning during benchmarks - import helion - from helion.autotuner import DifferentialEvolutionSearch, LocalAutotuneCache - from helion.runtime.kernel import BoundKernel - from typing import Sequence - - def fast_autotuner_fn( - bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object - ) -> LocalAutotuneCache: - # Use only 1 generation instead of default 20 for faster benchmarking - return LocalAutotuneCache( - DifferentialEvolutionSearch(bound_kernel, args, num_generations=1, **kwargs) - ) - - # Set the custom autotuner function - helion.set_default_settings(helion.Settings(autotuner_fn=fast_autotuner_fn)) - # Import tritonbench components try: from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports] From 09382b0d55466a838a004db043d3e695c0a336f8 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 4 Aug 2025 22:13:15 -0700 Subject: [PATCH 07/16] set static_shapes = True --- benchmarks/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index b0acde49..2587b44e 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -422,7 +422,7 @@ def helion_method( # This ensures we run autotuning even if the kernel has pre-specified configs if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1": attr.settings.force_autotune = True - attr.settings.static_shape = True # pyright: ignore[reportAttributeAccessIssue] + attr.settings.static_shapes = True # pyright: ignore[reportAttributeAccessIssue] def _inner() -> Callable[..., Any] | object: # BENCHMARK HOT PATH, do not add any new logic here From b5b7658c2535d93b699aebdc9d77dc35922057a2 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 6 Aug 2025 14:03:02 -0700 Subject: [PATCH 08/16] add only_shapes filtering in KERNEL_MAPPINGS --- benchmarks/run.py | 76 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index 2587b44e..0b63fecc 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -271,6 +271,7 @@ def run_kernel( # Extract operator args if present operator_args = {} + only_shapes = None # Normalize to list of variants format if isinstance(mapping[1], list): @@ -279,7 +280,10 @@ def run_kernel( variants = mapping[1] # Check if last element is args dict if len(mapping) > 2 and isinstance(mapping[2], dict): - operator_args = mapping[2] + operator_args = mapping[2].copy() + # Extract only_shapes if present + if "only_shapes" in operator_args: + only_shapes = operator_args.pop("only_shapes") else: # Single kernel format if len(mapping) == 4 and isinstance(mapping[3], dict): @@ -287,7 +291,10 @@ def run_kernel( tritonbench_module = mapping[0] module_path = mapping[1] func_name = mapping[2] - operator_args = mapping[3] # pyright: ignore[reportGeneralTypeIssues] + operator_args = mapping[3].copy() # pyright: ignore[reportGeneralTypeIssues] + # Extract only_shapes if present + if "only_shapes" in operator_args: + only_shapes = operator_args.pop("only_shapes") variants = [(module_path, func_name)] else: # Without args @@ -303,6 +310,7 @@ def run_kernel( tritonbench_args, input_shard_info, operator_args, + only_shapes, ) @@ -313,6 +321,7 @@ def run_kernel_variants( tritonbench_args: list[str], input_shard_info: tuple[int, int] | None = None, operator_args: dict[str, Any] | None = None, + only_shapes: list[str] | None = None, ) -> None: """Run kernel variants in the same benchmark run.""" @@ -377,6 +386,69 @@ def run_kernel_variants( from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports] register_benchmark, ) + + # Inject only_shapes filter if provided + if only_shapes: + print(f"Using only_shapes for {kernel_name}: {only_shapes}", file=sys.stderr) + + # Override the get_input_iter method for the operator class + original_get_input_iter = Operator.get_input_iter + original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None + + # Create a list to store filtered inputs and their shapes + filtered_inputs = [] + + # First, collect all inputs that match the shape filter + temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args) + for inputs in original_get_input_iter(temp_operator): + # Get the shape value for this input + shape_value = None + + if original_get_x_val: + # Use the operator's get_x_val method to get shape representation + shape_value = original_get_x_val(temp_operator, inputs) + else: + # Fallback: try to get shape from the inputs directly + if isinstance(inputs, tuple) and len(inputs) > 0: + if hasattr(inputs[0], 'shape'): + shape_value = list(inputs[0].shape) + elif isinstance(inputs[0], (int, float)): + shape_value = inputs[0] + else: + # For complex inputs, try to extract meaningful shape info + shape_value = inputs + + # Check if this shape matches any in our filter using direct comparison + match_found = False + for expected_shape in only_shapes: + if shape_value == expected_shape: + match_found = True + break + # Also check if shape_value is a tuple/list that matches + elif isinstance(shape_value, (tuple, list)) and isinstance(expected_shape, (tuple, list)): + if len(shape_value) == len(expected_shape) and all(a == b for a, b in zip(shape_value, expected_shape)): + match_found = True + break + + if match_found: + filtered_inputs.append(inputs) + print(f" Including shape: {shape_value}", file=sys.stderr) + + del temp_operator # Clean up temporary operator + + if not filtered_inputs: + print(f"Warning: No shapes matched the filter for {kernel_name}", file=sys.stderr) + + def filtered_get_input_iter(self): + """Custom input iterator that only yields filtered shapes.""" + for inputs in filtered_inputs: + yield inputs + + # Monkey-patch the operator class + Operator.get_input_iter = filtered_get_input_iter + + # Also override _available_num_inputs for proper sharding support + Operator._available_num_inputs = len(filtered_inputs) # Register all variants as separate methods for module_path, func_name in variants: From 22c03f79e7668192cafd7cbbf642709405037aa4 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 6 Aug 2025 14:03:32 -0700 Subject: [PATCH 09/16] choose specific shapes for softmax / embedding / exp --- benchmarks/run.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/benchmarks/run.py b/benchmarks/run.py index 0b63fecc..f03dcd41 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -55,6 +55,34 @@ "tritonbench.operators.softmax.operator", "examples.softmax", "softmax", + { + "only_shapes": [ + [4096, 6912 * 2 - 4096], + [4096, 6976 * 2 - 4096], + [4096, 7040 * 2 - 4096], + [4096, 7104 * 2 - 4096], + [4096, 7168 * 2 - 4096], + [4096, 7232 * 2 - 4096], + [4096, 7296 * 2 - 4096], + [4096, 7360 * 2 - 4096], + [4096, 7424 * 2 - 4096], + [4096, 7488 * 2 - 4096], + [4096, 7552 * 2 - 4096], + [4096, 7616 * 2 - 4096], + [4096, 7680 * 2 - 4096], + [4096, 7744 * 2 - 4096], + [4096, 7808 * 2 - 4096], + [4096, 7872 * 2 - 4096], + [4096, 7936 * 2 - 4096], + [4096, 8000 * 2 - 4096], + [4096, 8064 * 2 - 4096], + [4096, 8128 * 2 - 4096], + [4096, 8192 * 2 - 4096], + [4096, 8256 * 2 - 4096], + [4096, 8320 * 2 - 4096], + [4096, 8384 * 2 - 4096], + ] + }, ), "cross_entropy": ( "tritonbench.operators.cross_entropy.operator", @@ -78,11 +106,35 @@ "tritonbench.operators.embedding.operator", "examples.embedding", "embedding_tritonbench", + { + "only_shapes": [ + (8, 2048, 4096, 16384), + (8, 2048, 4096, 32768), + (8, 2048, 4096, 65536), + (8, 2048, 4096, 131072), + ] + }, ), "vector_exp": ( "tritonbench.operators.vector_exp.operator", "examples.exp", "exp_tritonbench", + { + "only_shapes": [ + 65536, + 131072, + 262144, + 524288, + 1048576, + 2097152, + 4194304, + 8388608, + 16777216, + 33554432, + 67108864, + 134217728, + ] + }, ), "fp8_gemm": ( "tritonbench.operators.fp8_gemm.fp8_gemm", From 253861c0e6947a97f49a10a8d135d9db4a9cca52 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 6 Aug 2025 22:25:27 -0700 Subject: [PATCH 10/16] Run specific batch_size and n_features for jagged_mean --- benchmarks/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index f03dcd41..2e9d67b2 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -99,7 +99,7 @@ "jagged_mean_tritonbench", {"B": 32, "M": 8, "seqlen": 64} if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" - else {}, + else {"B": 512, "M": 64}, ), "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), "embedding": ( From d721fc1429b1f382953475e8321f72796e348767 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 18 Aug 2025 10:50:40 -0700 Subject: [PATCH 11/16] full all shapes for rms_norm --- benchmarks/run.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index 2e9d67b2..2fc8f39f 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -42,9 +42,6 @@ "tritonbench.operators.rms_norm.operator", "examples.rms_norm", "rms_norm_tritonbench", - { - "num_inputs": 3 - }, # TODO(yf225): reduction dim size = 8192 currently throws error ), "layer_norm": ( "tritonbench.operators.layer_norm.operator", From 96358c5ff0634c058bd4798a03e8654f53456d60 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 18 Aug 2025 11:17:46 -0700 Subject: [PATCH 12/16] Always generate inputs beforehand, with fixed initial seeds --- benchmarks/run.py | 62 +++++++++++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index 2fc8f39f..ae09415c 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -29,6 +29,8 @@ from typing import Callable import time +import torch + # Maps tritonbench op names to Helion kernel examples # Can map to a single kernel or a list of kernel variants # Format options: @@ -436,20 +438,31 @@ def run_kernel_variants( register_benchmark, ) - # Inject only_shapes filter if provided + # Always extract all inputs beforehand + # Override the get_input_iter method for the operator class + original_get_input_iter = Operator.get_input_iter + original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None + + # Create a list to store all inputs + all_inputs = [] + + # Collect all inputs + torch.manual_seed(42) + temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args) + for inputs in original_get_input_iter(temp_operator): + # Set random seed for reproducibility + torch.manual_seed(42) + all_inputs.append(inputs) + + # If only_shapes is provided, filter the inputs if only_shapes: print(f"Using only_shapes for {kernel_name}: {only_shapes}", file=sys.stderr) - # Override the get_input_iter method for the operator class - original_get_input_iter = Operator.get_input_iter - original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None - - # Create a list to store filtered inputs and their shapes + # Create a list to store filtered inputs filtered_inputs = [] - # First, collect all inputs that match the shape filter - temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args) - for inputs in original_get_input_iter(temp_operator): + # Filter inputs that match the shape filter + for inputs in all_inputs: # Get the shape value for this input shape_value = None @@ -483,21 +496,28 @@ def run_kernel_variants( filtered_inputs.append(inputs) print(f" Including shape: {shape_value}", file=sys.stderr) - del temp_operator # Clean up temporary operator - if not filtered_inputs: print(f"Warning: No shapes matched the filter for {kernel_name}", file=sys.stderr) - def filtered_get_input_iter(self): - """Custom input iterator that only yields filtered shapes.""" - for inputs in filtered_inputs: - yield inputs - - # Monkey-patch the operator class - Operator.get_input_iter = filtered_get_input_iter - - # Also override _available_num_inputs for proper sharding support - Operator._available_num_inputs = len(filtered_inputs) + # Use filtered inputs instead of all inputs + inputs_to_use = filtered_inputs + else: + # Use all inputs + inputs_to_use = all_inputs + + del temp_operator # Clean up temporary operator + + # Create a new input iterator function + def new_get_input_iter(self): + """Custom input iterator that yields pre-collected inputs.""" + for inputs in inputs_to_use: + yield inputs + + # Monkey-patch the operator class + Operator.get_input_iter = new_get_input_iter + + # Also override _available_num_inputs for proper sharding support + Operator._available_num_inputs = len(inputs_to_use) # Register all variants as separate methods for module_path, func_name in variants: From 182ce0becbbd45226aee8e376b7d7d6836a61782 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 18 Aug 2025 11:55:26 -0700 Subject: [PATCH 13/16] recommend remove inductor cache before benchmarking --- benchmarks/run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/run.py b/benchmarks/run.py index ae09415c..a348409e 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -4,6 +4,8 @@ Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmarks/run.py`. +NOTE: It's recommended to run `rm -rf /tmp/torchinductor_${USER}/*` before running this script to enable autotuning and ensure best performance. + Usage: $ python benchmarks/run.py [tritonbench args...] [--kernel ] From b447976157715dd532e9044457aefa67447bae77 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 18 Aug 2025 14:24:04 -0700 Subject: [PATCH 14/16] use --latency-measure-mode inductor_benchmarker --- benchmarks/run_input_shard.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/run_input_shard.sh b/benchmarks/run_input_shard.sh index 9ce73051..53d6282f 100644 --- a/benchmarks/run_input_shard.sh +++ b/benchmarks/run_input_shard.sh @@ -31,7 +31,7 @@ for KERNEL_NAME in "${KERNEL_NAME_LIST[@]}"; do mkdir -p ${OUTPUT_DIR} || true OUTPUT_FILE="${OUTPUT_DIR}/${KERNEL_NAME}.log" - CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --kernel ${KERNEL_NAME} --metrics accuracy,tflops,gbps,speedup --csv --output-dir ${OUTPUT_DIR} >"${OUTPUT_FILE}" 2>&1 + CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --kernel ${KERNEL_NAME} --metrics accuracy,tflops,gbps,speedup --latency-measure-mode inductor_benchmarker --csv --output-dir ${OUTPUT_DIR} >"${OUTPUT_FILE}" 2>&1 exit_code=$? # Check for success: exit code 0 AND no exception message in output From 9e5d80ef70e513a34a14f5bd2e5e393a4cf1b649 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 18 Aug 2025 14:25:00 -0700 Subject: [PATCH 15/16] use benchmarks_results/ folder --- benchmarks/run_input_shard.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/run_input_shard.sh b/benchmarks/run_input_shard.sh index 53d6282f..52f34810 100644 --- a/benchmarks/run_input_shard.sh +++ b/benchmarks/run_input_shard.sh @@ -4,7 +4,7 @@ # Capture timestamp once for consistent filename TIMESTAMP=$(date +%s) -OUTPUT_DIR="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}" +OUTPUT_DIR="benchmarks_results/benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}" KERNEL_NAME_LIST=( "rms_norm" From d4e94831a67aace3cf81b436505d1129ce58704b Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 19 Aug 2025 18:56:23 -0700 Subject: [PATCH 16/16] layer_norm no-bias for Quack benchmark --- benchmarks/run.py | 4 +- examples/layer_norm.py | 88 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 6 deletions(-) diff --git a/benchmarks/run.py b/benchmarks/run.py index a348409e..87942884 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -50,7 +50,7 @@ "layer_norm": ( "tritonbench.operators.layer_norm.operator", "examples.layer_norm", - "layer_norm_fwd", + "layer_norm_fwd_tritonbench", ), "softmax": ( "tritonbench.operators.softmax.operator", @@ -166,7 +166,7 @@ "layer_norm": ( "tritonbench.operators.layer_norm.operator", "examples.layer_norm", - "layer_norm_fwd", + "layer_norm_fwd_tritonbench", ), "jagged_softmax": ( "tritonbench.operators.jagged_softmax.operator", diff --git a/examples/layer_norm.py b/examples/layer_norm.py index 33cb12fd..65bb8321 100644 --- a/examples/layer_norm.py +++ b/examples/layer_norm.py @@ -17,7 +17,7 @@ # %% @helion.kernel -def layer_norm_fwd( +def layer_norm_fwd_with_bias( x: torch.Tensor, nomralized_shape: list[int], weight: torch.Tensor, @@ -25,7 +25,7 @@ def layer_norm_fwd( eps: float = 1e-5, ) -> torch.Tensor: """ - Performs 1D layer normalization on the input tensor using Helion. + Performs 1D layer normalization with bias on the input tensor using Helion. Args: x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16. nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1). @@ -54,6 +54,58 @@ def layer_norm_fwd( return out +@helion.kernel +def layer_norm_fwd_no_bias( + x: torch.Tensor, + nomralized_shape: list[int], + weight: torch.Tensor, + eps: float = 1e-5, +) -> torch.Tensor: + """ + Performs 1D layer normalization without bias on the input tensor using Helion. + Args: + x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16. + nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1). + weight (torch.Tensor): Learnable scale parameter of shape [dim]. + eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5. + Returns: + torch.Tensor: The layer-normalized output tensor of shape [batch_size, dim], in FP16. + """ + m, n = x.size() + assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" + assert len(nomralized_shape) == 1, ( + "Helion layer norm only supports 1D layer norm currently" + ) + assert nomralized_shape[0] == n, ( + f"normalized shape mismatch {nomralized_shape[0]} != {n}" + ) + out = torch.empty([m, n], dtype=torch.float16, device=x.device) + for tile_m in hl.tile(m): + acc = x[tile_m, :].to(torch.float32) + var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) + normalized = (acc - mean) * torch.rsqrt(var + eps) + acc = normalized * (weight[:].to(torch.float32)) + out[tile_m, :] = acc + return out + + +def layer_norm_fwd_tritonbench( + x: torch.Tensor, + nomralized_shape: list[int], + weight: torch.Tensor, + bias: torch.Tensor | None, + eps: float = 1e-5, +) -> torch.Tensor: + """ + Wrapper function that dispatches to the appropriate layer normalization kernel. + Compatible with tritonbench which may pass None for bias. + """ + if bias is None: + return layer_norm_fwd_no_bias(x, nomralized_shape, weight, eps) + else: + return layer_norm_fwd_with_bias(x, nomralized_shape, weight, bias, eps) + + # %% def main() -> None: """ @@ -61,6 +113,7 @@ def main() -> None: - Generates random input, weight, and bias tensors. - Runs the Helion layer normalization kernel and compares its output to PyTorch's built-in layer_norm function using the run_example utility. + - Tests both with bias and without bias (no-bias mode). - Prints comparison results and checks for correctness within specified tolerances. """ batch_size = 32 @@ -70,15 +123,42 @@ def main() -> None: weight = torch.randn([dim], device=device, dtype=torch.float16) bias = torch.randn([dim], device=device, dtype=torch.float16) eps = 1e-4 + + # Test with bias + print("Testing layer_norm WITH bias:") run_example( - layer_norm_fwd, + layer_norm_fwd_with_bias, torch.nn.functional.layer_norm, (x, [dim], weight, bias, eps), - kernel_name="helion", + kernel_name="helion_with_bias", + baseline_name="torch", + rtol=1e-3, + atol=1e-3, + ) + + # Test without bias (no-bias mode) + print("\nTesting layer_norm WITHOUT bias (no-bias mode):") + run_example( + layer_norm_fwd_no_bias, + lambda x, shape, w, e: torch.nn.functional.layer_norm(x, shape, w, None, e), + (x, [dim], weight, eps), + kernel_name="helion_no_bias", baseline_name="torch", rtol=1e-3, atol=1e-3, ) + + # Test wrapper function with bias + print("\nTesting wrapper function WITH bias:") + result_with_bias = layer_norm_fwd_tritonbench(x, [dim], weight, bias, eps) + expected_with_bias = torch.nn.functional.layer_norm(x, [dim], weight, bias, eps) + print(f" Wrapper with bias matches torch: {torch.allclose(result_with_bias, expected_with_bias, rtol=1e-3, atol=1e-3)}") + + # Test wrapper function without bias + print("\nTesting wrapper function WITHOUT bias:") + result_no_bias = layer_norm_fwd_tritonbench(x, [dim], weight, None, eps) + expected_no_bias = torch.nn.functional.layer_norm(x, [dim], weight, None, eps) + print(f" Wrapper without bias matches torch: {torch.allclose(result_no_bias, expected_no_bias, rtol=1e-3, atol=1e-3)}") # %%