Skip to content

Commit 0f676de

Browse files
committed
Layer Norm fwd issue
ghstack-source-id: cb9208f Pull Request resolved: #170
1 parent 41fe6e9 commit 0f676de

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

examples/layer_norm.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import torch
6+
7+
import helion
8+
from helion._testing import run_example
9+
import helion.language as hl
10+
11+
12+
@helion.kernel(static_shapes=True, use_default_config=True)
13+
def layer_norm_fwd(
14+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
15+
) -> torch.Tensor:
16+
m, n = x.size()
17+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
18+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
19+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
20+
21+
eps = 1e-5
22+
23+
for tile_m in hl.tile(m):
24+
acc = x[tile_m, :].to(
25+
torch.float32
26+
) # TODO (PaulZhang12): Eliminate this cast, currently necessary
27+
28+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
29+
30+
normalized = (acc - mean) * torch.rsqrt(var + eps)
31+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
32+
33+
out[tile_m, :] = acc
34+
return out
35+
36+
37+
def layer_norm_torch_callable(
38+
dims: list[int],
39+
) -> Any: # noqa: ANN401
40+
return lambda x, weight, bias, eps: torch.nn.functional.layer_norm(
41+
x, dims, weight, bias, eps
42+
)
43+
44+
45+
def main() -> None:
46+
batch_size = 32
47+
dim = 64
48+
device = "cuda"
49+
eps = 1e-3
50+
51+
x = torch.randn([batch_size, dim], device=device, dtype=torch.float16)
52+
weight = torch.randn([dim], device=device, dtype=torch.float16)
53+
bias = torch.randn([dim], device=device, dtype=torch.float16)
54+
55+
run_example(
56+
layer_norm_fwd,
57+
layer_norm_torch_callable,
58+
(x, weight, bias, eps),
59+
kernel_name="helion",
60+
baseline_name="torch",
61+
rtol=1e-4,
62+
atol=1e-4,
63+
)
64+
65+
66+
if __name__ == "__main__":
67+
main()

0 commit comments

Comments
 (0)