Skip to content

Commit e98286d

Browse files
committed
Layer Norm fwd issue
ghstack-source-id: f2dbcd0 Pull Request resolved: #170
1 parent 41fe6e9 commit e98286d

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

examples/layer_norm.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
from typing import Callable
5+
6+
import torch
7+
8+
import helion
9+
from helion._testing import run_example
10+
import helion.language as hl
11+
12+
13+
@helion.kernel(static_shapes=True, use_default_config=True)
14+
def layer_norm_fwd(
15+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
16+
) -> torch.Tensor:
17+
m, n = x.size()
18+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
19+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
20+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
21+
22+
eps = 1e-5
23+
24+
for tile_m in hl.tile(m):
25+
acc = x[tile_m, :].to(
26+
torch.float32
27+
) # TODO (PaulZhang12): Eliminate this cast, currently necessary
28+
29+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
30+
31+
normalized = (acc - mean) * torch.rsqrt(var + eps)
32+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
33+
34+
out[tile_m, :] = acc
35+
return out
36+
37+
38+
def layer_norm_torch_callable(
39+
dims: list[int],
40+
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor, float], Any]:
41+
return lambda x, weight, bias, eps: torch.nn.functional.layer_norm(
42+
x, dims, weight, bias, eps
43+
)
44+
45+
46+
def main() -> None:
47+
batch_size = 32
48+
dim = 64
49+
device = "cuda"
50+
eps = 1e-3
51+
52+
x = torch.randn([batch_size, dim], device=device, dtype=torch.float16)
53+
weight = torch.randn([dim], device=device, dtype=torch.float16)
54+
bias = torch.randn([dim], device=device, dtype=torch.float16)
55+
56+
run_example(
57+
layer_norm_fwd,
58+
layer_norm_torch_callable,
59+
(x, weight, bias, eps),
60+
kernel_name="helion",
61+
baseline_name="torch",
62+
rtol=1e-4,
63+
atol=1e-4,
64+
)
65+
66+
67+
if __name__ == "__main__":
68+
main()

0 commit comments

Comments
 (0)