Skip to content

Commit 1b50f73

Browse files
committed
Layer Norm fwd issue
ghstack-source-id: 44f54ae Pull Request resolved: #170
1 parent 41fe6e9 commit 1b50f73

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

examples/layer_norm.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
@helion.kernel(static_shapes=True, use_default_config=True)
10+
def layer_norm_fwd(
11+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
12+
) -> torch.Tensor:
13+
m, n = x.size()
14+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
15+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
16+
out = torch.empty(
17+
[m, n], dtype=torch.float16, device=x.device
18+
)
19+
20+
eps = 1e-5
21+
22+
for tile_m in hl.tile(m):
23+
acc = x[tile_m, :].to(torch.float32) # TODO: Eliminate this cast, currently necessary
24+
25+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
26+
27+
normalized = (acc - mean) * torch.rsqrt(var + eps)
28+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
29+
30+
out[tile_m, :] = acc
31+
return out
32+
33+
34+
def main() -> None:
35+
batch_size = 32
36+
dim = 64
37+
device = "cuda"
38+
39+
x = torch.randn([batch_size, dim], device=device, dtype=torch.float16)
40+
weight = torch.randn([dim], device=device, dtype=torch.float16)
41+
bias = torch.randn([dim], device=device, dtype=torch.float16)
42+
43+
baseline_func = lambda x, weight, bias: torch.nn.functional.layer_norm(x, [dim], weight, bias)
44+
45+
run_example(
46+
layer_norm_fwd,
47+
baseline_func,
48+
(x, weight, bias),
49+
kernel_name="helion",
50+
baseline_name="torch",
51+
rtol=1e-4,
52+
atol=1e-4,
53+
)
54+
55+
56+
if __name__ == "__main__":
57+
main()

0 commit comments

Comments
 (0)