Skip to content

Commit e3bb95d

Browse files
committed
Layer Norm fwd issue
ghstack-source-id: 38636cd Pull Request resolved: #170
1 parent 41fe6e9 commit e3bb95d

File tree

4 files changed

+160
-0
lines changed

4 files changed

+160
-0
lines changed

benchmarks/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@
7575
"examples.fp8_attention",
7676
"fp8_attention_tritonbench",
7777
),
78+
"layer_norm": (
79+
"tritonbench.operators.layer_norm.operator",
80+
"examples.layer_norm",
81+
"helion_layer_norm_wrapper",
82+
),
7883
}
7984

8085

examples/layer_norm.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import torch
6+
7+
import helion
8+
from helion._testing import run_example
9+
import helion.language as hl
10+
11+
12+
# TODO(PaulZhang12): Support autotuning, setting reduction_loops currently errors
13+
@helion.kernel(
14+
static_shapes=True,
15+
config=helion.Config(
16+
block_sizes=[32],
17+
reduction_loops=[None],
18+
range_unroll_factors=[0],
19+
range_warp_specializes=[],
20+
range_num_stages=[0],
21+
range_multi_buffers=[None],
22+
range_flattens=[None],
23+
num_warps=4,
24+
num_stages=3,
25+
indexing="pointer",
26+
pid_type="flat",
27+
),
28+
)
29+
def layer_norm_fwd(
30+
x: torch.Tensor,
31+
weight: torch.Tensor,
32+
bias: torch.Tensor,
33+
eps: float = 1e-5,
34+
) -> torch.Tensor:
35+
m, n = x.size()
36+
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
37+
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
38+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
39+
40+
for tile_m in hl.tile(m):
41+
acc = x[tile_m, :].to(
42+
torch.float32
43+
) # TODO (PaulZhang12): Eliminate this cast, currently necessary
44+
45+
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
46+
47+
normalized = (acc - mean) * torch.rsqrt(var + eps)
48+
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
49+
50+
out[tile_m, :] = acc
51+
return out
52+
53+
54+
def helion_layer_norm_wrapper(
55+
x: torch.Tensor,
56+
dims: list[int],
57+
weight: torch.Tensor,
58+
bias: torch.Tensor,
59+
eps: float = 1e-5,
60+
) -> Any: # noqa: ANN401
61+
assert len(dims) == 1, "Helion layer norm only supports 1D layer norm currently"
62+
return layer_norm_fwd(x, weight, bias, eps)
63+
64+
65+
def main() -> None:
66+
batch_size = 32
67+
dim = 64
68+
device = "cuda"
69+
70+
x = torch.randn([batch_size, dim], device=device, dtype=torch.float16)
71+
weight = torch.randn([dim], device=device, dtype=torch.float16)
72+
bias = torch.randn([dim], device=device, dtype=torch.float16)
73+
eps = 1e-4
74+
75+
run_example(
76+
helion_layer_norm_wrapper,
77+
torch.nn.functional.layer_norm,
78+
(x, [dim], weight, bias, eps),
79+
kernel_name="helion",
80+
baseline_name="torch",
81+
rtol=1e-3,
82+
atol=1e-3,
83+
)
84+
85+
86+
if __name__ == "__main__":
87+
main()

test/test_examples.expected

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,57 @@ def jagged_mean_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, x_feature_
883883
_launcher(_jagged_mean_kernel_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_feature_counts, x_flat, out, out.stride(0), out.stride(1), x_feature_counts.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, max_M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
884884
return out
885885

886+
--- assertExpectedJournal(TestExamples.test_layernorm)
887+
from __future__ import annotations
888+
889+
import torch
890+
import triton
891+
import triton.language as tl
892+
from torch._inductor.runtime.triton_compat import libdevice
893+
from helion.runtime import default_launcher as _default_launcher
894+
895+
@triton.jit
896+
def _layer_norm_fwd_kernel(x, weight, bias, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
897+
pid_0 = tl.program_id(0)
898+
offset_0 = pid_0 * _BLOCK_SIZE_0
899+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
900+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
901+
load = tl.load(x + (indices_0[:, None] * 64 + indices_1[None, :] * 1), None)
902+
v_0 = load.to(tl.float32)
903+
var_mean_extra = tl.reshape(tl.sum(v_0, 1), [_BLOCK_SIZE_0, 1])
904+
v_1 = 64
905+
v_2 = var_mean_extra / v_1.to(tl.float32)
906+
v_3 = v_0 - v_2
907+
v_4 = v_3 * v_3
908+
var_mean_extra_2 = tl.reshape(tl.sum(v_4, 1), [_BLOCK_SIZE_0, 1])
909+
v_5 = 64
910+
v_6 = var_mean_extra_2 / v_5.to(tl.float32)
911+
v_7 = v_0 - v_2
912+
v_8 = v_6 + eps
913+
v_9 = libdevice.rsqrt(v_8)
914+
v_10 = v_7 * v_9
915+
load_1 = tl.load(weight + indices_1 * 1, None)
916+
v_11 = load_1.to(tl.float32)
917+
v_12 = v_11[None, :]
918+
v_13 = v_10 * v_12
919+
load_2 = tl.load(bias + indices_1 * 1, None)
920+
v_14 = load_2.to(tl.float32)
921+
v_15 = v_14[None, :]
922+
v_16 = v_13 + v_15
923+
v_17 = v_16.to(tl.float16)
924+
tl.store(out + (indices_0[:, None] * 64 + indices_1[None, :] * 1), v_17, None)
925+
926+
def layer_norm_fwd(x: torch.Tensor, dims: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
927+
m, n = x.size()
928+
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
929+
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
930+
assert len(dims) == 1 and dims[0] == n, f'dim mismatch {dims} != {n}'
931+
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
932+
_BLOCK_SIZE_0 = 32
933+
_RDIM_SIZE_1 = 64
934+
_launcher(_layer_norm_fwd_kernel, (triton.cdiv(32, _BLOCK_SIZE_0),), x, weight, bias, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
935+
return out
936+
886937
--- assertExpectedJournal(TestExamples.test_matmul)
887938
from __future__ import annotations
888939

test/test_examples.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,23 @@ def test_fp8_attention(self):
599599
)
600600
)
601601

602+
def test_layernorm(self):
603+
args = (
604+
torch.randn([32, 64], device=DEVICE, dtype=torch.float16),
605+
[64],
606+
torch.randn([64], device=DEVICE, dtype=torch.float16),
607+
torch.randn([64], device=DEVICE, dtype=torch.float16),
608+
)
609+
self.assertExpectedJournal(
610+
check_example(
611+
"layer_norm",
612+
args,
613+
torch.nn.functional.layer_norm(*args),
614+
fn_name="layer_norm_fwd",
615+
block_sizes=[32],
616+
)
617+
)
618+
602619

603620
if __name__ == "__main__":
604621
unittest.main()

0 commit comments

Comments
 (0)