Skip to content

[Example] Layer Norm Forward #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@
"examples.fp8_attention",
"fp8_attention_tritonbench",
),
"layer_norm": (
"tritonbench.operators.layer_norm.operator",
"examples.layer_norm",
"helion_layer_norm_wrapper",
),
}


Expand Down
87 changes: 87 additions & 0 deletions examples/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

from typing import Any

import torch

import helion
from helion._testing import run_example
import helion.language as hl


# TODO(PaulZhang12): Support autotuning, setting reduction_loops currently errors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error you are getting? Do you need help fixing this one? For benchmarking we need to run the autotuner.

Copy link
Author

@PaulZhang12 PaulZhang12 Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed an issue #345, haven't had time to look into it but can address it before merging

@helion.kernel(
static_shapes=True,
config=helion.Config(
block_sizes=[32],
reduction_loops=[None],
range_unroll_factors=[0],
range_warp_specializes=[],
range_num_stages=[0],
range_multi_buffers=[None],
range_flattens=[None],
num_warps=4,
num_stages=3,
indexing="pointer",
pid_type="flat",
),
)
def layer_norm_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-5,
) -> torch.Tensor:
m, n = x.size()
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
out = torch.empty([m, n], dtype=torch.float16, device=x.device)

for tile_m in hl.tile(m):
acc = x[tile_m, :].to(
torch.float32
) # TODO (PaulZhang12): Eliminate this cast, currently necessary

var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)

normalized = (acc - mean) * torch.rsqrt(var + eps)
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))

out[tile_m, :] = acc
return out


def helion_layer_norm_wrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a wraper?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch's layer_norm https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.layer_norm.html takes in normalized_shape as second arg, a list[int]. This allows us to work around it by not having to pass this into the kernel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to pass it into the kernel, I think we can remove this.

x: torch.Tensor,
dims: list[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-5,
) -> Any: # noqa: ANN401
assert len(dims) == 1, "Helion layer norm only supports 1D layer norm currently"
return layer_norm_fwd(x, weight, bias, eps)


def main() -> None:
batch_size = 32
dim = 64
device = "cuda"

x = torch.randn([batch_size, dim], device=device, dtype=torch.float16)
weight = torch.randn([dim], device=device, dtype=torch.float16)
bias = torch.randn([dim], device=device, dtype=torch.float16)
eps = 1e-4

run_example(
helion_layer_norm_wrapper,
torch.nn.functional.layer_norm,
(x, [dim], weight, bias, eps),
kernel_name="helion",
baseline_name="torch",
rtol=1e-3,
atol=1e-3,
)


if __name__ == "__main__":
main()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw probably need to add a unit test to test_examples.py similar to other examples

50 changes: 50 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,56 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
_launcher(_jagged_mean_kernel_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_layernorm)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _layer_norm_fwd_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
load = tl.load(x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
v_0 = load.to(tl.float32)
var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1])
v_1 = 64
v_2 = var_mean_extra / v_1.to(tl.float32)
v_3 = v_0 - v_2
v_4 = v_3 * v_3
var_mean_extra_2 = tl.reshape(tl.sum(v_4, 1), [_BLOCK_SIZE_0, 1])
v_5 = 64
v_6 = var_mean_extra_2 / v_5.to(tl.float32)
v_7 = v_0 - v_2
v_8 = v_6 + eps
v_9 = libdevice.rsqrt(v_8)
v_10 = v_7 * v_9
load_1 = tl.load(weight + indices_1 * 1, None)
v_11 = load_1.to(tl.float32)
v_12 = v_11[None, :]
v_13 = v_10 * v_12
load_2 = tl.load(bias + indices_1 * 1, None)
v_14 = load_2.to(tl.float32)
v_15 = v_14[None, :]
v_16 = v_13 + v_15
v_17 = v_16.to(tl.float16)
tl.store(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), v_17, None)

def layer_norm_fwd(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
m, n = x.size()
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_1 = 64
_launcher(_layer_norm_fwd_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_matmul)
from __future__ import annotations

Expand Down
15 changes: 15 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,21 @@ def test_fp8_attention(self):
)
)

def test_layernorm(self):
x = torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)
bias = torch.randn([64], device=DEVICE, dtype=torch.float16)

self.assertExpectedJournal(
check_example(
"layer_norm",
(x, weight, bias),
torch.nn.functional.layer_norm(*(x, [64], weight, bias)),
fn_name="layer_norm_fwd",
block_sizes=[32],
)
)


if __name__ == "__main__":
unittest.main()
Loading