-
Notifications
You must be signed in to change notification settings - Fork 15
[Example] Layer Norm Forward #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 1 commit
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
cb1f007
Layer Norm fwd issue
PaulZhang12 955b4e0
Update on "Layer Norm fwd issue"
PaulZhang12 c23130e
Update on "Layer Norm fwd issue"
PaulZhang12 13403ca
Update on "[Example] Layer Norm Forward"
PaulZhang12 17924e7
Update on "[Example] Layer Norm Forward"
PaulZhang12 2a1f519
Update on "[Example] Layer Norm Forward"
PaulZhang12 a0db0d0
Update on "[Example] Layer Norm Forward"
PaulZhang12 0c3a682
Update on "[Example] Layer Norm Forward"
PaulZhang12 a67125f
Update on "[Example] Layer Norm Forward"
PaulZhang12 317d746
Update on "[Example] Layer Norm Forward"
PaulZhang12 43762a0
Update on "[Example] Layer Norm Forward"
PaulZhang12 1d9c518
Update on "[Example] Layer Norm Forward"
PaulZhang12 d8231c6
Update on "[Example] Layer Norm Forward"
PaulZhang12 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
|
||
import helion | ||
import helion.language as hl | ||
|
||
""" | ||
NOTE: layer_norm_fwd_ideal does not work! I am keeping this around as a reference | ||
to what I believed should have worked in Helion when I first began without debugging. | ||
|
||
The user experience should be pushed this direction | ||
""" | ||
@helion.kernel(static_shapes=True) | ||
def layer_norm_fwd_ideal( | ||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5 | ||
) -> torch.Tensor: | ||
""" | ||
Layer normalization forward pass. | ||
|
||
Args: | ||
x: Input tensor of shape [batch_size, hidden_size] | ||
weight: Scale parameter of shape [hidden_size] | ||
bias: Bias parameter of shape [hidden_size] | ||
eps: Epsilon for numerical stability | ||
|
||
Returns: | ||
Normalized tensor of shape [batch_size, hidden_size] | ||
""" | ||
m = x.size(0) | ||
out = torch.empty_like(x) | ||
|
||
for tile_b in hl.tile(m): | ||
row = x[tile_b] | ||
mean, var = torch.var_mean(row) | ||
|
||
layer_norm_out = (row - mean) / torch.sqrt(var + eps) | ||
layer_norm_out = layer_norm_out * weight + bias | ||
out[tile_b, :] = layer_norm_out | ||
|
||
return out | ||
|
||
@helion.kernel(static_shapes=True, use_default_config=True) | ||
def layer_norm_fwd( | ||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | ||
) -> torch.Tensor: | ||
m, n = x.size() | ||
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" | ||
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" | ||
out = torch.empty( | ||
[m, n], dtype=torch.float16, device=x.device | ||
) | ||
|
||
eps = 1e-5 | ||
|
||
for tile_m in hl.tile(m): | ||
# acc = x[tile_m, :].to(torch.float32) works! We should not have to do this cast | ||
acc = x[tile_m, :] | ||
|
||
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) | ||
|
||
normalized = (acc - mean) * torch.rsqrt(var + eps) | ||
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) | ||
|
||
out[tile_m, :] = acc | ||
return out | ||
|
||
|
||
def check(batch_size: int, hidden_size: int) -> None: | ||
from triton.testing import do_bench | ||
|
||
# Create random input tensors | ||
x = torch.randn([batch_size, hidden_size], device="cuda", dtype=torch.float16) | ||
weight = torch.randn([hidden_size], device="cuda", dtype=torch.float16) | ||
bias = torch.randn([hidden_size], device="cuda", dtype=torch.float16) | ||
|
||
# Run Helion kernel | ||
result = layer_norm_fwd(x, weight, bias) | ||
|
||
# # Run PyTorch layer norm for comparison | ||
torch_result = torch.nn.functional.layer_norm( | ||
x, [hidden_size], weight, bias, eps=1e-5 | ||
) | ||
|
||
# # Check correctness | ||
torch.testing.assert_close(result, torch_result, rtol=1e-2, atol=1e-1) | ||
|
||
# Benchmark Helion implementation | ||
helion_sec = do_bench(lambda: layer_norm_fwd(x, weight, bias)) | ||
|
||
# Benchmark PyTorch implementation | ||
torch_sec = do_bench(lambda: torch.nn.functional.layer_norm( | ||
x, [hidden_size], weight, bias, eps=1e-5 | ||
)) | ||
|
||
print( | ||
f"Helion time: {helion_sec:.4f}ms, torch time: {torch_sec:.4f}, speedup: {torch_sec / helion_sec:.2f}x" | ||
) | ||
|
||
|
||
def main() -> None: | ||
# Test with different sizes | ||
print("Testing batch_size=128, hidden_size=768") | ||
check(128, 768) | ||
|
||
print("\nTesting batch_size=32, hidden_size=1024") | ||
check(32, 1024) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw probably need to add a unit test to |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -169,20 +169,25 @@ def convert_arg(arg: Node) -> TensorBox: | |||||
nodes = [] | ||||||
extra_input_names = [] | ||||||
new_node: torch.fx.Node | ||||||
|
||||||
|
||||||
read_buffer_names = set() | ||||||
# Explicitly track the mapping from node to Inductor buffer name. | ||||||
# First, map the original input nodes to their names. | ||||||
node_to_buf_name_mapping: dict[torch.fx.Node, str] = dict( | ||||||
zip(node._input_nodes, input_names, strict=True) | ||||||
) | ||||||
|
||||||
for i, buffer in enumerate(new_buffers): | ||||||
if not isinstance(buffer, ComputedBuffer) or not isinstance( | ||||||
buffer.data, (Pointwise, Reduction) | ||||||
): | ||||||
raise InductorLoweringError( | ||||||
f"Lowering {node.target} returned buffer type {type(buffer)}, expected ComputedBuffer(Pointwise|Reduction): {buffer}" | ||||||
) | ||||||
|
||||||
for name in buffer.get_read_names(): | ||||||
read_buffer_names.add(name) | ||||||
|
||||||
if i == len(new_buffers) - 1: | ||||||
new_node = node | ||||||
if nodes: | ||||||
|
@@ -191,6 +196,7 @@ def convert_arg(arg: Node) -> TensorBox: | |||||
new_node = create_extra_node(node, buffer, [*node._input_nodes, *nodes]) | ||||||
|
||||||
# Store output index if this buffer corresponds to an output | ||||||
import pdb; pdb.set_trace() | ||||||
if buffer.get_name() in buffer_name_to_output_index: | ||||||
new_node.meta["output_index"] = buffer_name_to_output_index[ | ||||||
buffer.get_name() | ||||||
|
@@ -207,7 +213,7 @@ def convert_arg(arg: Node) -> TensorBox: | |||||
current_input_names = [] | ||||||
for inp_node in current_input_nodes: | ||||||
current_input_names.append(node_to_buf_name_mapping[inp_node]) | ||||||
|
||||||
used_input_names = strip_unused_inputs( | ||||||
new_node, | ||||||
buffer.get_read_names(), | ||||||
|
@@ -230,6 +236,7 @@ def convert_arg(arg: Node) -> TensorBox: | |||||
for n in nodes: | ||||||
if "output_index" in n.meta: | ||||||
output_nodes[n.meta["output_index"]] = n.name | ||||||
import pdb; pdb.set_trace() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
same thing but shorter. |
||||||
last_node.meta["output_nodes"] = output_nodes | ||||||
|
||||||
|
||||||
|
@@ -254,6 +261,8 @@ def mask_unused_inputs(n: torch.fx.Node) -> torch.fx.Node | None: | |||||
return n | ||||||
return None | ||||||
|
||||||
if node.name == "var_mean": | ||||||
import pdb; pdb.set_trace() | ||||||
assert len(input_names) == len(node._input_nodes) | ||||||
seen_names: dict[str, None] = {} | ||||||
node.args = map_arg(node.args, mask_unused_inputs) | ||||||
|
@@ -878,11 +887,11 @@ def _collect_multi_outputs( | |||||
Collect outputs for multi-output operations using metadata. | ||||||
""" | ||||||
# Check if this operation has multiple outputs using the new metadata | ||||||
assert "output_nodes" in node.meta | ||||||
assert "output_nodes" in node.meta, "Output nodes not in node.meta" | ||||||
output_nodes = node.meta["output_nodes"] | ||||||
outputs = [None] * len(output_nodes) | ||||||
all_nodes = {n.name: n for n in self.module.graph.nodes} # pyre-ignore[16] | ||||||
|
||||||
import pdb; pdb.set_trace() | ||||||
for idx, node_name in output_nodes.items(): | ||||||
if node_name == node.name: | ||||||
# This is the last node | ||||||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.