From 9c8b3c6f62cc95cca9ce49036bdbc72700e75611 Mon Sep 17 00:00:00 2001 From: sekyonda <127536312+sekyondaMeta@users.noreply.github.com> Date: Wed, 16 Jul 2025 16:07:26 -0400 Subject: [PATCH 1/9] Add Helion Puzzles --- docs/helion_puzzles.rst | 736 ++++++++++++++++++++++++++++++++++++++++ docs/index.md | 1 + 2 files changed, 737 insertions(+) create mode 100644 docs/helion_puzzles.rst diff --git a/docs/helion_puzzles.rst b/docs/helion_puzzles.rst new file mode 100644 index 00000000..91d66268 --- /dev/null +++ b/docs/helion_puzzles.rst @@ -0,0 +1,736 @@ +Helion Puzzles +============== + +Programming for accelerators such as GPUs is critical for modern AI systems. This often means programming directly in proprietary low-level languages such as CUDA. Helion is a Python-embedded domain-specific language (DSL) for authoring machine learning kernels, designed to compile down to Triton, a performant backend for programming GPUs and other devices. + +Helion aims to raise the level of abstraction compared to Triton, making it easier to write correct and efficient kernels while enabling more automation in the autotuning process. + +This set of puzzles is meant to teach you how to use Helion from first principles in an interactive fashion. You will start with trivial examples and build your way up to real algorithms like Flash Attention and Quantized neural networks. + +Setup +----- + +First, let's install the necessary dependencies. Helion requires a recent version of PyTorch and a development version of Triton. + +.. code-block:: python + + import logging + + import helion + import helion.language as hl + import torch + from torch import Tensor + + # If you set this to info you will see the output Triton Code + logging.getLogger().setLevel(logging.WARNING) + +Let's also create a simple testing function to verify our implementations. + +.. code-block:: python + + from triton.testing import do_bench + def test_kernel(kernel_fn, spec_fn, *args): + """Test a Helion kernel against a reference implementation.""" + # Run our implementation + result = kernel_fn(*args) + # Run reference implementation + expected = spec_fn(*args) + + # Check if results match + torch.testing.assert_close(result, expected) + print("✅ Results Match ✅") + + def benchmark_kernel(kernel_fn, *args, **kwargs): + """Benchmark a Helion kernel.""" + no_args = lambda: kernel_fn(*args, **kwargs) + time_in_ms = do_bench(no_args) + print(f"⏱ Time: {time_in_ms} ms") + + def compare_implementations(kernel_fn, spec_fn, *args, **kwargs): + """Benchmark a Helion kernel and its reference implementation.""" + kernel_no_args = lambda: kernel_fn(*args, **kwargs) + spec_no_args = lambda: spec_fn(*args, **kwargs) + kernel_time = do_bench(kernel_no_args) + spec_time = do_bench(spec_no_args) + print(f"⏱ Helion Kernel Time: {kernel_time:.3f} ms, PyTorch Reference Time: {spec_time:.3f} ms, Speedup: {spec_time/kernel_time:.3f}x") + +Basic Structure of a Helion Kernel +--------------------------------- + +Helion allows you to write GPU kernels using familiar PyTorch syntax. + +A Helion kernel has three main sections: + +1. **Host Section** (CPU) + This is standard PyTorch code executed on the CPU. Memory allocation, and shape computations are done here. Like with `Triton` and `Cuda` you need to setup your output buffers on the host before launching your kernel. + +2. **Device Loop** (GPU Grid) + `for tile in hl.tile(sizes)` - defines parallel execution across GPU thread blocks + +3. **Device Operations** (GPU Kernel) + PyTorch operations inside the loop - automatically compiled and fused + +Example: + +.. code-block:: python + + @helion.kernel(config=helion.Config(block_sizes = [128, 128])) # The @helion.kernel decorator marks this function for compilation + def example_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Host code: Standard PyTorch operations + m, n = x.size() + out = torch.empty_like(x) # Allocate output tensor + + # The hl.tile loop defines the parallel execution structure + for tile_m, tile_n in hl.tile([m, n]): + # Device code: Everything inside the hl.tile loop runs on GPU + out[tile_m, tile_n] = x[tile_m, tile_n] + y[tile_m, tile_n] # Simple element-wise addition expressed w/ pytorch ops + + return out # Return the result back to the host + + # Create some sample data + x = torch.randn(10, 10, device="cuda") + y = torch.randn(10, 10, device="cuda") + + # Run the kernel + result = example_add(x, y) + + # Verify result + expected = x + y + torch.testing.assert_close(result, expected) + print("✅ Results Match ✅") + benchmark_kernel(example_add, x, y) + compare_implementations(example_add, torch.add, x, y) + +Autotuning in Helion +-------------------- + +In the previous example, we explicitly specified a configuration using `config=helion.Config(block_sizes=[128, 128])`. This bypasses Helion's autotuning mechanism and uses our predefined settings. While this is quick to run, manually choosing optimal parameters can be challenging and hardware-dependent. + +### What is Autotuning? + +Autotuning is Helion's process of automatically finding the best configuration parameters for your specific: + +- Hardware (GPU model) +- Problem size +- Operation patterns + +When you omit the `config` parameter, Helion will automatically search for the optimal configuration: + +.. code-block:: python + + @helion.kernel() # No config = automatic tuning + def autotuned_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.empty_like(x) + for tile_m, tile_n in hl.tile([m, n]): + out[tile_m, tile_n] = x[tile_m, tile_n] + y[tile_m, tile_n] + +Feel free to run the above code to see how much more performant it is than the original, although be warned it might take some time 😃 + +Now let's move on to our puzzles! + +Puzzle 1: Constant Add +---------------------- + +Add a constant to a vector. + +.. code-block:: python + + def add_spec(x: Tensor) -> Tensor: + """This is the spec that you should implement in the helion kernel below.""" + return x + 10. + + # ---- ✨ Is this the best block size? ---- + @helion.kernel(config = helion.Config(block_sizes = [1,])) + def add_kernel(x: torch.Tensor) -> torch.Tensor: + # ---- ✨ Your Code Here ✨---- + # Set up the output buffer which you will return + + # Use Helion to tile the computation + for tile_n in hl.tile(TILE_RANGE): + # ---- ✨ Your Code Here ✨---- + + return out + + # Test the kernel + x = torch.randn(8192, device="cuda") + test_kernel(add_kernel, add_spec, x) + benchmark_kernel(add_kernel, x) + compare_implementations(add_kernel, add_spec, x) + +.. code-block:: python + + def add_spec(x: Tensor) -> Tensor: + """This is the spec that you should implement.""" + return x + 10. + + # ---- ✨ Is this the best block size? ---- + @helion.kernel(config = helion.Config(block_sizes = [32,])) + def add_kernel(x: torch.Tensor) -> torch.Tensor: + # ---- ✨ Your Code Here ✨---- + # Set up the output buffer which you will return + TILE_RANGE = x.size() + out = torch.empty_like(x) + # ---- End of Code ---- + + # Use Helion to tile the computation + for tile_n in hl.tile(TILE_RANGE): + # ---- ✨ Your Code Here ✨---- + x_tile = x[tile_n] + out[tile_n] = x_tile + 10.0 + + return out + + # Test the kernel + x = torch.randn(8192, device="cuda") + test_kernel(add_kernel, add_spec, x) + benchmark_kernel(add_kernel, x) + compare_implementations(add_kernel, add_spec, x) + +Puzzle 2: Outer Vector Add +-------------------------- + +Add two vectors using an outer product pattern. + +.. code-block:: python + + def broadcast_add_spec(x: Tensor, y: Tensor) -> Tensor: + return x[None, :] + y[:, None] + + # ---- ✨ Is this the best block size? ---- + @helion.kernel(config = helion.Config(block_sizes = [32, 32])) + def broadcast_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + # ---- ✨ Your Code Here ✨---- + n0 = x.size(0) + n1 = y.size(0) + out = x.new_empty(n1, n0) + + # Use Helion to tile the computation + for tile_i, tile_j in hl.tile([n1, n0]): + # Get tiles from x and y + y_tile = y[tile_i] + x_tile = x[tile_j] + # Compute outer sum + out[tile_i, tile_j] = y_tile[:, None] + x_tile[None, :] + + return out + + # Test the kernel + x = torch.randn(1142, device="cuda") + y = torch.randn(512, device="cuda") + test_kernel(broadcast_add_kernel, broadcast_add_spec, x, y) + benchmark_kernel(broadcast_add_kernel, x, y) + compare_implementations(broadcast_add_kernel, broadcast_add_spec, x, y) + +Puzzle 3: Fused Outer Multiplication +----------------------------------- + +Multiply a row vector to a column vector and take a relu. + +.. code-block:: python + + def mul_relu_block_spec(x: Tensor, y: Tensor) -> Tensor: + return torch.relu(x[None, :] * y[:, None]) + + # ---- ✨ Is this the best block size? ---- + @helion.kernel(config = helion.Config(block_sizes = [32, 32])) + def mul_relu_block_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + n0 = x.size(0) + n1 = y.size(0) + # Create output tensor + out = torch.empty([n1, n0], dtype=x.dtype, device=x.device) + + # Use Helion to tile the computation + for tile_i, tile_j in hl.tile([n1, n0]): + # Get tiles from x and y + y_tile = y[tile_i] + x_tile = x[tile_j] + # Compute outer product followed by ReLU + out[tile_i, tile_j] = torch.relu(y_tile[:, None] * x_tile[None, :]) + + return out + + # Test the kernel + x = torch.randn(512, device="cuda") + y = torch.randn(512, device="cuda") + test_kernel(mul_relu_block_kernel, mul_relu_block_spec, x, y) + compare_implementations(mul_relu_block_kernel, mul_relu_block_spec, x, y) + +Puzzle 4: Fused Outer Multiplication - Backwards +------------------------------------------------ + +While PyTorch and torch.compile automatically generates the backwards pass for your Tensor Operations, Helion does not. So lets practice by writing the backwards function for a fused mul_relu kernel + +.. code-block:: python + + def mul_relu_block_back_spec(x: Tensor, y: Tensor, dz: Tensor) -> Tensor: + x = x.clone() + y = y.clone() + x = x.requires_grad_(True) + z = torch.relu(x * y[:, None]) + grad_x, grad_y = torch.autograd.grad(z, [x, y], dz, retain_graph=True) + return grad_x + + @helion.kernel(config=helion.Config(block_sizes=[32, 32])) + def mul_relu_block_back_kernel( + x: torch.Tensor, y: torch.Tensor, dz: torch.Tensor + ) -> torch.Tensor: + # Get tensor sizes + n0 = x.size(1) + n1 = x.size(0) + # Create output tensor for gradients + dx = torch.empty_like(x) + dy = torch.empty_like(y) + + # Use Helion to tile the computation + for tile_i, tile_j in hl.tile([n1, n0]): + # Get input tiles + x_tile = x[tile_i, tile_j] + y_tile = y[tile_i] + dz_tile = dz[tile_i, tile_j] + + # Compute gradients for ReLU * multiplication backward + # For ReLU, gradient is 1 where input > 0, 0 otherwise + relu_mask = (x_tile * y_tile[:, None]) > 0 + # Chain rule: dx = dz * relu_grad * y + dx[tile_i, tile_j] = dz_tile * relu_mask * y_tile[:, None] + + return dx, dy + + # Test the kernel + x = torch.randn(512, 1024, device="cuda") + y = torch.randn(512, device="cuda") + dz = torch.randn(512, 1024, device="cuda") + test_kernel(mul_relu_block_back_kernel, mul_relu_block_back_spec, x, y, dz) + +Puzzle 7: Long Sum +----------------- + +Sum of a batch of numbers. + +.. code-block:: python + + def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]: + return x.sum(1) + + @helion.kernel() + def sum_kernel(x: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + batch, seq_len = x.size() + # Create output tensor + out = torch.empty(batch, dtype=x.dtype, device=x.device) + + # Use Helion to tile the batch dimension + for tile_batch in hl.tile(batch): + # Initialize accumulator for each batch element + acc = torch.zeros_like(tile_batch, dtype=torch.float32) + + # Process the sequence in chunks + for tile_seq in hl.tile(seq_len): + # Get the current chunk + chunk = x[tile_batch, tile_seq] + # Accumulate sum + acc += torch.sum(chunk, dim=1) + + # Store result + out[tile_batch] = acc + + return out + + # Test the kernel + x = torch.randn(4, 200, device="cuda") + test_kernel(sum_kernel, sum_spec, x) + +Puzzle 8: Long Softmax +--------------------- + +Softmax of a batch of logits. + +.. code-block:: python + + def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]: + x_max = x.max(1, keepdim=True)[0] + x = x - x_max + x_exp = x.exp() + return x_exp / x_exp.sum(1, keepdim=True) + + @helion.kernel() + def softmax_kernel(x: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + batch, seq_len = x.size() + # Create output tensor + out = torch.empty_like(x) + + # Use Helion to tile the batch dimension + for tile_batch in hl.tile(batch): + # First pass: find max value for each sequence + max_vals = torch.full_like(tile_batch, float('-inf'), dtype=torch.float32) + + for tile_seq in hl.tile(seq_len): + chunk = x[tile_batch, tile_seq] + max_vals = torch.maximum(max_vals, torch.max(chunk, dim=1)[0]) + + # Second pass: compute sum of exp(x - max) + sum_exp = torch.zeros_like(tile_batch, dtype=torch.float32) + + for tile_seq in hl.tile(seq_len): + chunk = x[tile_batch, tile_seq] + exp_vals = torch.exp(chunk - max_vals[:, None]) + sum_exp += torch.sum(exp_vals, dim=1) + + # Third pass: compute softmax + for tile_seq in hl.tile(seq_len): + chunk = x[tile_batch, tile_seq] + exp_vals = torch.exp(chunk - max_vals[:, None]) + out[tile_batch, tile_seq] = exp_vals / sum_exp[:, None] + + return out + + # Test the kernel + x = torch.randn(4, 200, device="cuda") + test_kernel(softmax_kernel, softmax_spec, x) + +Puzzle 9: Simple FlashAttention +------------------------------- + +A scalar version of FlashAttention. + +.. code-block:: python + + def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]: + x = q[:, None] * k[None, :] + x_max = x.max(1, keepdim=True)[0] + x = x - x_max + x_exp = x.exp() + soft = x_exp / x_exp.sum(1, keepdim=True) + return (v[None, :] * soft).sum(1) + + @helion.kernel() + def flashatt_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # Get tensor size + seq_len = q.size(0) + # Create output tensor + out = torch.empty_like(q) + + # Process each query position + for tile_q in hl.tile(seq_len): + q_tile = q[tile_q] + + # Initialize tracking variables for stable softmax + max_val = torch.full_like(q_tile, float('-inf')) + sum_exp = torch.zeros_like(q_tile) + weighted_sum = torch.zeros_like(q_tile) + + # Process in tiles for better cache efficiency + for tile_kv in hl.tile(seq_len): + k_tile = k[tile_kv] + v_tile = v[tile_kv] + + # Compute attention scores + scores = q_tile[:, None] * k_tile[None, :] + + # Find max for numerical stability + batch_max = torch.max(scores, dim=1)[0] + new_max = torch.maximum(max_val, batch_max) + + # Scale old accumulations + scale_factor = torch.exp(max_val - new_max) + sum_exp = sum_exp * scale_factor + weighted_sum = weighted_sum * scale_factor + + # Update with new values + exp_scores = torch.exp(scores - new_max[:, None]) + sum_exp = sum_exp + torch.sum(exp_scores, dim=1) + weighted_sum = weighted_sum + torch.sum(exp_scores * v_tile[None, :], dim=1) + + # Update max_val + max_val = new_max + + # Compute final output + out[tile_q] = weighted_sum / sum_exp + + return out + + # Test the kernel + q = torch.randn(200, device="cuda") + k = torch.randn(200, device="cuda") + v = torch.randn(200, device="cuda") + test_kernel(flashatt_kernel, flashatt_spec, q, k, v) + +Puzzle 10: Two Dimensional Convolution +-------------------------------------- + +A batched 2D convolution. + +.. code-block:: python + + def conv2d_spec(x: Float32[Tensor, "4 8 8"], k: Float32[Tensor, "4 4"]) -> Float32[Tensor, "4 8 8"]: + z = torch.zeros(4, 8, 8) + x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0) + for i in range(8): + for j in range(8): + z[:, i, j] = (k[None, :, :] * x[:, i: i+4, j: j + 4]).sum(1).sum(1) + return z + + @helion.kernel() + def conv2d_kernel(x: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + batch, h, w = x.size() + kh, kw = k.size()[1:] + + # Create output tensor + out = torch.empty_like(x) + + # Pad the input + x_padded = torch.nn.functional.pad(x, (0, kw, 0, kh, 0, 0), value=0.0) + + # Use Helion to tile the computation + for tile_batch in hl.tile(batch): + # Process each output position + for i in range(h): + for j in range(w): + # Extract the patch + patch = x_padded[tile_batch, i:i+kh, j:j+kw] + # Apply the kernel + out[tile_batch, i, j] = (k[tile_batch] * patch).sum([1, 2]) + + return out + + # Test the kernel + x = torch.randn(4, 8, 8, device="cuda") + k = torch.randn(4, 4, 4, device="cuda") + test_kernel(conv2d_kernel, conv2d_spec, x, k) + +Puzzle 11: Matrix Multiplication +------------------------------- + +A blocked matrix multiplication. + +.. code-block:: python + + def dot_spec(x: Float32[Tensor, "4 32 32"], y: Float32[Tensor, "4 32 32"]) -> Float32[Tensor, "4 32 32"]: + return x @ y + + @helion.kernel() + def dot_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + batch, m, k = x.size() + _, k, n = y.size() + + # Create output tensor + out = torch.empty([batch, m, n], dtype=x.dtype, device=x.device) + + # Use Helion to tile the computation + for tile_batch in hl.tile(batch): + for tile_m, tile_n in hl.tile([m, n]): + # Initialize accumulator + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + + # Process the reduction dimension in tiles + for tile_k in hl.tile(k): + # Get tiles + x_tile = x[tile_batch, tile_m, tile_k] + y_tile = y[tile_batch, tile_k, tile_n] + + # Accumulate matrix multiplication + acc = acc + torch.matmul(x_tile, y_tile) + + # Store result + out[tile_batch, tile_m, tile_n] = acc + + return out + + # Test the kernel + x = torch.randn(4, 32, 32, device="cuda") + y = torch.randn(4, 32, 32, device="cuda") + test_kernel(dot_kernel, dot_spec, x, y) + +Puzzle 12: Quantized Matrix Multiplication +------------------------------------------ + +When doing matrix multiplication with quantized neural networks, a common strategy is to store the weight matrix in lower precision, with a shift and scale term. + +.. code-block:: python + + FPINT = 32 // 4 + GROUP = 8 + + def quant_dot_spec(scale: Float32[Tensor, "32 8"], + offset: Int32[Tensor, "32"], + weight: Int32[Tensor, "32 8"], + activation: Float32[Tensor, "64 32"]) -> Float32[Tensor, "32 32"]: + offset = offset.view(32, 1) + def extract(x): + over = torch.arange(8, device=x.device) * 4 + mask = 2**4 - 1 + return (x[..., None] >> over) & mask + scale = scale[..., None].expand(-1, 8, GROUP).contiguous().view(-1, 64) + offset = extract(offset)[..., None].expand(-1, 1, 8, GROUP).contiguous().view(-1, 64) + return (scale * (extract(weight).view(-1, 64) - offset)) @ activation + + @helion.kernel() + def quant_dot_kernel(scale: torch.Tensor, offset: torch.Tensor, weight: torch.Tensor, activation: torch.Tensor) -> torch.Tensor: + # Get tensor sizes + n_out, n_groups = scale.size() + mid, n_in = activation.size() + + # Create output tensor + out = torch.empty([n_out, n_in], dtype=scale.dtype, device=scale.device) + + # Helper function to extract 4-bit values + def extract_4bit(x, bit_positions): + mask = 2**4 - 1 + shifted = x[..., None] >> (bit_positions * 4) + return shifted & mask + + # Bit positions for extraction + bit_positions = torch.arange(8, device=scale.device) + + # Use Helion to tile the computation + for tile_out in hl.tile(n_out): + for tile_in in hl.tile(n_in): + # Initialize accumulator + acc = hl.zeros([tile_out, tile_in], dtype=torch.float32) + + # Get the offset values for this tile + offset_tile = offset[tile_out] + # Extract 4-bit values from offsets + offset_extracted = extract_4bit(offset_tile, bit_positions) + + # Process in chunks across the middle dimension + for group_idx in range(n_groups): + # Get scale for this group + scale_group = scale[tile_out, group_idx] + + # Get weights for this group + weight_group = weight[tile_out, group_idx] + + # Extract 4-bit values from weights + weight_extracted = extract_4bit(weight_group, bit_positions) + + # Compute dequantized weights: scale * (weight - offset) + offset_group = offset_extracted[:, group_idx:group_idx+1] # Shape: [tile_out, 1, 8] + dequant_weights = scale_group[:, None, None] * (weight_extracted - offset_group) + + # Reshape dequantized weights for matrix multiplication + dequant_weights = dequant_weights.reshape(tile_out.size(0), 8) + + # Get activations for this group + acts_idx = group_idx * 8 + torch.arange(8, device=scale.device) + act_group = activation[acts_idx][:, tile_in] + + # Accumulate to result + acc = acc + torch.matmul(dequant_weights, act_group) + + # Store result + out[tile_out, tile_in] = acc + + return out + + # Test the kernel with smaller inputs for quicker testing + scale = torch.randn(32, 8, device="cuda") + offset = torch.randint(-10, 10, (32,), device="cuda") + weight = torch.randint(0, 16, (32, 8), device="cuda", dtype=torch.int32) + activation = torch.randn(64, 32, device="cuda") + test_kernel(quant_dot_kernel, quant_dot_spec, scale, offset, weight, activation) + +Autotuning in Helion +-------------------- + +One of the major advantages of Helion is its sophisticated autotuning capability. Let's see how we can leverage this for our matrix multiplication kernel: + +.. code-block:: python + + import torch + import helion + import helion.language as hl + import time + + # Define a matrix multiplication kernel + @helion.kernel() # No config means autotuning will be used + def matmul_autotune(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = acc + torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc + + return out + + # Create larger tensors for better autotuning results + x = torch.randn(1024, 1024, device="cuda") + y = torch.randn(1024, 1024, device="cuda") + + # First run will trigger autotuning + print("Running with autotuning (this might take a while)...") + start = time.time() + result = matmul_autotune(x, y) + end = time.time() + print(f"First run time (including autotuning): {end - start:.2f}s") + + # Second run will use the tuned configuration + start = time.time() + result = matmul_autotune(x, y) + end = time.time() + print(f"Second run time (using tuned config): {end - start:.2f}s") + + # Verify correctness + expected = x @ y + print(f"Result is correct: {torch.allclose(result, expected, rtol=1e-2, atol=1e-2)}") + +Hardcoding Configurations +------------------------- + +After autotuning, you might want to hardcode the best configuration: + +.. code-block:: python + + # Example of hardcoding a configuration after autotuning + @helion.kernel(config=helion.Config( + block_sizes=[[64, 128], [16]], + loop_orders=[[1, 0]], + num_warps=4, + num_stages=3, + indexing='block_ptr', + l2_grouping=32 + )) + def matmul_fixed_config(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + k, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = acc + torch.matmul(x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc + + return out + + # Run with fixed configuration (no autotuning) + start = time.time() + result = matmul_fixed_config(x, y) + end = time.time() + print(f"Run time with fixed config: {end - start:.2f}s") + + # Verify correctness + expected = x @ y + print(f"Result is correct: {torch.allclose(result, expected, rtol=1e-2, atol=1e-2)}") + +Conclusion +---------- + +In this notebook, we've explored how to use Helion to write efficient GPU kernels using a high-level, PyTorch-like syntax. The key advantages of Helion include: + +1. **Higher-level abstraction** than raw Triton, making it easier to write correct kernels +2. **Automatic tiling and memory management**, eliminating a common source of bugs +3. **Powerful autotuning** that can explore a wide range of implementations automatically +4. **Familiar PyTorch syntax** that builds on existing knowledge + +These puzzles should give you a good foundation for writing your own Helion kernels for a variety of applications. diff --git a/docs/index.md b/docs/index.md index 21879d58..60bc939d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -244,4 +244,5 @@ for DEBUG-level logs. Alternatively, you can specify logging for specific module installation api/index +helion_puzzles ``` From d1074d31910e1a6f21ec4649e86f2cd1d577b6a9 Mon Sep 17 00:00:00 2001 From: sekyonda <127536312+sekyondaMeta@users.noreply.github.com> Date: Thu, 17 Jul 2025 11:02:51 -0400 Subject: [PATCH 2/9] Adding Examples and Helion Puzzle --- docs/GenerateExamples.py | 40 +++++++++++ docs/Makefile | 6 +- docs/examples.rst | 30 ++++++++ docs/examples/add.rst | 5 ++ docs/examples/all_gather_matmul.rst | 5 ++ docs/examples/attention.rst | 5 ++ docs/examples/bmm.rst | 5 ++ docs/examples/concatenate.rst | 5 ++ docs/examples/cross_entropy.rst | 5 ++ docs/examples/embedding.rst | 5 ++ docs/examples/exp.rst | 5 ++ docs/examples/fp8_attention.rst | 5 ++ docs/examples/fp8_gemm.rst | 5 ++ docs/examples/jagged_dense_add.rst | 5 ++ docs/examples/jagged_mean.rst | 5 ++ docs/examples/long_sum.rst | 5 ++ docs/examples/matmul.rst | 5 ++ docs/examples/matmul_layernorm.rst | 5 ++ docs/examples/matmul_split_k.rst | 5 ++ docs/examples/moe_matmul_ogs.rst | 5 ++ docs/examples/rms_norm.rst | 5 ++ docs/examples/segment_reduction.rst | 5 ++ docs/examples/softmax.rst | 5 ++ docs/examples/sum.rst | 5 ++ docs/examples/template_via_closure.rst | 5 ++ docs/index.md | 29 ++++---- docs/installation.md | 2 +- examples/add.py | 26 ++++++- examples/all_gather_matmul.py | 68 ++++++++++++++++-- examples/attention.py | 43 ++++++++++-- examples/bmm.py | 30 +++++++- examples/concatenate.py | 20 +++++- examples/cross_entropy.py | 20 +++++- examples/embedding.py | 22 +++++- examples/exp.py | 24 ++++++- examples/fp8_attention.py | 74 +++++++++++++++++++- examples/jagged_dense_add.py | 12 +++- examples/jagged_mean.py | 16 +++-- examples/long_sum.py | 62 ++++++++++++++++- examples/matmul.py | 27 +++++++- examples/matmul_layernorm.py | 45 +++++++++++- examples/matmul_split_k.py | 33 ++++++++- examples/moe_matmul_ogs.py | 20 +++++- examples/rms_norm.py | 35 +++++++++- examples/segment_reduction.py | 95 ++++++++++++++++++++++++-- examples/softmax.py | 51 +++++++++++++- examples/sum.py | 20 +++++- examples/template_via_closure.py | 35 ++++++++-- 48 files changed, 907 insertions(+), 88 deletions(-) create mode 100644 docs/GenerateExamples.py create mode 100644 docs/examples.rst create mode 100644 docs/examples/add.rst create mode 100644 docs/examples/all_gather_matmul.rst create mode 100644 docs/examples/attention.rst create mode 100644 docs/examples/bmm.rst create mode 100644 docs/examples/concatenate.rst create mode 100644 docs/examples/cross_entropy.rst create mode 100644 docs/examples/embedding.rst create mode 100644 docs/examples/exp.rst create mode 100644 docs/examples/fp8_attention.rst create mode 100644 docs/examples/fp8_gemm.rst create mode 100644 docs/examples/jagged_dense_add.rst create mode 100644 docs/examples/jagged_mean.rst create mode 100644 docs/examples/long_sum.rst create mode 100644 docs/examples/matmul.rst create mode 100644 docs/examples/matmul_layernorm.rst create mode 100644 docs/examples/matmul_split_k.rst create mode 100644 docs/examples/moe_matmul_ogs.rst create mode 100644 docs/examples/rms_norm.rst create mode 100644 docs/examples/segment_reduction.rst create mode 100644 docs/examples/softmax.rst create mode 100644 docs/examples/sum.rst create mode 100644 docs/examples/template_via_closure.rst diff --git a/docs/GenerateExamples.py b/docs/GenerateExamples.py new file mode 100644 index 00000000..c6dd41fd --- /dev/null +++ b/docs/GenerateExamples.py @@ -0,0 +1,40 @@ +import os +EXAMPLES_DIR = '../../examples' # Adjust as needed +RST_DIR = './examples' # Relative to your Sphinx source dir +example_files = [ + 'add.py', + 'all_gather_matmul.py', + 'attention.py', + 'bmm.py', + 'concatenate.py', + 'cross_entropy.py', + 'embedding.py', + 'exp.py', + 'fp8_attention.py', + 'fp8_gemm.py', + 'jagged_dense_add.py', + 'jagged_mean.py', + 'long_sum.py', + 'matmul.py', + 'matmul_layernorm.py', + 'matmul_split_k.py', + 'moe_matmul_ogs.py', + 'rms_norm.py', + 'segment_reduction.py', + 'softmax.py', + 'sum.py', + 'template_via_closure.py', +] +os.makedirs(RST_DIR, exist_ok=True) +for fname in example_files: + base = os.path.splitext(fname)[0] + # Capitalize and replace underscores with spaces for nicer titles + title = base.replace('_', ' ').title() + rst_path = os.path.join(RST_DIR, f"{base}.rst") + with open(rst_path, "w") as f: + f.write(f"""{title} +{'=' * len(title)} +.. literalinclude:: {os.path.join(EXAMPLES_DIR, fname)} + :language: python + :linenos: +""") diff --git a/docs/Makefile b/docs/Makefile index 731da962..e6a46333 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,12 +3,14 @@ SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = ../site -html: clean +html: clean genEx @$(SPHINXBUILD) -b html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -livehtml: clean +livehtml: clean genEx sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) --open-browser --port 0 +genEx: + python GenerateExamples.py clean: rm -rf $(BUILDDIR)/* diff --git a/docs/examples.rst b/docs/examples.rst new file mode 100644 index 00000000..d277e8e1 --- /dev/null +++ b/docs/examples.rst @@ -0,0 +1,30 @@ +Examples +======== + +Examples showing the use of Helios in various scenarios. + +.. toctree:: + :maxdepth: 1 + + examples/add + examples/all_gather_matmul + examples/attention + examples/bmm + examples/concatenate + examples/cross_entropy + examples/embedding + examples/exp + examples/fp8_attention + examples/fp8_gemm + examples/jagged_dense_add + examples/jagged_mean + examples/long_sum + examples/matmul + examples/matmul_layernorm + examples/matmul_split_k + examples/moe_matmul_ogs + examples/rms_norm + examples/segment_reduction + examples/softmax + examples/sum + examples/template_via_closure diff --git a/docs/examples/add.rst b/docs/examples/add.rst new file mode 100644 index 00000000..6c102c02 --- /dev/null +++ b/docs/examples/add.rst @@ -0,0 +1,5 @@ +Add +=== +.. literalinclude:: ../../examples/add.py + :language: python + :linenos: diff --git a/docs/examples/all_gather_matmul.rst b/docs/examples/all_gather_matmul.rst new file mode 100644 index 00000000..9f290a5f --- /dev/null +++ b/docs/examples/all_gather_matmul.rst @@ -0,0 +1,5 @@ +All Gather Matmul +================= +.. literalinclude:: ../../examples/all_gather_matmul.py + :language: python + :linenos: diff --git a/docs/examples/attention.rst b/docs/examples/attention.rst new file mode 100644 index 00000000..e8a3c1fe --- /dev/null +++ b/docs/examples/attention.rst @@ -0,0 +1,5 @@ +Attention +========= +.. literalinclude:: ../../examples/attention.py + :language: python + :linenos: diff --git a/docs/examples/bmm.rst b/docs/examples/bmm.rst new file mode 100644 index 00000000..c971ec71 --- /dev/null +++ b/docs/examples/bmm.rst @@ -0,0 +1,5 @@ +Bmm +=== +.. literalinclude:: ../../examples/bmm.py + :language: python + :linenos: diff --git a/docs/examples/concatenate.rst b/docs/examples/concatenate.rst new file mode 100644 index 00000000..417e50c3 --- /dev/null +++ b/docs/examples/concatenate.rst @@ -0,0 +1,5 @@ +Concatenate +=========== +.. literalinclude:: ../../examples/concatenate.py + :language: python + :linenos: diff --git a/docs/examples/cross_entropy.rst b/docs/examples/cross_entropy.rst new file mode 100644 index 00000000..3c3bfe98 --- /dev/null +++ b/docs/examples/cross_entropy.rst @@ -0,0 +1,5 @@ +Cross Entropy +============= +.. literalinclude:: ../../examples/cross_entropy.py + :language: python + :linenos: diff --git a/docs/examples/embedding.rst b/docs/examples/embedding.rst new file mode 100644 index 00000000..97b585a6 --- /dev/null +++ b/docs/examples/embedding.rst @@ -0,0 +1,5 @@ +Embedding +========= +.. literalinclude:: ../../examples/embedding.py + :language: python + :linenos: diff --git a/docs/examples/exp.rst b/docs/examples/exp.rst new file mode 100644 index 00000000..1b7a6be8 --- /dev/null +++ b/docs/examples/exp.rst @@ -0,0 +1,5 @@ +Exp +=== +.. literalinclude:: ../../examples/exp.py + :language: python + :linenos: diff --git a/docs/examples/fp8_attention.rst b/docs/examples/fp8_attention.rst new file mode 100644 index 00000000..67917be6 --- /dev/null +++ b/docs/examples/fp8_attention.rst @@ -0,0 +1,5 @@ +Fp8 Attention +============= +.. literalinclude:: ../../examples/fp8_attention.py + :language: python + :linenos: diff --git a/docs/examples/fp8_gemm.rst b/docs/examples/fp8_gemm.rst new file mode 100644 index 00000000..ae4e4230 --- /dev/null +++ b/docs/examples/fp8_gemm.rst @@ -0,0 +1,5 @@ +Fp8 Gemm +======== +.. literalinclude:: ../../examples/fp8_gemm.py + :language: python + :linenos: diff --git a/docs/examples/jagged_dense_add.rst b/docs/examples/jagged_dense_add.rst new file mode 100644 index 00000000..26909aeb --- /dev/null +++ b/docs/examples/jagged_dense_add.rst @@ -0,0 +1,5 @@ +Jagged Dense Add +================ +.. literalinclude:: ../../examples/jagged_dense_add.py + :language: python + :linenos: diff --git a/docs/examples/jagged_mean.rst b/docs/examples/jagged_mean.rst new file mode 100644 index 00000000..638935f6 --- /dev/null +++ b/docs/examples/jagged_mean.rst @@ -0,0 +1,5 @@ +Jagged Mean +=========== +.. literalinclude:: ../../examples/jagged_mean.py + :language: python + :linenos: diff --git a/docs/examples/long_sum.rst b/docs/examples/long_sum.rst new file mode 100644 index 00000000..ae71aa7b --- /dev/null +++ b/docs/examples/long_sum.rst @@ -0,0 +1,5 @@ +Long Sum +======== +.. literalinclude:: ../../examples/long_sum.py + :language: python + :linenos: diff --git a/docs/examples/matmul.rst b/docs/examples/matmul.rst new file mode 100644 index 00000000..9c07aeb5 --- /dev/null +++ b/docs/examples/matmul.rst @@ -0,0 +1,5 @@ +Matmul +====== +.. literalinclude:: ../../examples/matmul.py + :language: python + :linenos: diff --git a/docs/examples/matmul_layernorm.rst b/docs/examples/matmul_layernorm.rst new file mode 100644 index 00000000..05e71050 --- /dev/null +++ b/docs/examples/matmul_layernorm.rst @@ -0,0 +1,5 @@ +Matmul Layernorm +================ +.. literalinclude:: ../../examples/matmul_layernorm.py + :language: python + :linenos: diff --git a/docs/examples/matmul_split_k.rst b/docs/examples/matmul_split_k.rst new file mode 100644 index 00000000..7e40a33e --- /dev/null +++ b/docs/examples/matmul_split_k.rst @@ -0,0 +1,5 @@ +Matmul Split K +============== +.. literalinclude:: ../../examples/matmul_split_k.py + :language: python + :linenos: diff --git a/docs/examples/moe_matmul_ogs.rst b/docs/examples/moe_matmul_ogs.rst new file mode 100644 index 00000000..f9038cea --- /dev/null +++ b/docs/examples/moe_matmul_ogs.rst @@ -0,0 +1,5 @@ +Moe Matmul Ogs +============== +.. literalinclude:: ../../examples/moe_matmul_ogs.py + :language: python + :linenos: diff --git a/docs/examples/rms_norm.rst b/docs/examples/rms_norm.rst new file mode 100644 index 00000000..dc3789c6 --- /dev/null +++ b/docs/examples/rms_norm.rst @@ -0,0 +1,5 @@ +Rms Norm +======== +.. literalinclude:: ../../examples/rms_norm.py + :language: python + :linenos: diff --git a/docs/examples/segment_reduction.rst b/docs/examples/segment_reduction.rst new file mode 100644 index 00000000..69d72594 --- /dev/null +++ b/docs/examples/segment_reduction.rst @@ -0,0 +1,5 @@ +Segment Reduction +================= +.. literalinclude:: ../../examples/segment_reduction.py + :language: python + :linenos: diff --git a/docs/examples/softmax.rst b/docs/examples/softmax.rst new file mode 100644 index 00000000..4f07e1c6 --- /dev/null +++ b/docs/examples/softmax.rst @@ -0,0 +1,5 @@ +Softmax +======= +.. literalinclude:: ../../examples/softmax.py + :language: python + :linenos: diff --git a/docs/examples/sum.rst b/docs/examples/sum.rst new file mode 100644 index 00000000..438df2b7 --- /dev/null +++ b/docs/examples/sum.rst @@ -0,0 +1,5 @@ +Sum +=== +.. literalinclude:: ../../examples/sum.py + :language: python + :linenos: diff --git a/docs/examples/template_via_closure.rst b/docs/examples/template_via_closure.rst new file mode 100644 index 00000000..fcf278b8 --- /dev/null +++ b/docs/examples/template_via_closure.rst @@ -0,0 +1,5 @@ +Template Via Closure +==================== +.. literalinclude:: ../../examples/template_via_closure.py + :language: python + :linenos: diff --git a/docs/index.md b/docs/index.md index 60bc939d..9194697e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,7 +1,20 @@ # Helion Documentation -> ⚠️ **Early Development Warning** -> Helion is currently in an experimental stage. You should expect bugs, incomplete features, and APIs that may change in future versions. Feedback and bug reports are welcome and appreciated! + +```{toctree} +:maxdepth: 1 +:caption: Table of Contents: +:hidden: + +installation +helion_puzzles +examples +api/index + +``` + +⚠️ **Early Development Warning** +Helion is currently in an experimental stage. You should expect bugs, incomplete features, and APIs that may change in future versions. Feedback and bug reports are welcome and appreciated! **Helion** is a Python-embedded domain-specific language (DSL) for authoring machine learning kernels, designed to compile down to [Triton], @@ -234,15 +247,3 @@ variable will be ignored. Enable logging by setting the environment variable `HELION_LOGS=all` for INFO-level logs, or `HELION_LOGS=+all` for DEBUG-level logs. Alternatively, you can specify logging for specific modules using a comma-separated list (e.g., `HELION_LOGS=+helion.runtime.kernel`). - - -## Table of Contents - -```{toctree} -:maxdepth: 1 -:caption: Contents: - -installation -api/index -helion_puzzles -``` diff --git a/docs/installation.md b/docs/installation.md index 41b52954..112b9353 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -156,4 +156,4 @@ Matches the requirements of [Triton](https://github.com/triton-lang/triton). At Once installation is complete: 1. **Check out the {doc}`api/index` for complete API documentation** -2. **Explore the [examples/](https://github.com/pytorch-labs/helion/tree/main/examples) folder for real-world patterns** +2. **Explore the [examples](examples/) and [Helion Puzzles](helion_puzzles) pages for real-world patterns** diff --git a/examples/add.py b/examples/add.py index c940a626..c937f4ec 100644 --- a/examples/add.py +++ b/examples/add.py @@ -1,14 +1,24 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + @helion.kernel() def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Add two tensors element-wise with broadcasting support. + + Args: + x: First input tensor + y: Second input tensor + + Returns: + A new tensor containing the element-wise sum of x and y + """ # match pytorch broadcasting rules x, y = torch.broadcast_tensors(x, y) out = torch.empty( @@ -24,12 +34,22 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def check(m: int, n: int) -> None: + """ + Verify the add kernel implementation against PyTorch's native add function. + + Args: + m: First dimension of the test tensors + n: Second dimension of the test tensors + """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) y = torch.randn([m, n], device="cuda", dtype=torch.float16) run_example(add, torch.add, (x, y)) def main() -> None: + """ + Main entry point that runs the add kernel verification with 1024x1024 tensors. + """ check(1024, 1024) diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py index e93f28bd..c0243fda 100644 --- a/examples/all_gather_matmul.py +++ b/examples/all_gather_matmul.py @@ -2,13 +2,13 @@ import os +import helion +import helion.language as hl + import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem -import helion -import helion.language as hl - def copy_engine_all_gather_w_progress( output: torch.Tensor, @@ -17,6 +17,19 @@ def copy_engine_all_gather_w_progress( splits_per_rank: int, backend_stream: torch.cuda.Stream | None = None, ) -> torch.cuda.Stream: + """ + Performs an all-gather operation with progress tracking using symmetric memory. + + Args: + output: The output tensor to store the gathered results + inp: The input tensor to be gathered (must be a symmetric tensor) + progress: Tensor used to track progress of the operation + splits_per_rank: Number of splits per rank + backend_stream: CUDA stream for backend operations (optional) + + Returns: + The CUDA stream used for the operation + """ backend_stream = symm_mem._get_backend_stream(priority=-1) assert inp.is_contiguous() symm_mem_group = dist.group.WORLD @@ -78,6 +91,20 @@ def helion_matmul_w_progress( SPLITS_PER_RANK: int, RANK: int, ) -> torch.Tensor: + """ + Performs matrix multiplication with progress tracking. + + Args: + a: First input tensor for matrix multiplication + a_shared: Shared tensor across ranks + b: Second input tensor for matrix multiplication + progress: Tensor used to track progress of the operation + SPLITS_PER_RANK: Number of splits per rank + RANK: Current process rank + + Returns: + The result of the matrix multiplication + """ M, K = a.size() K2, N = b.size() assert K2 == K, f"size mismatch {K2} != {K}" @@ -119,6 +146,21 @@ def helion_all_gather_matmul( progress: torch.Tensor | None = None, **kwargs: int, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Combines all-gather and matrix multiplication operations. + + Args: + a_shared: Shared tensor across ranks to be gathered + b: Second input tensor for matrix multiplication + a_out: Optional output tensor for the gathered results + progress: Optional tensor used to track progress of the operation + **kwargs: Additional keyword arguments including splits_per_rank + + Returns: + A tuple containing: + - The gathered tensor + - The result of the matrix multiplication + """ configs = { "SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1), } @@ -169,6 +211,16 @@ def helion_all_gather_matmul( def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None: + """ + Tests the helion_all_gather_matmul function against PyTorch's implementation. + + Args: + M: First dimension of the matrix + N: Second dimension of the matrix + K: Third dimension of the matrix + world_size: Number of processes + device: Device to run the test on + """ a_shared = symm_mem.empty( M // world_size, K, dtype=torch.bfloat16, device=device ).normal_() @@ -180,14 +232,20 @@ def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None: dist_group = dist.group.WORLD if dist_group is None: raise RuntimeError("No distributed group available") - ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( # pyright: ignore[reportCallIssue] - golden_a, [b], gather_dim=0, group_name=dist_group.group_name + ag_golden, mm_golden = ( + torch.ops.symm_mem.fused_all_gather_matmul( # pyright: ignore[reportCallIssue] + golden_a, [b], gather_dim=0, group_name=dist_group.group_name + ) ) torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1) torch.testing.assert_close(a_out, ag_golden) def main() -> None: + """ + Main entry point that initializes the distributed environment and runs the test. + Sets up the distributed process group, runs the test, and then cleans up. + """ rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) torch.manual_seed(42 + rank) diff --git a/examples/attention.py b/examples/attention.py index 3d0b0149..21ff9738 100644 --- a/examples/attention.py +++ b/examples/attention.py @@ -1,16 +1,15 @@ from __future__ import annotations import math -from typing import Callable -from typing import cast - -import torch -from torch.nn.attention.flex_attention import flex_attention +from typing import Callable, cast import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example +from torch.nn.attention.flex_attention import flex_attention + @helion.kernel( # Static shapes provides a speedup for attention @@ -21,6 +20,19 @@ def attention( k_in: torch.Tensor, v_in: torch.Tensor, ) -> torch.Tensor: + """ + Computes scaled dot-product attention. + + Implements the attention mechanism: Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V + + Args: + q_in: Query tensor of shape [..., seq_len_q, head_dim] + k_in: Key tensor of shape [..., seq_len_k, head_dim] + v_in: Value tensor of shape [..., seq_len_k, head_dim] + + Returns: + Output tensor of shape [..., seq_len_q, head_dim] + """ m_dim = q_in.size(-2) n_dim = k_in.size(-2) assert n_dim == v_in.size(-2) @@ -62,6 +74,10 @@ def attention( configs=attention.configs, # pyright: ignore[reportArgumentType] static_shapes=False, ) +""" +Dynamic shape version of the attention kernel. +This version allows for variable input shapes at runtime. +""" def test( @@ -72,6 +88,17 @@ def test( dtype: torch.dtype = torch.float32, device: torch.device | str = "cuda", ) -> None: + """ + Test the attention kernel implementation against PyTorch's native attention functions. + + Args: + z: Batch size + h: Number of attention heads + n_ctx: Sequence length (context size) + head_dim: Dimension of each attention head + dtype: Data type for the tensors + device: Device to run the test on + """ q, k, v = [ torch.randn((z, h, n_ctx, head_dim), dtype=dtype, device=device) for _ in range(3) @@ -98,6 +125,10 @@ def ref_attention( def main() -> None: + """ + Main entry point that runs the attention kernel test with specific parameters. + Tests with batch size 2, 32 heads, 1024 sequence length, and 64-dimensional heads using float16. + """ test(2, 32, 1024, 64, torch.float16) diff --git a/examples/bmm.py b/examples/bmm.py index bdae21b3..2ceb1c86 100644 --- a/examples/bmm.py +++ b/examples/bmm.py @@ -1,15 +1,25 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + """ + Performs batch matrix multiplication. + + Args: + A: Input tensor of shape [B, M, K] + B: Input tensor of shape [B, K, N] + + Returns: + Output tensor of shape [B, M, N] containing the result of batch matrix multiplication + """ # A: [B, M, K], B: [B, K, N], Out: [B, M, N] # dense bmm b, m, k = A.size() b, k, n = B.size() @@ -27,12 +37,26 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: def check(b: int, m: int, k: int, n: int) -> None: + """ + Verify the bmm kernel implementation against PyTorch's native bmm function. + + Args: + b: Batch size + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ x = torch.randn([b, m, k], device="cuda", dtype=torch.float16) y = torch.randn([b, k, n], device="cuda", dtype=torch.float16) run_example(bmm, torch.bmm, (x, y)) def main() -> None: + """ + Main entry point that runs the bmm kernel verification with specific parameters. + Tests with batch size 16, and matrices of dimensions 512x768 and 768x1024. + Ensures torch version is at least 2.8 for 16-bit tensor support in baddbmm. + """ # torch.baddbmm support for 16-bit tensors requires torch 2.8+ assert torch.__version__.split(".")[:2] >= ["2", "8"], "Requires torch 2.8+" check(16, 512, 768, 1024) diff --git a/examples/concatenate.py b/examples/concatenate.py index cb72cf72..439a5dcb 100644 --- a/examples/concatenate.py +++ b/examples/concatenate.py @@ -1,14 +1,24 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + @helion.kernel() def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Concatenates two 2D tensors along dimension 1 (columns). + + Args: + x: First input tensor of shape [M, N1] + y: Second input tensor of shape [M, N2] with same first dimension as x + + Returns: + Output tensor of shape [M, N1+N2] containing the concatenation of x and y along dimension 1 + """ assert x.size(0) == y.size(0) out = torch.empty( [x.size(0), x.size(1) + y.size(1)], dtype=x.dtype, device=x.device @@ -30,6 +40,10 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def main() -> None: + """ + Main entry point that runs the concatenation kernel verification. + Tests with two tensors of shapes [1500, 400] and [1500, 600]. + """ x = torch.randn([1500, 400], device="cuda") y = torch.randn([1500, 600], device="cuda") run_example(concat2d_dim1, lambda x, y: torch.cat([x, y], dim=1), (x, y)) diff --git a/examples/cross_entropy.py b/examples/cross_entropy.py index 28f36cd1..d285566c 100644 --- a/examples/cross_entropy.py +++ b/examples/cross_entropy.py @@ -2,12 +2,12 @@ import os -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + # TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": # Low memory configuration @@ -19,6 +19,20 @@ def cross_entropy( logits: torch.Tensor, # [N, V] input logits labels: torch.Tensor, # [N] target labels ) -> torch.Tensor: + """ + Computes the cross entropy loss between logits and target labels. + + Implements the cross entropy loss function commonly used in classification tasks. + The function computes the log softmax of the logits and then calculates the negative + log likelihood of the true labels. + + Args: + logits: Input logits tensor of shape [N, V] where N is batch size and V is vocabulary size + labels: Target labels tensor of shape [N] containing class indices + + Returns: + A scalar tensor containing the mean cross entropy loss + """ n, v = logits.shape losses = torch.zeros([n], dtype=logits.dtype, device=logits.device) diff --git a/examples/embedding.py b/examples/embedding.py index e9e99f84..b3c61ca8 100644 --- a/examples/embedding.py +++ b/examples/embedding.py @@ -1,14 +1,26 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + @helion.kernel() def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Performs embedding lookup for input indices. + + Maps indices in the input tensor to vectors from the embedding weight matrix. + + Args: + x: Input tensor of indices of any shape + weight: Embedding weight matrix of shape [num_embeddings, embedding_dim] + + Returns: + Output tensor of shape [*x.shape, embedding_dim] containing the embedding vectors + """ x_flat = x.reshape(-1) # collapse x into a single dimension _, embedding_dim = weight.size() out = torch.empty( @@ -28,6 +40,10 @@ def embedding_tritonbench( def main() -> None: + """ + Main entry point that runs the embedding kernel verification. + Tests with a batch of indices and an embedding table of size 16x64. + """ num_embeddings, embedding_dim = 16, 64 x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32) weight = torch.randn([num_embeddings, embedding_dim], device="cuda") diff --git a/examples/exp.py b/examples/exp.py index 357f4862..ad32c216 100644 --- a/examples/exp.py +++ b/examples/exp.py @@ -1,14 +1,23 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + @helion.kernel() def exp(x: torch.Tensor) -> torch.Tensor: + """ + Computes the exponential of all elements in the input tensor. + + Args: + x: Input tensor + + Returns: + Output tensor with the exponential of each element in the input + """ out = torch.empty_like(x) for tile in hl.tile(x.size()): out[tile] = torch.exp(x[tile]) @@ -21,11 +30,20 @@ def exp_tritonbench(x: torch.Tensor) -> dict[str, torch.Tensor]: def check(n: int) -> None: + """ + Verify the exp kernel implementation against PyTorch's native exp function. + + Args: + n: Size of the test tensor + """ x = torch.randn(n, device="cuda", dtype=torch.float32) run_example(exp, torch.exp, (x,)) def main() -> None: + """ + Main entry point that runs the exp kernel verification with a tensor of size 1M elements. + """ check(1024 * 1024) diff --git a/examples/fp8_attention.py b/examples/fp8_attention.py index f9c5153b..71c5a622 100644 --- a/examples/fp8_attention.py +++ b/examples/fp8_attention.py @@ -3,11 +3,11 @@ import math from typing import Callable -import torch - import helion import helion.language as hl +import torch + @helion.kernel(static_shapes=True) def fp8_attention_kernel( @@ -17,6 +17,21 @@ def fp8_attention_kernel( batch: int, heads: int, ) -> torch.Tensor: + """ + Computes scaled dot-product attention using FP8 precision. + + Implements the attention mechanism with FP8 tensors for improved performance and memory efficiency. + + Args: + q: Query tensor of shape [batch*heads, seq, dim] in FP8 format + k: Key tensor of shape [batch*heads, seq, dim] in FP8 format + v: Value tensor of shape [batch*heads, dim, seq] (pre-transposed) in FP8 format + batch: Number of batches + heads: Number of attention heads + + Returns: + Output tensor of shape [batch, heads, seq_len, head_dim] in FP8 format + """ batch_heads = q.size(0) seq_len = q.size(1) head_dim = q.size(2) @@ -108,6 +123,20 @@ def fp8_attention_kernel( def preprocess_fp8_attention_inputs( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Preprocesses attention inputs by converting them to FP8 format and reshaping. + + Args: + q: Query tensor of shape [batch, heads, seq_len, head_dim] + k: Key tensor of shape [batch, heads, seq_len, head_dim] + v: Value tensor of shape [batch, heads, seq_len, head_dim] + + Returns: + Tuple of (q_fp8, k_fp8, v_fp8) where: + - q_fp8: Query tensor in FP8 format with shape [batch*heads, seq_len, head_dim] + - k_fp8: Key tensor in FP8 format with shape [batch*heads, seq_len, head_dim] + - v_fp8: Value tensor in FP8 format with shape [batch*heads, head_dim, seq_len] (pre-transposed) + """ q_fp8 = q.to(torch.float8_e5m2) k_fp8 = k.to(torch.float8_e5m2) v = v.permute(0, 1, 3, 2) @@ -122,6 +151,19 @@ def preprocess_fp8_attention_inputs( def fp8_attention_tritonbench( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> Callable[[], torch.Tensor]: + """ + Creates a callable function for benchmarking FP8 attention with tritonbench. + + Preprocesses inputs and returns a lambda function that calls the FP8 attention kernel. + + Args: + q: Query tensor of shape [batch, heads, seq_len, head_dim] + k: Key tensor of shape [batch, heads, seq_len, head_dim] + v: Value tensor of shape [batch, heads, seq_len, head_dim] + + Returns: + A callable function that executes the FP8 attention kernel + """ batch, heads, seq_len, head_dim = q.shape q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v) # Return lambda that calls the kernel - preprocessing is done outside. @@ -138,6 +180,21 @@ def _fp8_attention_pytorch_impl( seq_len: int, head_dim: int, ) -> torch.Tensor: + """ + PyTorch implementation of FP8 attention for comparison with the kernel version. + + Args: + q_fp8: Query tensor in FP8 format with shape [batch*heads, seq_len, head_dim] + k_fp8: Key tensor in FP8 format with shape [batch*heads, seq_len, head_dim] + v_fp8: Value tensor in FP8 format with shape [batch*heads, head_dim, seq_len] (pre-transposed) + batch: Number of batches + heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + + Returns: + Output tensor of shape [batch, heads, seq_len, head_dim] in FP8 format + """ sm_scale = 1.0 / math.sqrt(float(head_dim)) outputs = [] @@ -204,6 +261,15 @@ def fp8_attention_pytorch( def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None: + """ + Verifies the FP8 attention kernel implementation against the PyTorch reference implementation. + + Args: + batch: Number of batches + heads: Number of attention heads + seq_len: Sequence length + head_dim: Dimension of each attention head + """ torch.manual_seed(42) q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") @@ -223,6 +289,10 @@ def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None: def main() -> None: + """ + Main entry point that runs the FP8 attention kernel verification with different configurations. + Tests with small, medium, and large attention configurations. + """ check(1, 2, 128, 64) check(2, 4, 256, 64) check(4, 8, 512, 128) diff --git a/examples/jagged_dense_add.py b/examples/jagged_dense_add.py index d5fb91e8..b795db72 100644 --- a/examples/jagged_dense_add.py +++ b/examples/jagged_dense_add.py @@ -1,11 +1,11 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + """ A tensor x is stored in a jagged-row, prefix-sparse layout that packs only the non-zero elements of each row. All non-zeros are concatenated into a one-dimensional buffer @@ -106,6 +106,12 @@ def random_jagged_2d( def main() -> None: + """ + Main entry point that runs the jagged dense add kernel verification. + + Creates random jagged 2D data and a dense tensor, then compares the kernel + implementation against the PyTorch reference implementation. + """ rows, cols = 256, 5000 x_data, x_offsets = random_jagged_2d(rows, cols, device="cuda") y = torch.randn([rows, cols], device="cuda") diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index cbc6e99d..a7eb263a 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -2,12 +2,12 @@ import os -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + # TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": # Low memory configuration @@ -139,12 +139,20 @@ def jagged_mean_tritonbench( dtype=torch.int32, device=x_values.device, # pyright: ignore[reportAttributeAccessIssue] ) - max_M_tensor = torch.empty(M, device=x_values.device) # pyright: ignore[reportAttributeAccessIssue] + max_M_tensor = torch.empty( + M, device=x_values.device + ) # pyright: ignore[reportAttributeAccessIssue] return jagged_mean_kernel(x_values, x_offsets, feature_counts, max_M_tensor) def main() -> None: + """ + Main entry point that runs the jagged mean kernel verification. + + Creates test data with random jagged tensors and feature counts, then compares + the kernel implementation against the PyTorch reference implementation. + """ num_rows, max_cols = 32, 64 device = "cuda" diff --git a/examples/long_sum.py b/examples/long_sum.py index 543869da..47948920 100644 --- a/examples/long_sum.py +++ b/examples/long_sum.py @@ -1,13 +1,22 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + def baseline_sum(x: torch.Tensor) -> torch.Tensor: + """ + PyTorch baseline implementation of sum reduction along the last dimension. + + Args: + x: Input tensor + + Returns: + Tensor with sum of elements along the last dimension + """ return x.sum(-1) @@ -22,6 +31,17 @@ def baseline_sum(x: torch.Tensor) -> torch.Tensor: ) ) def longsum(x: torch.Tensor) -> torch.Tensor: + """ + Naive reduction kernel that sums elements along the last dimension. + + Loads the entire reduction dimension at once and reduces in registers. + + Args: + x: Input tensor of shape [m, n] + + Returns: + Output tensor of shape [m] containing the sum of each row + """ m, _ = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) @@ -43,6 +63,17 @@ def longsum(x: torch.Tensor) -> torch.Tensor: ) ) def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor: + """ + Looped reduction kernel that sums elements along the last dimension. + + Uses a reduction loop with a specified tile size to handle large dimensions efficiently. + + Args: + x: Input tensor of shape [m, n] + + Returns: + Output tensor of shape [m] containing the sum of each row + """ m, _ = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) @@ -58,6 +89,17 @@ def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor: ) ) def longsum_manual(x: torch.Tensor) -> torch.Tensor: + """ + Manual implementation of looped reduction for summing elements along the last dimension. + + Manually implements the reduction loop with explicit accumulation and final reduction. + + Args: + x: Input tensor of shape [m, n] + + Returns: + Output tensor of shape [m] containing the sum of each row + """ m, n = x.size() out = torch.empty([m], dtype=x.dtype, device=x.device) @@ -73,6 +115,15 @@ def longsum_manual(x: torch.Tensor) -> torch.Tensor: def check(m: int, n: int) -> None: + """ + Verify the sum kernel implementations against PyTorch's native sum function. + + Tests all three kernel variants (naive, looped, manual) against the baseline. + + Args: + m: First dimension of the test tensor + n: Second dimension of the test tensor (reduction dimension) + """ x = torch.randn([m, n], device="cuda", dtype=torch.float32) # Test all three kernel variants against the baseline @@ -86,6 +137,11 @@ def check(m: int, n: int) -> None: def main() -> None: + """ + Main entry point that runs the sum kernel verification with a large tensor. + + Tests with a tensor of shape [4, 130000] to demonstrate handling of long reduction dimensions. + """ check(4, 130000) # seq_len = 128k diff --git a/examples/matmul.py b/examples/matmul.py index 1f6ad675..93c87965 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -1,15 +1,25 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Performs matrix multiplication between two tensors. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + + Returns: + Output tensor of shape [M, N] containing the result of matrix multiplication + """ m, k = x.size() k2, n = y.size() assert k == k2, f"size mismatch {k} != {k2}" @@ -25,12 +35,23 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def check(m: int, k: int, n: int) -> None: + """ + Verify the matmul kernel implementation against PyTorch's native matmul function. + + Args: + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) run_example(matmul, torch.matmul, (x, y)) def main() -> None: + """ + Main entry point that runs the matmul kernel verification with 1024x1024 matrices. + """ check(1024, 1024, 1024) diff --git a/examples/matmul_layernorm.py b/examples/matmul_layernorm.py index 4e5ecc35..fc70b1f3 100644 --- a/examples/matmul_layernorm.py +++ b/examples/matmul_layernorm.py @@ -1,11 +1,11 @@ from __future__ import annotations +import helion +import helion.language as hl + import torch import torch.nn.functional as F - -import helion from helion._testing import run_example -import helion.language as hl # static_shapes=True gives a performance boost for matmuls @@ -13,6 +13,18 @@ def matmul_layernorm( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: + """ + Performs matrix multiplication followed by layer normalization. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + weight: Layer normalization weight parameter of shape [N] + bias: Layer normalization bias parameter of shape [N] + + Returns: + Output tensor of shape [M, N] containing the result of matrix multiplication followed by layer normalization + """ m, k = x.size() k2 = y.size(0) n = hl.register_reduction_dim(y.size(1)) @@ -38,6 +50,18 @@ def matmul_layernorm( def matmul_layernorm_pytorch( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: + """ + PyTorch reference implementation of matrix multiplication followed by layer normalization. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + weight: Layer normalization weight parameter of shape [N] + bias: Layer normalization bias parameter of shape [N] + + Returns: + Output tensor of shape [M, N] containing the result of matrix multiplication followed by layer normalization + """ matmul_out = torch.matmul(x, y) ln_out = F.layer_norm( @@ -51,6 +75,14 @@ def matmul_layernorm_pytorch( def check(m: int, k: int, n: int) -> None: + """ + Verify the matmul_layernorm kernel implementation against the PyTorch reference implementation. + + Args: + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) weight = torch.randn([n], device="cuda", dtype=torch.float16) @@ -59,6 +91,13 @@ def check(m: int, k: int, n: int) -> None: def main() -> None: + """ + Main entry point that runs the matmul_layernorm kernel verification with different matrix sizes. + + Tests with two configurations: + - 32x64 * 64x200 + - 128x256 * 256x400 + """ # TODO(yf225): n=64 or 128 throws error, need to investigate # check(32, 64, 64) # check(32, 64, 128) diff --git a/examples/matmul_split_k.py b/examples/matmul_split_k.py index 66f87449..11c118e5 100644 --- a/examples/matmul_split_k.py +++ b/examples/matmul_split_k.py @@ -1,16 +1,29 @@ from __future__ import annotations -import torch - import helion +import helion.language as hl + +import torch from helion._testing import run_example from helion.autotuner import PowerOfTwoFragment -import helion.language as hl # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Performs matrix multiplication using split-K algorithm for better parallelism. + + Split-K divides the reduction dimension (K) into multiple chunks that can be processed + in parallel, with results atomically accumulated at the end. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + + Returns: + Output tensor of shape [M, N] containing the result of matrix multiplication + """ m, k = x.size() k2, n = y.size() assert k == k2, f"size mismatch {k} != {k2}" @@ -28,12 +41,26 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def check(m: int, k: int, n: int) -> None: + """ + Verify the split-K matmul kernel implementation against PyTorch's native matmul function. + + Args: + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ x = torch.randn([m, k], device="cuda", dtype=torch.float16) y = torch.randn([k, n], device="cuda", dtype=torch.float16) run_example(matmul_split_k, torch.matmul, (x, y), atol=1) def main() -> None: + """ + Main entry point that runs the split-K matmul kernel verification. + + Tests with matrices of shape 64x32768 and 32768x64, which benefits from the split-K approach + due to the large reduction dimension. + """ check(64, 32768, 64) diff --git a/examples/moe_matmul_ogs.py b/examples/moe_matmul_ogs.py index 66b9af24..63c29328 100644 --- a/examples/moe_matmul_ogs.py +++ b/examples/moe_matmul_ogs.py @@ -4,12 +4,12 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + @helion.kernel(static_shapes=False) def moe_matmul_ogs( @@ -151,6 +151,15 @@ def moe_matmul_ogs_reference( def check(T: int, K: int, N: int, n_experts: int) -> None: + """ + Verify the MoE matmul OGS kernel implementation against the reference implementation. + + Args: + T: Number of tokens + K: Input feature dimension + N: Output feature dimension + n_experts: Number of experts + """ dtype = torch.float16 device = "cuda" if torch.cuda.is_available() else "cpu" @@ -173,6 +182,11 @@ def reference_fn() -> torch.Tensor: def main() -> None: + """ + Main entry point that runs the MoE matmul OGS kernel verification. + + Tests with 1000 tokens, 500 input features, 200 output features, and 30 experts. + """ check(1000, 500, 200, 30) diff --git a/examples/rms_norm.py b/examples/rms_norm.py index c1b46841..cedc17de 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -1,11 +1,11 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + # TritonBench configuration # TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg. TRITONBENCH_ARGS = {"num_inputs": 3} @@ -13,6 +13,20 @@ @helion.kernel(static_shapes=True) def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + """ + Performs Root Mean Square (RMS) normalization on the input tensor. + + RMS normalization normalizes by the root mean square of the elements: + output = x / sqrt(mean(x^2) + eps) * weight + + Args: + x: Input tensor of shape [M, N] + weight: Scale parameter of shape [N] + eps: Small constant for numerical stability + + Returns: + Output tensor of shape [M, N] with RMS normalization applied + """ m, n = x.size() assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}" @@ -50,12 +64,27 @@ def rms_norm_pytorch( def check(m: int, n: int) -> None: + """ + Verify the RMS norm kernel implementation against the PyTorch reference implementation. + + Args: + m: First dimension of the test tensor + n: Second dimension of the test tensor + """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) weight = torch.randn([n], device="cuda", dtype=torch.float16) run_example(rms_norm, rms_norm_pytorch, (x, weight, 1e-5)) def main() -> None: + """ + Main entry point that runs the RMS norm kernel verification with different tensor sizes. + + Tests with three configurations: + - 32x64 + - 128x256 + - 1024x1024 + """ check(32, 64) check(128, 256) check(1024, 1024) diff --git a/examples/segment_reduction.py b/examples/segment_reduction.py index 32792de3..a1a675c2 100644 --- a/examples/segment_reduction.py +++ b/examples/segment_reduction.py @@ -1,14 +1,13 @@ # Code based on https://github.com/pytorch-labs/helion/issues/237 from __future__ import annotations +import helion +import helion.language as hl + import torch import triton import triton.language as tl - -import helion -from helion._testing import DEVICE -from helion._testing import run_example -import helion.language as hl +from helion._testing import DEVICE, run_example def combine_fn_helion( @@ -17,6 +16,20 @@ def combine_fn_helion( right_values: torch.Tensor, right_indices: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Combine function for associative scan in Helion implementation. + + Adds values when indices match (same segment), otherwise takes the right value. + + Args: + left_values: Values from the left side of the scan + left_indices: Indices from the left side of the scan + right_values: Values from the right side of the scan + right_indices: Indices from the right side of the scan + + Returns: + Tuple of (combined_values, right_indices) + """ combined_values = torch.where( left_indices == right_indices, left_values + right_values, right_values ) @@ -27,6 +40,19 @@ def combine_fn_helion( def segmented_reduction_helion( indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int ) -> torch.Tensor: + """ + Performs segmented reduction using Helion. + + Reduces input data by summing values with the same index. + + Args: + indices: Tensor of segment indices for each element + input_data: Input tensor of shape [num_elements, num_features] + num_nodes: Number of output nodes/segments + + Returns: + Output tensor of shape [num_nodes, num_features] with reduced values + """ num_elements, num_features = input_data.shape output = torch.zeros( (num_nodes, num_features), dtype=input_data.dtype, device=input_data.device @@ -54,6 +80,20 @@ def combine_fn_triton( right_values: tl.tensor, right_indices: tl.tensor, ) -> tuple[tl.tensor, tl.tensor]: + """ + Combine function for associative scan in Triton implementation. + + Adds values when indices match (same segment), otherwise takes the right value. + + Args: + left_values: Values from the left side of the scan + left_indices: Indices from the left side of the scan + right_values: Values from the right side of the scan + right_indices: Indices from the right side of the scan + + Returns: + Tuple of (combined_values, combined_indices) + """ same_segment = left_indices == right_indices combined_values = tl.where(same_segment, left_values + right_values, right_values) combined_indices = right_indices @@ -79,6 +119,19 @@ def _segmented_reduction_triton( C: tl.constexpr, # Number of features in the input tensor (2d) BLOCK_SIZE: tl.constexpr, # Block size for the scan ) -> None: + """ + Triton kernel for segmented reduction. + + Uses associative scan to efficiently perform segmented reduction. + + Args: + index: Input index tensor + in_ptr: Input data tensor + out_ptr: Output tensor + E: Number of elements in the input tensor + C: Number of features in the input tensor + BLOCK_SIZE: Block size for the scan + """ # Triton version adapted from # https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py pid = tl.program_id(axis=0) @@ -109,6 +162,19 @@ def _segmented_reduction_triton( def segmented_reduction_triton( indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int ) -> torch.Tensor: + """ + Performs segmented reduction using Triton. + + Wrapper function for the Triton kernel implementation. + + Args: + indices: Tensor of segment indices for each element + input_data: Input tensor of shape [num_elements, num_features] + num_nodes: Number of output nodes/segments + + Returns: + Output tensor of shape [num_nodes, num_features] with reduced values + """ E, C = input_data.shape output = torch.zeros( (num_nodes, C), dtype=input_data.dtype, device=input_data.device @@ -124,6 +190,19 @@ def grid(META: dict[str, int]) -> tuple[int, ...]: def segmented_reduction_pytorch( indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int ) -> torch.Tensor: + """ + Performs segmented reduction using PyTorch's scatter_add. + + Reference implementation using PyTorch's native operations. + + Args: + indices: Tensor of segment indices for each element + input_data: Input tensor of shape [num_elements, num_features] + num_nodes: Number of output nodes/segments + + Returns: + Output tensor of shape [num_nodes, num_features] with reduced values + """ # Run PyTorch reference (scatter_add equivalent) num_features = input_data.size(1) pytorch_output = torch.zeros( @@ -136,6 +215,12 @@ def segmented_reduction_pytorch( def main() -> None: + """ + Main entry point that runs the segmented reduction implementations. + + Creates random data with 100 nodes, 2000 edges, and 128 features, + then compares the Helion implementation against Triton and PyTorch. + """ num_nodes = 100 num_edges = 2000 num_features = 128 diff --git a/examples/softmax.py b/examples/softmax.py index e8dcdcf1..614f8826 100644 --- a/examples/softmax.py +++ b/examples/softmax.py @@ -1,14 +1,23 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + @helion.kernel() def softmax(x: torch.Tensor) -> torch.Tensor: + """ + Performs softmax operation along dimension 1 using PyTorch's built-in softmax. + + Args: + x: Input tensor of shape [N, M] + + Returns: + Output tensor of shape [N, M] with softmax applied along dimension 1 + """ n, _m = x.size() out = torch.empty_like(x) for tile_n in hl.tile(n): @@ -19,6 +28,20 @@ def softmax(x: torch.Tensor) -> torch.Tensor: # This generates the same code as the above, but avoids using the pytorch softmax decomposition @helion.kernel() def softmax_decomposed(x: torch.Tensor) -> torch.Tensor: + """ + Performs softmax operation along dimension 1 using manual decomposition. + + Implements the softmax algorithm step by step: + 1. Find the maximum value for numerical stability + 2. Subtract the maximum and compute exponentials + 3. Normalize by the sum of exponentials + + Args: + x: Input tensor of shape [N, M] + + Returns: + Output tensor of shape [N, M] with softmax applied along dimension 1 + """ n, _m = x.size() out = torch.empty_like(x) for tile_n in hl.tile(n): @@ -33,6 +56,18 @@ def softmax_decomposed(x: torch.Tensor) -> torch.Tensor: # This optimization does softmax in fewer passes, but is less numerically stable @helion.kernel() def softmax_two_pass(x: torch.Tensor) -> torch.Tensor: + """ + Performs softmax operation in two passes for better performance. + + This optimized version computes softmax with fewer passes over the data, + trading some numerical stability for performance. + + Args: + x: Input tensor of shape [M, N] + + Returns: + Output tensor of shape [M, N] with softmax applied along dimension 1 + """ m, n = x.size() out = torch.empty_like(x) block_size_m = hl.register_block_size(m) @@ -55,6 +90,13 @@ def softmax_two_pass(x: torch.Tensor) -> torch.Tensor: def check(m: int, n: int) -> None: + """ + Verify the softmax kernel implementations against PyTorch's native softmax function. + + Args: + m: First dimension of the test tensor + n: Second dimension of the test tensor + """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) kernels = { "helion simple": softmax, @@ -65,6 +107,9 @@ def check(m: int, n: int) -> None: def main() -> None: + """ + Main entry point that runs the softmax kernel verification with a 1024x1024 tensor. + """ check(1024, 1024) diff --git a/examples/sum.py b/examples/sum.py index 3def1af2..2be680e8 100644 --- a/examples/sum.py +++ b/examples/sum.py @@ -1,11 +1,11 @@ from __future__ import annotations -import torch - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example + @helion.kernel() def sum_kernel(x: torch.Tensor) -> torch.Tensor: @@ -30,12 +30,26 @@ def sum_tritonbench(x: torch.Tensor) -> torch.Tensor: def check(m: int, n: int) -> None: + """ + Verify the sum kernel implementation against PyTorch's native sum function. + + Args: + m: First dimension of the test tensor + n: Second dimension of the test tensor + """ x = torch.randn([m, n], device="cuda", dtype=torch.float32) kernels = {"helion": sum_kernel} run_example(kernels, lambda x: x.sum(-1), (x,)) def main() -> None: + """ + Main entry point that runs the sum kernel verification with different tensor sizes. + + Tests with two configurations: + - 512x256 + - 1024x1024 + """ check(512, 256) check(1024, 1024) diff --git a/examples/template_via_closure.py b/examples/template_via_closure.py index 471fdf42..c5e98499 100644 --- a/examples/template_via_closure.py +++ b/examples/template_via_closure.py @@ -2,13 +2,13 @@ from typing import TYPE_CHECKING -import torch -from torch import Tensor - import helion -from helion._testing import run_example import helion.language as hl +import torch +from helion._testing import run_example +from torch import Tensor + if TYPE_CHECKING: from collections.abc import Callable @@ -35,6 +35,17 @@ def matmul_with_epilogue( def autotune(n: int, k: int, m: int) -> None: + """ + Autotunes the matmul_with_epilogue kernel and saves the best configuration. + + Creates random tensors and runs the autotuning process to find the optimal + configuration for the kernel with the given dimensions. + + Args: + n: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + m: Second dimension of the second matrix + """ x = torch.randn([n, k], device="cuda", dtype=torch.float16) y = torch.randn([k, m], device="cuda", dtype=torch.float16) bias = torch.randn([1, m], device="cuda", dtype=torch.float16) @@ -45,6 +56,16 @@ def autotune(n: int, k: int, m: int) -> None: def check(n: int, k: int, m: int) -> None: + """ + Verify the matmul_with_epilogue kernel implementation against a PyTorch baseline. + + Tests matrix multiplication with a ReLU + bias epilogue function. + + Args: + n: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + m: Second dimension of the second matrix + """ x = torch.randn([n, k], device="cuda", dtype=torch.float16) y = torch.randn([k, m], device="cuda", dtype=torch.float16) bias: torch.Tensor = torch.randn([1, m], device="cuda", dtype=torch.float16) @@ -67,6 +88,12 @@ def baseline_wrapper(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def main() -> None: + """ + Main entry point that runs the matmul_with_epilogue kernel verification. + + Tests with 1024x1024 matrices and a ReLU + bias epilogue function. + Uncomment the autotune line to run autotuning instead. + """ # autotune(1024, 1024, 1024) check(1024, 1024, 1024) From f8e3e7b1ea4a8884261a7d9a3f86d825af654ff7 Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Mon, 21 Jul 2025 13:50:58 -0400 Subject: [PATCH 3/9] Update docs/Makefile Co-authored-by: Jason Ansel --- docs/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index e6a46333..6d0ed7a2 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -9,7 +9,7 @@ html: clean genEx livehtml: clean genEx sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) --open-browser --port 0 -genEx: +generate_examples: python GenerateExamples.py clean: rm -rf $(BUILDDIR)/* From d11c0ed06922b602da48aea732807cdf2754c23b Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Mon, 21 Jul 2025 13:51:05 -0400 Subject: [PATCH 4/9] Update docs/Makefile Co-authored-by: Jason Ansel --- docs/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index 6d0ed7a2..1563febc 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -10,7 +10,7 @@ livehtml: clean genEx sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) --open-browser --port 0 generate_examples: - python GenerateExamples.py + python generate_examples.py clean: rm -rf $(BUILDDIR)/* From fb673a707dfcaef9af5bf11a89444a599023db25 Mon Sep 17 00:00:00 2001 From: sekyonda <127536312+sekyondaMeta@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:59:37 -0400 Subject: [PATCH 5/9] GeneratePython and lint updates --- docs/GenerateExamples.py | 40 ----------------------- docs/Makefile | 4 +-- docs/examples.rst | 2 +- docs/generate_examples.py | 54 ++++++++++++++++++++++++++++++++ docs/helion_puzzles.rst | 8 ++--- examples/add.py | 6 ++-- examples/all_gather_matmul.py | 12 +++---- examples/attention.py | 11 ++++--- examples/bmm.py | 6 ++-- examples/concatenate.py | 6 ++-- examples/cross_entropy.py | 6 ++-- examples/embedding.py | 6 ++-- examples/exp.py | 6 ++-- examples/fp8_attention.py | 4 +-- examples/jagged_dense_add.py | 6 ++-- examples/jagged_mean.py | 10 +++--- examples/long_sum.py | 6 ++-- examples/matmul.py | 6 ++-- examples/matmul_layernorm.py | 6 ++-- examples/matmul_split_k.py | 6 ++-- examples/moe_matmul_ogs.py | 6 ++-- examples/rms_norm.py | 6 ++-- examples/segment_reduction.py | 9 +++--- examples/softmax.py | 6 ++-- examples/sum.py | 6 ++-- examples/template_via_closure.py | 8 ++--- 26 files changed, 132 insertions(+), 120 deletions(-) delete mode 100644 docs/GenerateExamples.py create mode 100644 docs/generate_examples.py diff --git a/docs/GenerateExamples.py b/docs/GenerateExamples.py deleted file mode 100644 index c6dd41fd..00000000 --- a/docs/GenerateExamples.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -EXAMPLES_DIR = '../../examples' # Adjust as needed -RST_DIR = './examples' # Relative to your Sphinx source dir -example_files = [ - 'add.py', - 'all_gather_matmul.py', - 'attention.py', - 'bmm.py', - 'concatenate.py', - 'cross_entropy.py', - 'embedding.py', - 'exp.py', - 'fp8_attention.py', - 'fp8_gemm.py', - 'jagged_dense_add.py', - 'jagged_mean.py', - 'long_sum.py', - 'matmul.py', - 'matmul_layernorm.py', - 'matmul_split_k.py', - 'moe_matmul_ogs.py', - 'rms_norm.py', - 'segment_reduction.py', - 'softmax.py', - 'sum.py', - 'template_via_closure.py', -] -os.makedirs(RST_DIR, exist_ok=True) -for fname in example_files: - base = os.path.splitext(fname)[0] - # Capitalize and replace underscores with spaces for nicer titles - title = base.replace('_', ' ').title() - rst_path = os.path.join(RST_DIR, f"{base}.rst") - with open(rst_path, "w") as f: - f.write(f"""{title} -{'=' * len(title)} -.. literalinclude:: {os.path.join(EXAMPLES_DIR, fname)} - :language: python - :linenos: -""") diff --git a/docs/Makefile b/docs/Makefile index 1563febc..cfce74cc 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,10 +3,10 @@ SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = ../site -html: clean genEx +html: clean generate_examples @$(SPHINXBUILD) -b html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -livehtml: clean genEx +livehtml: clean generate_examples sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) --open-browser --port 0 generate_examples: diff --git a/docs/examples.rst b/docs/examples.rst index d277e8e1..ad74a9c3 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -1,7 +1,7 @@ Examples ======== -Examples showing the use of Helios in various scenarios. +Examples showing the use of Helion in various scenarios. .. toctree:: :maxdepth: 1 diff --git a/docs/generate_examples.py b/docs/generate_examples.py new file mode 100644 index 00000000..0e89446f --- /dev/null +++ b/docs/generate_examples.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import glob +import os + +# Configuration +EXAMPLES_DIR = "../examples" # Path to examples directory +RST_DIR = "./examples" # Directory for individual RST files +EXAMPLES_RST_PATH = "./examples.rst" # Path to main examples.rst file + +# Create the examples directory if it doesn't exist +os.makedirs(RST_DIR, exist_ok=True) + +# Get all Python files in the examples directory +example_files = [ + os.path.basename(f) for f in glob.glob(os.path.join(EXAMPLES_DIR, "*.py")) +] +example_files.sort() # Sort files alphabetically + +# Generate individual RST files for each example +for fname in example_files: + base = os.path.splitext(fname)[0] + # Capitalize and replace underscores with spaces for nicer titles + title = base.replace("_", " ").title() + rst_path = os.path.join(RST_DIR, f"{base}.rst") + with open(rst_path, "w") as f: + f.write( + f"""{title} +{"=" * len(title)} +.. literalinclude:: ../../examples/{fname} + :language: python + :linenos: +""" + ) + +# Generate the main examples.rst file with toctree +with open(EXAMPLES_RST_PATH, "w") as f: + f.write( + """Examples +======== + +Examples showing the use of Helios in various scenarios. + +.. toctree:: + :maxdepth: 1 + +""" + ) + # Add each example to the toctree + for fname in example_files: + base = os.path.splitext(fname)[0] + f.write(f" examples/{base}\n") + +print(f"Generated {len(example_files)} example RST files and updated examples.rst") diff --git a/docs/helion_puzzles.rst b/docs/helion_puzzles.rst index 91d66268..0a25d5ee 100644 --- a/docs/helion_puzzles.rst +++ b/docs/helion_puzzles.rst @@ -57,17 +57,17 @@ Let's also create a simple testing function to verify our implementations. Basic Structure of a Helion Kernel --------------------------------- -Helion allows you to write GPU kernels using familiar PyTorch syntax. +Helion allows you to write GPU kernels using familiar PyTorch syntax. A Helion kernel has three main sections: -1. **Host Section** (CPU) +1. **Host Section** (CPU) This is standard PyTorch code executed on the CPU. Memory allocation, and shape computations are done here. Like with `Triton` and `Cuda` you need to setup your output buffers on the host before launching your kernel. -2. **Device Loop** (GPU Grid) +2. **Device Loop** (GPU Grid) `for tile in hl.tile(sizes)` - defines parallel execution across GPU thread blocks -3. **Device Operations** (GPU Kernel) +3. **Device Operations** (GPU Kernel) PyTorch operations inside the loop - automatically compiled and fused Example: diff --git a/examples/add.py b/examples/add.py index c937f4ec..4fddeea3 100644 --- a/examples/add.py +++ b/examples/add.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl @helion.kernel() diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py index c0243fda..49665cbe 100644 --- a/examples/all_gather_matmul.py +++ b/examples/all_gather_matmul.py @@ -2,13 +2,13 @@ import os -import helion -import helion.language as hl - import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem +import helion +import helion.language as hl + def copy_engine_all_gather_w_progress( output: torch.Tensor, @@ -232,10 +232,8 @@ def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None: dist_group = dist.group.WORLD if dist_group is None: raise RuntimeError("No distributed group available") - ag_golden, mm_golden = ( - torch.ops.symm_mem.fused_all_gather_matmul( # pyright: ignore[reportCallIssue] - golden_a, [b], gather_dim=0, group_name=dist_group.group_name - ) + ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( # pyright: ignore[reportCallIssue] + golden_a, [b], gather_dim=0, group_name=dist_group.group_name ) torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1) torch.testing.assert_close(a_out, ag_golden) diff --git a/examples/attention.py b/examples/attention.py index 21ff9738..29c9e47d 100644 --- a/examples/attention.py +++ b/examples/attention.py @@ -1,15 +1,16 @@ from __future__ import annotations import math -from typing import Callable, cast - -import helion -import helion.language as hl +from typing import Callable +from typing import cast import torch -from helion._testing import run_example from torch.nn.attention.flex_attention import flex_attention +import helion +from helion._testing import run_example +import helion.language as hl + @helion.kernel( # Static shapes provides a speedup for attention diff --git a/examples/bmm.py b/examples/bmm.py index 2ceb1c86..c6b57433 100644 --- a/examples/bmm.py +++ b/examples/bmm.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl # static_shapes=True gives a performance boost for matmuls diff --git a/examples/concatenate.py b/examples/concatenate.py index 439a5dcb..89bd2742 100644 --- a/examples/concatenate.py +++ b/examples/concatenate.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl @helion.kernel() diff --git a/examples/cross_entropy.py b/examples/cross_entropy.py index d285566c..25d2ed1b 100644 --- a/examples/cross_entropy.py +++ b/examples/cross_entropy.py @@ -2,11 +2,11 @@ import os -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl # TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": diff --git a/examples/embedding.py b/examples/embedding.py index b3c61ca8..c169be15 100644 --- a/examples/embedding.py +++ b/examples/embedding.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl @helion.kernel() diff --git a/examples/exp.py b/examples/exp.py index ad32c216..f87b72f2 100644 --- a/examples/exp.py +++ b/examples/exp.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl @helion.kernel() diff --git a/examples/fp8_attention.py b/examples/fp8_attention.py index 71c5a622..be4df7fd 100644 --- a/examples/fp8_attention.py +++ b/examples/fp8_attention.py @@ -3,11 +3,11 @@ import math from typing import Callable +import torch + import helion import helion.language as hl -import torch - @helion.kernel(static_shapes=True) def fp8_attention_kernel( diff --git a/examples/jagged_dense_add.py b/examples/jagged_dense_add.py index b795db72..8798029e 100644 --- a/examples/jagged_dense_add.py +++ b/examples/jagged_dense_add.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl """ A tensor x is stored in a jagged-row, prefix-sparse layout that packs only the non-zero diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index a7eb263a..0e811231 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -2,11 +2,11 @@ import os -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl # TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": @@ -139,9 +139,7 @@ def jagged_mean_tritonbench( dtype=torch.int32, device=x_values.device, # pyright: ignore[reportAttributeAccessIssue] ) - max_M_tensor = torch.empty( - M, device=x_values.device - ) # pyright: ignore[reportAttributeAccessIssue] + max_M_tensor = torch.empty(M, device=x_values.device) # pyright: ignore[reportAttributeAccessIssue] return jagged_mean_kernel(x_values, x_offsets, feature_counts, max_M_tensor) diff --git a/examples/long_sum.py b/examples/long_sum.py index 47948920..e8a40967 100644 --- a/examples/long_sum.py +++ b/examples/long_sum.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl def baseline_sum(x: torch.Tensor) -> torch.Tensor: diff --git a/examples/matmul.py b/examples/matmul.py index 93c87965..75c8e7fe 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl # static_shapes=True gives a performance boost for matmuls diff --git a/examples/matmul_layernorm.py b/examples/matmul_layernorm.py index fc70b1f3..81075c85 100644 --- a/examples/matmul_layernorm.py +++ b/examples/matmul_layernorm.py @@ -1,11 +1,11 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch import torch.nn.functional as F + +import helion from helion._testing import run_example +import helion.language as hl # static_shapes=True gives a performance boost for matmuls diff --git a/examples/matmul_split_k.py b/examples/matmul_split_k.py index 11c118e5..f8556c4e 100644 --- a/examples/matmul_split_k.py +++ b/examples/matmul_split_k.py @@ -1,11 +1,11 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example from helion.autotuner import PowerOfTwoFragment +import helion.language as hl # static_shapes=True gives a performance boost for matmuls diff --git a/examples/moe_matmul_ogs.py b/examples/moe_matmul_ogs.py index 63c29328..9ffafe6f 100644 --- a/examples/moe_matmul_ogs.py +++ b/examples/moe_matmul_ogs.py @@ -4,11 +4,11 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl @helion.kernel(static_shapes=False) diff --git a/examples/rms_norm.py b/examples/rms_norm.py index cedc17de..ae1dde18 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl # TritonBench configuration # TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg. diff --git a/examples/segment_reduction.py b/examples/segment_reduction.py index a1a675c2..4480b6df 100644 --- a/examples/segment_reduction.py +++ b/examples/segment_reduction.py @@ -1,13 +1,14 @@ # Code based on https://github.com/pytorch-labs/helion/issues/237 from __future__ import annotations -import helion -import helion.language as hl - import torch import triton import triton.language as tl -from helion._testing import DEVICE, run_example + +import helion +from helion._testing import DEVICE +from helion._testing import run_example +import helion.language as hl def combine_fn_helion( diff --git a/examples/softmax.py b/examples/softmax.py index 614f8826..2cad297e 100644 --- a/examples/softmax.py +++ b/examples/softmax.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl @helion.kernel() diff --git a/examples/sum.py b/examples/sum.py index 2be680e8..43746333 100644 --- a/examples/sum.py +++ b/examples/sum.py @@ -1,10 +1,10 @@ from __future__ import annotations -import helion -import helion.language as hl - import torch + +import helion from helion._testing import run_example +import helion.language as hl @helion.kernel() diff --git a/examples/template_via_closure.py b/examples/template_via_closure.py index c5e98499..5b6af8f6 100644 --- a/examples/template_via_closure.py +++ b/examples/template_via_closure.py @@ -2,13 +2,13 @@ from typing import TYPE_CHECKING -import helion -import helion.language as hl - import torch -from helion._testing import run_example from torch import Tensor +import helion +from helion._testing import run_example +import helion.language as hl + if TYPE_CHECKING: from collections.abc import Callable From 74dc53c78828c3520fe444dd49f67d34aa861793 Mon Sep 17 00:00:00 2001 From: sekyonda <127536312+sekyondaMeta@users.noreply.github.com> Date: Thu, 24 Jul 2025 17:56:52 -0400 Subject: [PATCH 6/9] Update to use Sphinx Gallery Changing example generation to use sphinx-gallery --- .gitignore | 1 + docs/Makefile | 8 +-- docs/conf.py | 11 ++++ docs/examples.rst | 30 ---------- docs/examples/add.rst | 5 -- docs/examples/all_gather_matmul.rst | 5 -- docs/examples/attention.rst | 5 -- docs/examples/bmm.rst | 5 -- docs/examples/concatenate.rst | 5 -- docs/examples/cross_entropy.rst | 5 -- docs/examples/embedding.rst | 5 -- docs/examples/exp.rst | 5 -- docs/examples/fp8_attention.rst | 5 -- docs/examples/fp8_gemm.rst | 5 -- docs/examples/jagged_dense_add.rst | 5 -- docs/examples/jagged_mean.rst | 5 -- docs/examples/long_sum.rst | 5 -- docs/examples/matmul.rst | 5 -- docs/examples/matmul_layernorm.rst | 5 -- docs/examples/matmul_split_k.rst | 5 -- docs/examples/moe_matmul_ogs.rst | 5 -- docs/examples/rms_norm.rst | 5 -- docs/examples/segment_reduction.rst | 5 -- docs/examples/softmax.rst | 5 -- docs/examples/sum.rst | 5 -- docs/examples/template_via_closure.rst | 5 -- docs/generate_examples.py | 54 ----------------- docs/index.md | 7 ++- docs/requirements.txt | 5 ++ examples/README.rst | 80 ++++++++++++++++++++++++++ examples/add.py | 19 ++++++ examples/all_gather_matmul.py | 11 ++++ examples/attention.py | 22 +++++++ examples/bmm.py | 19 ++++++ examples/concatenate.py | 16 ++++++ examples/cross_entropy.py | 24 +++++++- examples/embedding.py | 32 ++++++++++- examples/exp.py | 32 ++++++++++- examples/fp8_attention.py | 10 ++++ examples/fp8_gemm.py | 63 ++++++++++++++++++-- examples/jagged_dense_add.py | 73 +++++++++++++++++------ examples/jagged_mean.py | 61 +++++++++++++++----- examples/long_sum.py | 31 +++++++++- examples/matmul.py | 19 ++++++ examples/matmul_layernorm.py | 23 ++++++++ examples/matmul_split_k.py | 20 +++++++ examples/moe_matmul_ogs.py | 65 ++++++++++++++++++++- examples/rms_norm.py | 51 +++++++++++++++- examples/segment_reduction.py | 24 +++++++- examples/softmax.py | 27 ++++++++- examples/sum.py | 42 +++++++++++++- examples/template_via_closure.py | 38 ++++++++++++ 52 files changed, 780 insertions(+), 248 deletions(-) delete mode 100644 docs/examples.rst delete mode 100644 docs/examples/add.rst delete mode 100644 docs/examples/all_gather_matmul.rst delete mode 100644 docs/examples/attention.rst delete mode 100644 docs/examples/bmm.rst delete mode 100644 docs/examples/concatenate.rst delete mode 100644 docs/examples/cross_entropy.rst delete mode 100644 docs/examples/embedding.rst delete mode 100644 docs/examples/exp.rst delete mode 100644 docs/examples/fp8_attention.rst delete mode 100644 docs/examples/fp8_gemm.rst delete mode 100644 docs/examples/jagged_dense_add.rst delete mode 100644 docs/examples/jagged_mean.rst delete mode 100644 docs/examples/long_sum.rst delete mode 100644 docs/examples/matmul.rst delete mode 100644 docs/examples/matmul_layernorm.rst delete mode 100644 docs/examples/matmul_split_k.rst delete mode 100644 docs/examples/moe_matmul_ogs.rst delete mode 100644 docs/examples/rms_norm.rst delete mode 100644 docs/examples/segment_reduction.rst delete mode 100644 docs/examples/softmax.rst delete mode 100644 docs/examples/sum.rst delete mode 100644 docs/examples/template_via_closure.rst delete mode 100644 docs/generate_examples.py create mode 100644 docs/requirements.txt create mode 100644 examples/README.rst diff --git a/.gitignore b/.gitignore index 2dce5632..ea422f57 100644 --- a/.gitignore +++ b/.gitignore @@ -91,3 +91,4 @@ torch benchmarks/tritonbench site generated +docs/examples/ diff --git a/docs/Makefile b/docs/Makefile index cfce74cc..a275d84c 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,16 +3,16 @@ SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = ../site -html: clean generate_examples +html: clean @$(SPHINXBUILD) -b html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -livehtml: clean generate_examples +livehtml: clean sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) --open-browser --port 0 -generate_examples: - python generate_examples.py + clean: rm -rf $(BUILDDIR)/* + rm -rf examples/* # Catch-all target: route all unknown targets to Sphinx-Build using the # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). diff --git a/docs/conf.py b/docs/conf.py index 1f50ac4d..0a5b68fd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,6 +27,7 @@ "sphinx.ext.intersphinx", "myst_parser", "sphinx_autodoc_typehints", + "sphinx_gallery.gen_gallery", ] # MyST parser configuration @@ -44,6 +45,16 @@ "tasklist", ] +sphinx_gallery_conf = { + "examples_dirs": [ + "../examples", + ], # path to your example scripts + "gallery_dirs": "examples", # path to where to save gallery generated output + "filename_pattern": r".*\.py$", # Include all Python files + "ignore_pattern": r"__init__\.py", # Exclude __init__.py files + "plot_gallery": "False", # Don't run the examples +} + # Templates path templates_path = ["_templates"] diff --git a/docs/examples.rst b/docs/examples.rst deleted file mode 100644 index ad74a9c3..00000000 --- a/docs/examples.rst +++ /dev/null @@ -1,30 +0,0 @@ -Examples -======== - -Examples showing the use of Helion in various scenarios. - -.. toctree:: - :maxdepth: 1 - - examples/add - examples/all_gather_matmul - examples/attention - examples/bmm - examples/concatenate - examples/cross_entropy - examples/embedding - examples/exp - examples/fp8_attention - examples/fp8_gemm - examples/jagged_dense_add - examples/jagged_mean - examples/long_sum - examples/matmul - examples/matmul_layernorm - examples/matmul_split_k - examples/moe_matmul_ogs - examples/rms_norm - examples/segment_reduction - examples/softmax - examples/sum - examples/template_via_closure diff --git a/docs/examples/add.rst b/docs/examples/add.rst deleted file mode 100644 index 6c102c02..00000000 --- a/docs/examples/add.rst +++ /dev/null @@ -1,5 +0,0 @@ -Add -=== -.. literalinclude:: ../../examples/add.py - :language: python - :linenos: diff --git a/docs/examples/all_gather_matmul.rst b/docs/examples/all_gather_matmul.rst deleted file mode 100644 index 9f290a5f..00000000 --- a/docs/examples/all_gather_matmul.rst +++ /dev/null @@ -1,5 +0,0 @@ -All Gather Matmul -================= -.. literalinclude:: ../../examples/all_gather_matmul.py - :language: python - :linenos: diff --git a/docs/examples/attention.rst b/docs/examples/attention.rst deleted file mode 100644 index e8a3c1fe..00000000 --- a/docs/examples/attention.rst +++ /dev/null @@ -1,5 +0,0 @@ -Attention -========= -.. literalinclude:: ../../examples/attention.py - :language: python - :linenos: diff --git a/docs/examples/bmm.rst b/docs/examples/bmm.rst deleted file mode 100644 index c971ec71..00000000 --- a/docs/examples/bmm.rst +++ /dev/null @@ -1,5 +0,0 @@ -Bmm -=== -.. literalinclude:: ../../examples/bmm.py - :language: python - :linenos: diff --git a/docs/examples/concatenate.rst b/docs/examples/concatenate.rst deleted file mode 100644 index 417e50c3..00000000 --- a/docs/examples/concatenate.rst +++ /dev/null @@ -1,5 +0,0 @@ -Concatenate -=========== -.. literalinclude:: ../../examples/concatenate.py - :language: python - :linenos: diff --git a/docs/examples/cross_entropy.rst b/docs/examples/cross_entropy.rst deleted file mode 100644 index 3c3bfe98..00000000 --- a/docs/examples/cross_entropy.rst +++ /dev/null @@ -1,5 +0,0 @@ -Cross Entropy -============= -.. literalinclude:: ../../examples/cross_entropy.py - :language: python - :linenos: diff --git a/docs/examples/embedding.rst b/docs/examples/embedding.rst deleted file mode 100644 index 97b585a6..00000000 --- a/docs/examples/embedding.rst +++ /dev/null @@ -1,5 +0,0 @@ -Embedding -========= -.. literalinclude:: ../../examples/embedding.py - :language: python - :linenos: diff --git a/docs/examples/exp.rst b/docs/examples/exp.rst deleted file mode 100644 index 1b7a6be8..00000000 --- a/docs/examples/exp.rst +++ /dev/null @@ -1,5 +0,0 @@ -Exp -=== -.. literalinclude:: ../../examples/exp.py - :language: python - :linenos: diff --git a/docs/examples/fp8_attention.rst b/docs/examples/fp8_attention.rst deleted file mode 100644 index 67917be6..00000000 --- a/docs/examples/fp8_attention.rst +++ /dev/null @@ -1,5 +0,0 @@ -Fp8 Attention -============= -.. literalinclude:: ../../examples/fp8_attention.py - :language: python - :linenos: diff --git a/docs/examples/fp8_gemm.rst b/docs/examples/fp8_gemm.rst deleted file mode 100644 index ae4e4230..00000000 --- a/docs/examples/fp8_gemm.rst +++ /dev/null @@ -1,5 +0,0 @@ -Fp8 Gemm -======== -.. literalinclude:: ../../examples/fp8_gemm.py - :language: python - :linenos: diff --git a/docs/examples/jagged_dense_add.rst b/docs/examples/jagged_dense_add.rst deleted file mode 100644 index 26909aeb..00000000 --- a/docs/examples/jagged_dense_add.rst +++ /dev/null @@ -1,5 +0,0 @@ -Jagged Dense Add -================ -.. literalinclude:: ../../examples/jagged_dense_add.py - :language: python - :linenos: diff --git a/docs/examples/jagged_mean.rst b/docs/examples/jagged_mean.rst deleted file mode 100644 index 638935f6..00000000 --- a/docs/examples/jagged_mean.rst +++ /dev/null @@ -1,5 +0,0 @@ -Jagged Mean -=========== -.. literalinclude:: ../../examples/jagged_mean.py - :language: python - :linenos: diff --git a/docs/examples/long_sum.rst b/docs/examples/long_sum.rst deleted file mode 100644 index ae71aa7b..00000000 --- a/docs/examples/long_sum.rst +++ /dev/null @@ -1,5 +0,0 @@ -Long Sum -======== -.. literalinclude:: ../../examples/long_sum.py - :language: python - :linenos: diff --git a/docs/examples/matmul.rst b/docs/examples/matmul.rst deleted file mode 100644 index 9c07aeb5..00000000 --- a/docs/examples/matmul.rst +++ /dev/null @@ -1,5 +0,0 @@ -Matmul -====== -.. literalinclude:: ../../examples/matmul.py - :language: python - :linenos: diff --git a/docs/examples/matmul_layernorm.rst b/docs/examples/matmul_layernorm.rst deleted file mode 100644 index 05e71050..00000000 --- a/docs/examples/matmul_layernorm.rst +++ /dev/null @@ -1,5 +0,0 @@ -Matmul Layernorm -================ -.. literalinclude:: ../../examples/matmul_layernorm.py - :language: python - :linenos: diff --git a/docs/examples/matmul_split_k.rst b/docs/examples/matmul_split_k.rst deleted file mode 100644 index 7e40a33e..00000000 --- a/docs/examples/matmul_split_k.rst +++ /dev/null @@ -1,5 +0,0 @@ -Matmul Split K -============== -.. literalinclude:: ../../examples/matmul_split_k.py - :language: python - :linenos: diff --git a/docs/examples/moe_matmul_ogs.rst b/docs/examples/moe_matmul_ogs.rst deleted file mode 100644 index f9038cea..00000000 --- a/docs/examples/moe_matmul_ogs.rst +++ /dev/null @@ -1,5 +0,0 @@ -Moe Matmul Ogs -============== -.. literalinclude:: ../../examples/moe_matmul_ogs.py - :language: python - :linenos: diff --git a/docs/examples/rms_norm.rst b/docs/examples/rms_norm.rst deleted file mode 100644 index dc3789c6..00000000 --- a/docs/examples/rms_norm.rst +++ /dev/null @@ -1,5 +0,0 @@ -Rms Norm -======== -.. literalinclude:: ../../examples/rms_norm.py - :language: python - :linenos: diff --git a/docs/examples/segment_reduction.rst b/docs/examples/segment_reduction.rst deleted file mode 100644 index 69d72594..00000000 --- a/docs/examples/segment_reduction.rst +++ /dev/null @@ -1,5 +0,0 @@ -Segment Reduction -================= -.. literalinclude:: ../../examples/segment_reduction.py - :language: python - :linenos: diff --git a/docs/examples/softmax.rst b/docs/examples/softmax.rst deleted file mode 100644 index 4f07e1c6..00000000 --- a/docs/examples/softmax.rst +++ /dev/null @@ -1,5 +0,0 @@ -Softmax -======= -.. literalinclude:: ../../examples/softmax.py - :language: python - :linenos: diff --git a/docs/examples/sum.rst b/docs/examples/sum.rst deleted file mode 100644 index 438df2b7..00000000 --- a/docs/examples/sum.rst +++ /dev/null @@ -1,5 +0,0 @@ -Sum -=== -.. literalinclude:: ../../examples/sum.py - :language: python - :linenos: diff --git a/docs/examples/template_via_closure.rst b/docs/examples/template_via_closure.rst deleted file mode 100644 index fcf278b8..00000000 --- a/docs/examples/template_via_closure.rst +++ /dev/null @@ -1,5 +0,0 @@ -Template Via Closure -==================== -.. literalinclude:: ../../examples/template_via_closure.py - :language: python - :linenos: diff --git a/docs/generate_examples.py b/docs/generate_examples.py deleted file mode 100644 index 0e89446f..00000000 --- a/docs/generate_examples.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -import glob -import os - -# Configuration -EXAMPLES_DIR = "../examples" # Path to examples directory -RST_DIR = "./examples" # Directory for individual RST files -EXAMPLES_RST_PATH = "./examples.rst" # Path to main examples.rst file - -# Create the examples directory if it doesn't exist -os.makedirs(RST_DIR, exist_ok=True) - -# Get all Python files in the examples directory -example_files = [ - os.path.basename(f) for f in glob.glob(os.path.join(EXAMPLES_DIR, "*.py")) -] -example_files.sort() # Sort files alphabetically - -# Generate individual RST files for each example -for fname in example_files: - base = os.path.splitext(fname)[0] - # Capitalize and replace underscores with spaces for nicer titles - title = base.replace("_", " ").title() - rst_path = os.path.join(RST_DIR, f"{base}.rst") - with open(rst_path, "w") as f: - f.write( - f"""{title} -{"=" * len(title)} -.. literalinclude:: ../../examples/{fname} - :language: python - :linenos: -""" - ) - -# Generate the main examples.rst file with toctree -with open(EXAMPLES_RST_PATH, "w") as f: - f.write( - """Examples -======== - -Examples showing the use of Helios in various scenarios. - -.. toctree:: - :maxdepth: 1 - -""" - ) - # Add each example to the toctree - for fname in example_files: - base = os.path.splitext(fname)[0] - f.write(f" examples/{base}\n") - -print(f"Generated {len(example_files)} example RST files and updated examples.rst") diff --git a/docs/index.md b/docs/index.md index 9194697e..039149f8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,14 +1,15 @@ # Helion Documentation + ```{toctree} -:maxdepth: 1 -:caption: Table of Contents: +:maxdepth: 2 +:caption: Contents :hidden: installation +./examples/index helion_puzzles -examples api/index ``` diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..e73fae30 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,5 @@ +sphinx>=7.0.0 +myst-parser>=2.0.0 +sphinx-autodoc-typehints>=1.24.0 +sphinx-rtd-theme>=1.3.0 +sphinx_gallery.gen_gallery diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 00000000..2e32391c --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,80 @@ +Helion Examples +============== + +This directory contains examples demonstrating how to use Helion for high-performance tensor operations. +The examples are organized into the following categories: + +Basic Operations +~~~~~~~~~~~~~~~ + +- ``add.py``: Element-wise addition with broadcasting support +- ``exp.py``: Element-wise exponential function +- ``sum.py``: Sum reduction along the last dimension +- ``long_sum.py``: Efficient sum reduction along a long dimension +- ``softmax.py``: Different implementations of the softmax function + +Matrix Operations +~~~~~~~~~~~~~~~~ + +- ``matmul.py``: Basic matrix multiplication +- ``bmm.py``: Batch matrix multiplication +- ``matmul_split_k.py``: Matrix multiplication using split-K algorithm for better parallelism +- ``matmul_layernorm.py``: Fused matrix multiplication and layer normalization +- ``fp8_gemm.py``: Matrix multiplication using FP8 precision + +Attention Mechanisms +~~~~~~~~~~~~~~~~~~~ + +- ``attention.py``: Scaled dot-product attention mechanism +- ``fp8_attention.py``: Attention mechanism using FP8 precision + +Normalization +~~~~~~~~~~~~ + +- ``rms_norm.py``: Root Mean Square (RMS) normalization + +Sparse and Jagged Tensors +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``jagged_dense_add.py``: Addition between a jagged tensor and a dense tensor +- ``jagged_mean.py``: Computing the mean of each row in a jagged tensor +- ``segment_reduction.py``: Segmented reduction operation +- ``moe_matmul_ogs.py``: Mixture-of-Experts matrix multiplication using Outer-Gather-Scatter + +Other Operations +~~~~~~~~~~~~~~~ + +- ``concatenate.py``: Tensor concatenation along a dimension +- ``cross_entropy.py``: Cross entropy loss function +- ``embedding.py``: Embedding lookup operation +- ``all_gather_matmul.py``: All-gather operation followed by matrix multiplication +- ``template_via_closure.py``: Templated matrix multiplication with customizable epilogue function + + +.. toctree:: + :maxdepth: 2 + :caption: Contents + :hidden: + + add + all_gather_matmul + attention + bmm + concatenate + cross_entropy + embedding + exp + fp8_attention + fp8_gemm + jagged_dense_add + jagged_mean + long_sum + matmul + matmul_layernorm + matmul_split_k + moe_matmul_ogs + rms_norm + segment_reduction + softmax + sum + template_via_closure diff --git a/examples/add.py b/examples/add.py index 4fddeea3..897224df 100644 --- a/examples/add.py +++ b/examples/add.py @@ -1,3 +1,13 @@ +""" +Element-wise Addition Example +=========================== + +This example demonstrates how to implement an element-wise addition kernel using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,6 +17,9 @@ import helion.language as hl +# %% +# Addition Kernel +# -------------- @helion.kernel() def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ @@ -33,6 +46,9 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: """ Verify the add kernel implementation against PyTorch's native add function. @@ -46,6 +62,9 @@ def check(m: int, n: int) -> None: run_example(add, torch.add, (x, y)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the add kernel verification with 1024x1024 tensors. diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py index 49665cbe..8e3deba0 100644 --- a/examples/all_gather_matmul.py +++ b/examples/all_gather_matmul.py @@ -1,3 +1,14 @@ +""" +All-Gather Matrix Multiplication Example +===============================>>>>>>> REPLACE + +This example demonstrates how to implement an all-gather operation followed by matrix multiplication +using Helion and PyTorch's distributed capabilities. +""" + +# %% +# Imports +# ------- from __future__ import annotations import os diff --git a/examples/attention.py b/examples/attention.py index 29c9e47d..1afbdb79 100644 --- a/examples/attention.py +++ b/examples/attention.py @@ -1,3 +1,13 @@ +""" +Attention Mechanism Example +======================== + +This example demonstrates how to implement a scaled dot-product attention mechanism using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import math @@ -12,6 +22,9 @@ import helion.language as hl +# %% +# Attention Kernel Implementation +# ---------------------------- @helion.kernel( # Static shapes provides a speedup for attention static_shapes=True, @@ -70,6 +83,9 @@ def attention( return out.view(q_in.size()) +# %% +# Dynamic Shape Version +# ------------------ attention_dynamic: object = helion.kernel( # pyright: ignore[reportCallIssue] attention.fn, configs=attention.configs, # pyright: ignore[reportArgumentType] @@ -81,6 +97,9 @@ def attention( """ +# %% +# Testing Function +# ------------- def test( z: int, h: int, @@ -125,6 +144,9 @@ def ref_attention( run_example(attention, baselines, (q, k, v)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the attention kernel test with specific parameters. diff --git a/examples/bmm.py b/examples/bmm.py index c6b57433..c5007345 100644 --- a/examples/bmm.py +++ b/examples/bmm.py @@ -1,3 +1,13 @@ +""" +Batch Matrix Multiplication Example +=============================== + +This example demonstrates how to implement a batch matrix multiplication kernel using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,6 +17,9 @@ import helion.language as hl +# %% +# Batch Matrix Multiplication Kernel +# ------------------------------- # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: @@ -36,6 +49,9 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(b: int, m: int, k: int, n: int) -> None: """ Verify the bmm kernel implementation against PyTorch's native bmm function. @@ -51,6 +67,9 @@ def check(b: int, m: int, k: int, n: int) -> None: run_example(bmm, torch.bmm, (x, y)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the bmm kernel verification with specific parameters. diff --git a/examples/concatenate.py b/examples/concatenate.py index 89bd2742..34035249 100644 --- a/examples/concatenate.py +++ b/examples/concatenate.py @@ -1,3 +1,13 @@ +""" +Tensor Concatenation Example +======================== + +This example demonstrates how to implement a tensor concatenation operation using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,6 +17,9 @@ import helion.language as hl +# %% +# Concatenation Kernel +# ----------------- @helion.kernel() def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ @@ -39,6 +52,9 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the concatenation kernel verification. diff --git a/examples/cross_entropy.py b/examples/cross_entropy.py index 25d2ed1b..b91acb92 100644 --- a/examples/cross_entropy.py +++ b/examples/cross_entropy.py @@ -1,3 +1,13 @@ +""" +Cross Entropy Loss Example +====================== + +This example demonstrates how to implement a cross entropy loss function using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import os @@ -8,12 +18,18 @@ from helion._testing import run_example import helion.language as hl +# %% +# Configuration +# ----------- # TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": # Low memory configuration TRITONBENCH_ARGS = {"B": 4, "T": 512, "v_range": "10,15"} +# %% +# Cross Entropy Kernel +# ----------------- @helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper]) def cross_entropy( logits: torch.Tensor, # [N, V] input logits @@ -67,8 +83,14 @@ def cross_entropy( return losses.mean() +# %% +# Main Function +# ----------- def main() -> None: - """Run cross entropy benchmark with different input sizes.""" + """ + Main entry point that runs the cross entropy kernel verification. + Tests with a batch size of 128 and vocabulary size of 1000. + """ # Test with moderate size n, v = 128, 1000 logits = torch.randn(n, v, device="cuda", dtype=torch.float32) diff --git a/examples/embedding.py b/examples/embedding.py index c169be15..66c8261e 100644 --- a/examples/embedding.py +++ b/examples/embedding.py @@ -1,3 +1,13 @@ +""" +Embedding Lookup Example +==================== + +This example demonstrates how to implement an embedding lookup operation using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,6 +17,9 @@ import helion.language as hl +# %% +# Embedding Kernel +# ------------- @helion.kernel() def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """ @@ -32,13 +45,30 @@ def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: return out.view(*x.size(), embedding_dim) +# %% +# Benchmark Wrapper +# -------------- def embedding_tritonbench( V: int, D: int, inp: torch.Tensor, shared_weight: torch.Tensor ) -> torch.Tensor: - """Wrapper for tritonbench that matches its interface.""" + """ + Wrapper for tritonbench that matches its interface. + + Args: + V: Vocabulary size (unused, provided for compatibility) + D: Embedding dimension (unused, provided for compatibility) + inp: Input tensor of indices + shared_weight: Embedding weight matrix + + Returns: + Output tensor containing the embedding vectors + """ return embedding(inp, shared_weight) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the embedding kernel verification. diff --git a/examples/exp.py b/examples/exp.py index f87b72f2..305a0562 100644 --- a/examples/exp.py +++ b/examples/exp.py @@ -1,3 +1,13 @@ +""" +Exponential Function Example +======================== + +This example demonstrates how to implement an element-wise exponential function using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,6 +17,9 @@ import helion.language as hl +# %% +# Exponential Kernel +# --------------- @helion.kernel() def exp(x: torch.Tensor) -> torch.Tensor: """ @@ -24,11 +37,25 @@ def exp(x: torch.Tensor) -> torch.Tensor: return out +# %% +# Benchmark Wrapper +# -------------- def exp_tritonbench(x: torch.Tensor) -> dict[str, torch.Tensor]: - """Wrapper for tritonbench that returns output in expected format.""" + """ + Wrapper for tritonbench that returns output in expected format. + + Args: + x: Input tensor + + Returns: + Dictionary containing the output tensor + """ return {"output": exp(x)} +# %% +# Verification Function +# ------------------- def check(n: int) -> None: """ Verify the exp kernel implementation against PyTorch's native exp function. @@ -40,6 +67,9 @@ def check(n: int) -> None: run_example(exp, torch.exp, (x,)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the exp kernel verification with a tensor of size 1M elements. diff --git a/examples/fp8_attention.py b/examples/fp8_attention.py index be4df7fd..69a44d0e 100644 --- a/examples/fp8_attention.py +++ b/examples/fp8_attention.py @@ -1,3 +1,13 @@ +""" +FP8 Attention Mechanism Example +====================>>>>>>> REPLACE + +This example demonstrates how to implement a scaled dot-product attention mechanism using FP8 precision in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import math diff --git a/examples/fp8_gemm.py b/examples/fp8_gemm.py index 81cc6815..7b8ce9e3 100644 --- a/examples/fp8_gemm.py +++ b/examples/fp8_gemm.py @@ -1,3 +1,13 @@ +""" +FP8 Matrix Multiplication Example +============================ + +This example demonstrates how to implement a matrix multiplication kernel using FP8 precision in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,9 +17,13 @@ import helion.language as hl +# %% +# FP8 GEMM Kernel +# ------------ @helion.kernel(static_shapes=True) def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """FP8 General Matrix Multiplication (GEMM). + """ + FP8 General Matrix Multiplication (GEMM). This kernel demonstrates FP8 computation in Helion. When lowered to Triton, the tl.dot operation will handle @@ -47,10 +61,22 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Reference Implementation +# -------------------- def reference_fp8_gemm_pytorch( x_fp8: torch.Tensor, y_fp8: torch.Tensor ) -> torch.Tensor: - """Reference implementation using torch._scaled_mm.""" + """ + Reference implementation using torch._scaled_mm. + + Args: + x_fp8: Input tensor in FP8 format + y_fp8: Input tensor in FP8 format + + Returns: + Output tensor in FP16 format + """ # torch._scaled_mm requires column-major for second operand y_fp8_t = y_fp8.T.contiguous().T scale_a = torch.tensor(1.0, device=x_fp8.device) @@ -60,13 +86,35 @@ def reference_fp8_gemm_pytorch( ) +# %% +# Benchmark Wrapper +# -------------- def fp8_gemm_tritonbench(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """Wrapper for TritonBench compatibility.""" + """ + Wrapper for TritonBench compatibility. + + Args: + a: First input tensor in FP8 format + b: Second input tensor in FP8 format + + Returns: + Output tensor from the fp8_gemm kernel + """ return fp8_gemm(a, b) +# %% +# Verification Function +# ------------------- def check(m: int, k: int, n: int) -> None: - """Test the FP8 GEMM implementation.""" + """ + Test the FP8 GEMM implementation against the PyTorch reference implementation. + + Args: + m: First dimension of the first matrix + k: Second dimension of the first matrix / First dimension of the second matrix + n: Second dimension of the second matrix + """ # Create FP8 tensors x = torch.randn([m, k], device="cuda", dtype=torch.float32) y = torch.randn([k, n], device="cuda", dtype=torch.float32) @@ -78,7 +126,14 @@ def check(m: int, k: int, n: int) -> None: run_example(fp8_gemm, reference_fp8_gemm_pytorch, (x_fp8, y_fp8)) +# %% +# Main Function +# ----------- def main() -> None: + """ + Main entry point that runs the FP8 GEMM kernel verification with different matrix sizes. + Tests with small (256x256), medium (512x512), and large (1024x1024) matrices. + """ # Test with different sizes check(256, 256, 256) check(512, 512, 512) diff --git a/examples/jagged_dense_add.py b/examples/jagged_dense_add.py index 8798029e..201e5556 100644 --- a/examples/jagged_dense_add.py +++ b/examples/jagged_dense_add.py @@ -1,3 +1,14 @@ +""" +Jagged Dense Addition Example +========================= + +This example demonstrates how to implement an addition operation between a jagged tensor +and a dense tensor using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -6,6 +17,9 @@ from helion._testing import run_example import helion.language as hl +# %% +# Jagged Tensor Format +# ----------------- """ A tensor x is stored in a jagged-row, prefix-sparse layout that packs only the non-zero elements of each row. All non-zeros are concatenated into a one-dimensional buffer @@ -14,12 +28,12 @@ contains exactly the first K_i non-zero entries of that row (with K_i = x_offsets[i+1] − x_offsets[i]). Elements beyond column K_i − 1 are implicitly zero and therefore omitted from storage. - -This example implements a kernel that adds a dense matrix y to a -jagged matrix x. It is intended to illustrate how to work with jagged tensors. """ +# %% +# Jagged Dense Addition Kernel +# ------------------------ @helion.kernel() def jagged_dense_add_2d( x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor @@ -28,16 +42,14 @@ def jagged_dense_add_2d( Add a jagged-prefix sparse tensor (x_data, x_offsets) to a dense matrix y and return the dense result. - Args - ---- - x_data : 1-D tensor holding all non-zero elements row-by-row. - x_offsets : (num_rows + 1) tensor. Row i is the slice - x_data[x_offsets[i] : x_offsets[i+1]] (length K_i). - y: (num_rows, N) tensor, N >= max(K_i). + Args: + x_data: 1-D tensor holding all non-zero elements row-by-row + x_offsets: (num_rows + 1) tensor. Row i is the slice + x_data[x_offsets[i] : x_offsets[i+1]] (length K_i) + y: (num_rows, N) tensor, N >= max(K_i) - Returns - ------- - result : dense + jagged, shape (num_rows, N). + Returns: + Dense tensor of shape (num_rows, N) containing the sum of the jagged and dense tensors """ num_rows = y.size(0) assert x_offsets.size(0) == num_rows + 1 @@ -63,12 +75,25 @@ def jagged_dense_add_2d( return out +# %% +# Reference Implementation +# -------------------- def jagged_dense_add_2d_reference( x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: - """The same as the above, but implemented in pure PyTorch.""" + """ + Reference implementation of jagged dense addition in pure PyTorch. + + Args: + x_data: 1-D tensor holding all non-zero elements row-by-row + x_offsets: (num_rows + 1) tensor with offsets for each row + y: Dense tensor to add to the jagged tensor + + Returns: + Dense tensor containing the sum of the jagged and dense tensors + """ num_rows = x_offsets.numel() - 1 assert y.shape[0] == num_rows out = y.clone() @@ -79,6 +104,9 @@ def jagged_dense_add_2d_reference( return out +# %% +# Utility Function +# ------------- def random_jagged_2d( num_rows: int, max_cols: int, @@ -87,10 +115,18 @@ def random_jagged_2d( device: torch.device | str = "cuda", ) -> tuple[torch.Tensor, torch.Tensor]: """ - Produces: - x_data – 1-D tensor holding all non-zeros row-by-row - x_offsets – (num_rows+1) tensor; x_data[x_offsets[i]:x_offsets[i+1]] is row i - Each row i has a random non-zero prefix length K_i in [1, max_cols]. + Generate random jagged 2D tensor data. + + Args: + num_rows: Number of rows in the jagged tensor + max_cols: Maximum number of columns per row + dtype: Data type for the tensor values + device: Device to create the tensors on + + Returns: + Tuple of (x_data, x_offsets) where: + - x_data: 1-D tensor holding all non-zeros row-by-row + - x_offsets: (num_rows+1) tensor with offsets for each row """ # random positive K_i for each row lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device) @@ -105,6 +141,9 @@ def random_jagged_2d( return x_data, x_offsets +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the jagged dense add kernel verification. diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index 0e811231..5494e7d0 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -1,3 +1,14 @@ +""" +Jagged Mean Example +=============== + +This example demonstrates how to compute the mean of each row in a jagged tensor +with variable features per row using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import os @@ -8,12 +19,18 @@ from helion._testing import run_example import helion.language as hl +# %% +# Configuration +# ----------- # TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": # Low memory configuration TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64} +# %% +# Jagged Mean Kernel +# --------------- @helion.kernel() def jagged_mean_kernel( x_data: torch.Tensor, @@ -24,18 +41,16 @@ def jagged_mean_kernel( """ Compute the mean of each row in a jagged tensor with variable features per row. - Args - ---- - x_data : 2-D tensor of shape (total_elements, max_M) holding all elements. - x_offsets : (num_rows + 1) tensor. Row i is the slice - x_data[x_offsets[i] : x_offsets[i+1], :]. - x_feature_counts: (num_rows) tensor. Number of valid features for each row. - max_M_tensor : Dummy tensor whose numel() gives max number of features. - - Returns - ------- - result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row. - Invalid features (beyond x_feature_counts[i]) are set to 0. + Args: + x_data: 2-D tensor of shape (total_elements, max_M) holding all elements + x_offsets: (num_rows + 1) tensor. Row i is the slice + x_data[x_offsets[i] : x_offsets[i+1], :] + x_feature_counts: (num_rows) tensor. Number of valid features for each row + max_M_tensor: Dummy tensor whose numel() gives max number of features + + Returns: + 2-D tensor of shape (num_rows, max_M) containing the mean of each row. + Invalid features (beyond x_feature_counts[i]) are set to 0. """ num_rows = x_offsets.size(0) - 1 max_M = max_M_tensor.numel() # Extract max features from dummy tensor @@ -96,13 +111,27 @@ def jagged_mean_kernel( return out +# %% +# Reference Implementation +# -------------------- def reference_jagged_mean_kernel_pytorch( x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_counts: torch.Tensor, max_M: int, ) -> torch.Tensor: - """PyTorch reference implementation for jagged mean with variable features.""" + """ + PyTorch reference implementation for jagged mean with variable features. + + Args: + x_data: 2-D tensor holding all elements + x_offsets: Offsets tensor for row indexing + x_feature_counts: Number of valid features per row + max_M: Maximum number of features + + Returns: + Tensor containing the mean of each row + """ num_rows = x_offsets.numel() - 1 out = torch.zeros((num_rows, max_M), dtype=x_data.dtype, device=x_data.device) for i in range(num_rows): @@ -114,6 +143,9 @@ def reference_jagged_mean_kernel_pytorch( return out +# %% +# Benchmark Wrapper +# -------------- def jagged_mean_tritonbench( x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float ) -> torch.Tensor: @@ -144,6 +176,9 @@ def jagged_mean_tritonbench( return jagged_mean_kernel(x_values, x_offsets, feature_counts, max_M_tensor) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the jagged mean kernel verification. diff --git a/examples/long_sum.py b/examples/long_sum.py index e8a40967..19bd20b4 100644 --- a/examples/long_sum.py +++ b/examples/long_sum.py @@ -1,3 +1,13 @@ +""" +Long Dimension Sum Example +====================== + +This example demonstrates how to implement efficient sum reduction along a long dimension using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,6 +17,9 @@ import helion.language as hl +# %% +# Baseline Implementation +# ------------------- def baseline_sum(x: torch.Tensor) -> torch.Tensor: """ PyTorch baseline implementation of sum reduction along the last dimension. @@ -20,7 +33,9 @@ def baseline_sum(x: torch.Tensor) -> torch.Tensor: return x.sum(-1) -# Naive Reduction: Load the entire reduction dim at once, and reduce in reg. +# %% +# Naive Reduction Kernel +# ------------------ @helion.kernel( config=helion.Config( block_sizes=[1], @@ -50,7 +65,9 @@ def longsum(x: torch.Tensor) -> torch.Tensor: return out -# Looped reduction +# %% +# Looped Reduction Kernel +# ------------------- @helion.kernel( config=helion.Config( block_sizes=[1], @@ -82,7 +99,9 @@ def longsum_w_red_loop(x: torch.Tensor) -> torch.Tensor: return out -# This generates the same code as above, but manually implements looped reduction. +# %% +# Manual Looped Reduction Kernel +# -------------------------- @helion.kernel( config=helion.Config( block_sizes=[32768, 1], num_warps=16, num_stages=5, indexing="pointer" @@ -114,6 +133,9 @@ def longsum_manual(x: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: """ Verify the sum kernel implementations against PyTorch's native sum function. @@ -136,6 +158,9 @@ def check(m: int, n: int) -> None: run_example(kernels, baseline_sum, (x,)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the sum kernel verification with a large tensor. diff --git a/examples/matmul.py b/examples/matmul.py index 75c8e7fe..1441d756 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -1,3 +1,13 @@ +""" +Matrix Multiplication Example +============================ + +This example demonstrates how to implement a basic matrix multiplication kernel using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,6 +17,9 @@ import helion.language as hl +# %% +# Matrix Multiplication Kernel +# --------------------------- # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -34,6 +47,9 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, k: int, n: int) -> None: """ Verify the matmul kernel implementation against PyTorch's native matmul function. @@ -48,6 +64,9 @@ def check(m: int, k: int, n: int) -> None: run_example(matmul, torch.matmul, (x, y)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the matmul kernel verification with 1024x1024 matrices. diff --git a/examples/matmul_layernorm.py b/examples/matmul_layernorm.py index 81075c85..59d45e52 100644 --- a/examples/matmul_layernorm.py +++ b/examples/matmul_layernorm.py @@ -1,3 +1,14 @@ +""" +Matrix Multiplication with Layer Normalization Example +============================================== + +This example demonstrates how to implement a fused matrix multiplication and layer normalization +operation using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -8,6 +19,9 @@ import helion.language as hl +# %% +# MatMul-LayerNorm Kernel +# -------------------- # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def matmul_layernorm( @@ -47,6 +61,9 @@ def matmul_layernorm( return out +# %% +# Reference Implementation +# -------------------- def matmul_layernorm_pytorch( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor ) -> torch.Tensor: @@ -74,6 +91,9 @@ def matmul_layernorm_pytorch( return ln_out.to(torch.promote_types(x.dtype, y.dtype)) +# %% +# Verification Function +# ------------------- def check(m: int, k: int, n: int) -> None: """ Verify the matmul_layernorm kernel implementation against the PyTorch reference implementation. @@ -90,6 +110,9 @@ def check(m: int, k: int, n: int) -> None: run_example(matmul_layernorm, matmul_layernorm_pytorch, (x, y, weight, bias)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the matmul_layernorm kernel verification with different matrix sizes. diff --git a/examples/matmul_split_k.py b/examples/matmul_split_k.py index f8556c4e..efe7fb88 100644 --- a/examples/matmul_split_k.py +++ b/examples/matmul_split_k.py @@ -1,3 +1,14 @@ +""" +Split-K Matrix Multiplication Example +================================ + +This example demonstrates how to implement a matrix multiplication kernel using the split-K +algorithm for better parallelism in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -8,6 +19,9 @@ import helion.language as hl +# %% +# Split-K Matrix Multiplication Kernel +# -------------------------------- # static_shapes=True gives a performance boost for matmuls @helion.kernel(static_shapes=True) def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -40,6 +54,9 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, k: int, n: int) -> None: """ Verify the split-K matmul kernel implementation against PyTorch's native matmul function. @@ -54,6 +71,9 @@ def check(m: int, k: int, n: int) -> None: run_example(matmul_split_k, torch.matmul, (x, y), atol=1) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the split-K matmul kernel verification. diff --git a/examples/moe_matmul_ogs.py b/examples/moe_matmul_ogs.py index 9ffafe6f..a32232b4 100644 --- a/examples/moe_matmul_ogs.py +++ b/examples/moe_matmul_ogs.py @@ -1,7 +1,14 @@ """ -Mixture-of-Experts (MoE) matmul with Outer-Gather-Scatter (OGS) +Mixture-of-Experts Matrix Multiplication Example +========================================= + +This example demonstrates how to implement a Mixture-of-Experts (MoE) matrix multiplication +using the Outer-Gather-Scatter (OGS) approach in Helion. """ +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -11,6 +18,9 @@ import helion.language as hl +# %% +# MoE MatMul OGS Kernel +# ------------------ @helion.kernel(static_shapes=False) def moe_matmul_ogs( A: torch.Tensor, # [T, K] - Input activations (T tokens, K features) @@ -20,6 +30,23 @@ def moe_matmul_ogs( sorted_to_orig_token_idx: torch.Tensor, # [T] - Maps sorted token positions back to original positions max_T_per_expert_tensor: torch.Tensor, # [max_T_per_expert] - Dummy tensor whose size indicates max tokens per expert ) -> torch.Tensor: # [T, N] - Output activations + """ + Performs Mixture-of-Experts (MoE) matrix multiplication using the Outer-Gather-Scatter approach. + + This kernel efficiently handles sparse expert routing by grouping tokens by their assigned expert, + performing matrix multiplications for each expert, and scattering results back to the original token order. + + Args: + A: Input activations tensor of shape [T, K] (T tokens, K features) + W: Expert weights tensor of shape [E, K, N] (E experts, K input features, N output features) + expert_token_counts: Number of tokens assigned to each expert, shape [E] + expert_token_offsets: Starting position of each expert's tokens in sorted order, shape [E+1] + sorted_to_orig_token_idx: Maps sorted token positions back to original positions, shape [T] + max_T_per_expert_tensor: Dummy tensor whose size indicates max tokens per expert + + Returns: + Output activations tensor of shape [T, N] + """ # Extract dimensions from input tensors T, K = A.shape E, _, N = W.shape @@ -89,6 +116,9 @@ def moe_matmul_ogs( return C +# %% +# Helper Function for Kernel Arguments +# -------------------------------- def moe_matmul_ogs_helion_kernel_args_gen( A: torch.Tensor, # [T, K] - Input activations W: torch.Tensor, # [E, K, N] - Expert weights @@ -96,6 +126,19 @@ def moe_matmul_ogs_helion_kernel_args_gen( ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ]: + """ + Generates the arguments needed for the MoE MatMul OGS kernel. + + Prepares the data structures needed for efficient token routing and processing. + + Args: + A: Input activations tensor of shape [T, K] + W: Expert weights tensor of shape [E, K, N] + top1_expert_per_token: Expert assignment for each token, shape [T] + + Returns: + Tuple of tensors needed for the MoE MatMul OGS kernel + """ E = W.size(0) # Number of experts device = A.device @@ -131,9 +174,23 @@ def moe_matmul_ogs_helion_kernel_args_gen( ) +# %% +# Reference Implementation +# -------------------- def moe_matmul_ogs_reference( A: torch.Tensor, W: torch.Tensor, top1_expert_per_token: torch.Tensor ) -> torch.Tensor: + """ + PyTorch reference implementation of MoE matrix multiplication. + + Args: + A: Input activations tensor of shape [T, K] + W: Expert weights tensor of shape [E, K, N] + top1_expert_per_token: Expert assignment for each token, shape [T] + + Returns: + Output activations tensor of shape [T, N] + """ T, K = A.shape N = W.size(2) device, dtype = A.device, torch.promote_types(A.dtype, W.dtype) @@ -150,6 +207,9 @@ def moe_matmul_ogs_reference( return C +# %% +# Verification Function +# ------------------- def check(T: int, K: int, N: int, n_experts: int) -> None: """ Verify the MoE matmul OGS kernel implementation against the reference implementation. @@ -181,6 +241,9 @@ def reference_fn() -> torch.Tensor: run_example(helion_fn, reference_fn, ()) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the MoE matmul OGS kernel verification. diff --git a/examples/rms_norm.py b/examples/rms_norm.py index ae1dde18..678d7f86 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -1,3 +1,14 @@ +""" +Root Mean Square Normalization Example +================================= + +This example demonstrates how to implement a Root Mean Square (RMS) normalization +operation using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -6,11 +17,17 @@ from helion._testing import run_example import helion.language as hl +# %% +# Configuration +# ----------- # TritonBench configuration # TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg. TRITONBENCH_ARGS = {"num_inputs": 3} +# %% +# RMS Normalization Kernel +# --------------------- @helion.kernel(static_shapes=True) def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: """ @@ -47,15 +64,41 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch. return out +# %% +# Benchmark Wrapper +# -------------- def rms_norm_tritonbench(H: int, inp: torch.Tensor) -> torch.Tensor: - """Wrapper for tritonbench that matches expected interface.""" + """ + Wrapper for tritonbench that matches expected interface. + + Args: + H: Hidden dimension size + inp: Input tensor + + Returns: + Normalized tensor + """ weight = torch.ones(H, device=inp.device, dtype=inp.dtype) return rms_norm(inp, weight, eps=1e-6) +# %% +# Reference Implementation +# -------------------- def rms_norm_pytorch( x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5 ) -> torch.Tensor: + """ + PyTorch reference implementation of RMS normalization. + + Args: + x: Input tensor + weight: Scale parameter + eps: Small constant for numerical stability + + Returns: + Normalized tensor + """ input_dtype = x.dtype hidden_states = x.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) @@ -63,6 +106,9 @@ def rms_norm_pytorch( return weight * hidden_states.to(input_dtype) +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: """ Verify the RMS norm kernel implementation against the PyTorch reference implementation. @@ -76,6 +122,9 @@ def check(m: int, n: int) -> None: run_example(rms_norm, rms_norm_pytorch, (x, weight, 1e-5)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the RMS norm kernel verification with different tensor sizes. diff --git a/examples/segment_reduction.py b/examples/segment_reduction.py index 4480b6df..e8ce4cff 100644 --- a/examples/segment_reduction.py +++ b/examples/segment_reduction.py @@ -1,4 +1,14 @@ -# Code based on https://github.com/pytorch-labs/helion/issues/237 +""" +Segmented Reduction Example +======================= + +This example demonstrates how to implement a segmented reduction operation using Helion, +comparing it with Triton and PyTorch implementations. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -11,6 +21,9 @@ import helion.language as hl +# %% +# Helion Implementation +# ----------------- def combine_fn_helion( left_values: torch.Tensor, left_indices: torch.Tensor, @@ -74,6 +87,9 @@ def segmented_reduction_helion( return output +# %% +# Triton Implementation +# ----------------- @triton.jit def combine_fn_triton( left_values: tl.tensor, @@ -188,6 +204,9 @@ def grid(META: dict[str, int]) -> tuple[int, ...]: return output +# %% +# PyTorch Reference Implementation +# ---------------------------- def segmented_reduction_pytorch( indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int ) -> torch.Tensor: @@ -215,6 +234,9 @@ def segmented_reduction_pytorch( return pytorch_output +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the segmented reduction implementations. diff --git a/examples/softmax.py b/examples/softmax.py index 2cad297e..edb91f55 100644 --- a/examples/softmax.py +++ b/examples/softmax.py @@ -1,3 +1,13 @@ +""" +Softmax Function Example +=================== + +This example demonstrates how to implement softmax operations using different approaches in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,6 +17,9 @@ import helion.language as hl +# %% +# Simple Softmax Kernel +# ----------------- @helion.kernel() def softmax(x: torch.Tensor) -> torch.Tensor: """ @@ -25,7 +38,9 @@ def softmax(x: torch.Tensor) -> torch.Tensor: return out -# This generates the same code as the above, but avoids using the pytorch softmax decomposition +# %% +# Decomposed Softmax Kernel +# --------------------- @helion.kernel() def softmax_decomposed(x: torch.Tensor) -> torch.Tensor: """ @@ -53,7 +68,9 @@ def softmax_decomposed(x: torch.Tensor) -> torch.Tensor: return out -# This optimization does softmax in fewer passes, but is less numerically stable +# %% +# Two-Pass Optimized Softmax Kernel +# ----------------------------- @helion.kernel() def softmax_two_pass(x: torch.Tensor) -> torch.Tensor: """ @@ -89,6 +106,9 @@ def softmax_two_pass(x: torch.Tensor) -> torch.Tensor: return out +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: """ Verify the softmax kernel implementations against PyTorch's native softmax function. @@ -106,6 +126,9 @@ def check(m: int, n: int) -> None: run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the softmax kernel verification with a 1024x1024 tensor. diff --git a/examples/sum.py b/examples/sum.py index 43746333..ef47ba28 100644 --- a/examples/sum.py +++ b/examples/sum.py @@ -1,3 +1,13 @@ +""" +Sum Reduction Example +================ + +This example demonstrates how to implement a sum reduction operation along the last dimension using Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations import torch @@ -7,9 +17,20 @@ import helion.language as hl +# %% +# Sum Kernel +# -------- @helion.kernel() def sum_kernel(x: torch.Tensor) -> torch.Tensor: - """Sum 2D tensor along the last dimension.""" + """ + Sums a 2D tensor along the last dimension. + + Args: + x: Input tensor of shape [M, N] + + Returns: + Output tensor of shape [M] containing the sum of each row + """ m, n = x.shape out = torch.empty([m], dtype=x.dtype, device=x.device) @@ -19,8 +40,19 @@ def sum_kernel(x: torch.Tensor) -> torch.Tensor: return out +# %% +# Benchmark Wrapper +# -------------- def sum_tritonbench(x: torch.Tensor) -> torch.Tensor: - """Wrapper for tritonbench that handles 1D input.""" + """ + Wrapper for tritonbench that handles 1D input. + + Args: + x: Input tensor (1D or 2D) + + Returns: + Sum of the tensor along the last dimension + """ if x.ndim == 1: # For 1D tensors, reshape to 2D for sum_kernel x_2d = x.unsqueeze(0) @@ -29,6 +61,9 @@ def sum_tritonbench(x: torch.Tensor) -> torch.Tensor: return sum_kernel(x) +# %% +# Verification Function +# ------------------- def check(m: int, n: int) -> None: """ Verify the sum kernel implementation against PyTorch's native sum function. @@ -42,6 +77,9 @@ def check(m: int, n: int) -> None: run_example(kernels, lambda x: x.sum(-1), (x,)) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the sum kernel verification with different tensor sizes. diff --git a/examples/template_via_closure.py b/examples/template_via_closure.py index 5b6af8f6..96d0920d 100644 --- a/examples/template_via_closure.py +++ b/examples/template_via_closure.py @@ -1,3 +1,14 @@ +""" +Template via Closure Example +======================= + +This example demonstrates how to implement a templated matrix multiplication kernel +with a customizable epilogue function using closures in Helion. +""" + +# %% +# Imports +# ------- from __future__ import annotations from typing import TYPE_CHECKING @@ -13,6 +24,9 @@ from collections.abc import Callable +# %% +# Templated MatMul Kernel +# ------------------- @helion.kernel( # static_shapes=True gives a performance boost for matmuls static_shapes=True, @@ -20,6 +34,21 @@ def matmul_with_epilogue( x: Tensor, y: Tensor, epilogue: Callable[[Tensor, list[Tensor]], Tensor] ) -> Tensor: + """ + Matrix multiplication with a customizable epilogue function. + + This kernel demonstrates how to use closures to create templated kernels + where the epilogue operation can be customized at runtime. + + Args: + x: First input tensor of shape [M, K] + y: Second input tensor of shape [K, N] + epilogue: Function that takes the accumulator and tile indices and returns + the final output for that tile + + Returns: + Output tensor of shape [M, N] with the epilogue function applied + """ m, k = x.size() k2, n = y.size() assert k == k2, f"size mismatch {k} != {k2}" @@ -34,6 +63,9 @@ def matmul_with_epilogue( return out +# %% +# Autotuning Function +# --------------- def autotune(n: int, k: int, m: int) -> None: """ Autotunes the matmul_with_epilogue kernel and saves the best configuration. @@ -55,6 +87,9 @@ def autotune(n: int, k: int, m: int) -> None: best_config.save("best_config.json") +# %% +# Verification Function +# ------------------- def check(n: int, k: int, m: int) -> None: """ Verify the matmul_with_epilogue kernel implementation against a PyTorch baseline. @@ -87,6 +122,9 @@ def baseline_wrapper(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ) +# %% +# Main Function +# ----------- def main() -> None: """ Main entry point that runs the matmul_with_epilogue kernel verification. From c1006aef48f8cc28dcb3ecdfc70c740713503638 Mon Sep 17 00:00:00 2001 From: sekyonda <127536312+sekyondaMeta@users.noreply.github.com> Date: Fri, 25 Jul 2025 08:34:18 -0400 Subject: [PATCH 7/9] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index ea422f57..4c0778e8 100644 --- a/.gitignore +++ b/.gitignore @@ -91,4 +91,5 @@ torch benchmarks/tritonbench site generated +uv.lock docs/examples/ From 046372caf2749ba49733c074b69022fd8f3f6dde Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Fri, 25 Jul 2025 08:40:21 -0400 Subject: [PATCH 8/9] Update examples/README.rst Co-authored-by: Jason Ansel --- examples/README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.rst b/examples/README.rst index 2e32391c..dc02d515 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -22,7 +22,7 @@ Matrix Operations - ``matmul_layernorm.py``: Fused matrix multiplication and layer normalization - ``fp8_gemm.py``: Matrix multiplication using FP8 precision -Attention Mechanisms +Attention Operations ~~~~~~~~~~~~~~~~~~~ - ``attention.py``: Scaled dot-product attention mechanism From d99c7bbbf49030851ec7d985c75ccfaa6c38f32f Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Fri, 25 Jul 2025 08:40:28 -0400 Subject: [PATCH 9/9] Update examples/README.rst Co-authored-by: Jason Ansel --- examples/README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.rst b/examples/README.rst index dc02d515..e50eba55 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -13,7 +13,7 @@ Basic Operations - ``long_sum.py``: Efficient sum reduction along a long dimension - ``softmax.py``: Different implementations of the softmax function -Matrix Operations +Matrix Multiplication Operations ~~~~~~~~~~~~~~~~ - ``matmul.py``: Basic matrix multiplication