-
Notifications
You must be signed in to change notification settings - Fork 15
[Example] Layer Norm Forward #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cb1f007
955b4e0
c23130e
13403ca
17924e7
2a1f519
a0db0d0
0c3a682
a67125f
317d746
43762a0
1d9c518
d8231c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import torch | ||
|
||
import helion | ||
from helion._testing import run_example | ||
import helion.language as hl | ||
|
||
|
||
# TODO(PaulZhang12): Support autotuning, setting reduction_loops currently errors | ||
@helion.kernel( | ||
static_shapes=True, | ||
config=helion.Config( | ||
block_sizes=[32], | ||
reduction_loops=[None], | ||
range_unroll_factors=[0], | ||
range_warp_specializes=[], | ||
range_num_stages=[0], | ||
range_multi_buffers=[None], | ||
range_flattens=[None], | ||
num_warps=4, | ||
num_stages=3, | ||
indexing="pointer", | ||
pid_type="flat", | ||
), | ||
) | ||
def layer_norm_fwd( | ||
x: torch.Tensor, | ||
weight: torch.Tensor, | ||
bias: torch.Tensor, | ||
eps: float = 1e-5, | ||
) -> torch.Tensor: | ||
m, n = x.size() | ||
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" | ||
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" | ||
out = torch.empty([m, n], dtype=torch.float16, device=x.device) | ||
|
||
for tile_m in hl.tile(m): | ||
acc = x[tile_m, :].to( | ||
torch.float32 | ||
) # TODO (PaulZhang12): Eliminate this cast, currently necessary | ||
|
||
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) | ||
|
||
normalized = (acc - mean) * torch.rsqrt(var + eps) | ||
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) | ||
|
||
out[tile_m, :] = acc | ||
return out | ||
|
||
|
||
def helion_layer_norm_wrapper( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need a wraper? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PyTorch's layer_norm https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.layer_norm.html takes in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to pass it into the kernel, I think we can remove this. |
||
x: torch.Tensor, | ||
dims: list[int], | ||
weight: torch.Tensor, | ||
bias: torch.Tensor, | ||
eps: float = 1e-5, | ||
) -> Any: # noqa: ANN401 | ||
assert len(dims) == 1, "Helion layer norm only supports 1D layer norm currently" | ||
return layer_norm_fwd(x, weight, bias, eps) | ||
|
||
|
||
def main() -> None: | ||
batch_size = 32 | ||
dim = 64 | ||
device = "cuda" | ||
|
||
x = torch.randn([batch_size, dim], device=device, dtype=torch.float16) | ||
weight = torch.randn([dim], device=device, dtype=torch.float16) | ||
bias = torch.randn([dim], device=device, dtype=torch.float16) | ||
eps = 1e-4 | ||
|
||
run_example( | ||
helion_layer_norm_wrapper, | ||
torch.nn.functional.layer_norm, | ||
(x, [dim], weight, bias, eps), | ||
kernel_name="helion", | ||
baseline_name="torch", | ||
rtol=1e-3, | ||
atol=1e-3, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw probably need to add a unit test to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the error you are getting? Do you need help fixing this one? For benchmarking we need to run the autotuner.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I filed an issue #345, haven't had time to look into it but can address it before merging