-
Notifications
You must be signed in to change notification settings - Fork 273
add rmsnorm-add fusion kernel #996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @theNiemand, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a highly optimized, fused kernel for the common 'add and RMSNorm' operation, crucial for improving the efficiency of large language models. It leverages Triton to achieve significant performance gains by executing these sequential operations in a single, in-place kernel. Additionally, it provides new utility functions for robust testing and performance measurement of custom CUDA kernels, alongside a comprehensive test suite for the new fused operation.
Highlights
- New Fused Kernel: Introduced a new Triton-based
fused_add_rmsnorm_inplace
kernel (lightllm/models/llama/triton_kernel/fused_add_rmsnorm_inplace.py
) for efficient, in-place computation of residual addition followed by RMS Normalization. This kernel is designed to optimize performance by fusing these common operations, particularly relevant for transformer architectures like Llama. - Utility Functions: Added general-purpose utility functions to
lightllm/utils/custom_kernel_utis.py
: anerror
function for calculating Signal-to-Noise Ratio (SNR) between tensors (useful for numerical accuracy validation) and abenchmark
function to measure CUDA kernel performance (TFLOPS and memory throughput). - Comprehensive Testing: Included a dedicated unit test file (
unit_tests/models/llama/test_fused_add_rmsnorm.py
) to verify the accuracy and benchmark the performance of the new fused RMSNorm kernel. The tests compare the fused kernel's output against a reference PyTorch implementation across various input shapes and use the newly addederror
andbenchmark
utilities.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new fused Triton kernel for add + rmsnorm
operations, along with utility functions for benchmarking and error calculation, and corresponding unit tests. The core idea of the fused kernel is excellent for performance. However, the Triton kernel wrapper incorrectly restricts its use to smaller feature dimensions, the new benchmark
utility is not safe for functions with in-place side effects, and the tests for the new kernel include dimensions that will cause failures and rely on the flawed benchmarking tool. Additionally, there are medium-severity issues related to incorrect docstrings and type hints in the new utility functions. Addressing the critical issues is essential for the correctness and reliability of these new components.
if N > BLOCK_SIZE: | ||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") | ||
|
||
# heuristics for number of warps | ||
num_warps = min(max(BLOCK_SIZE // 256, 1), 4) | ||
num_warps = triton.next_power_of_2(num_warps) | ||
if BLOCK_SIZE > 16384: | ||
BLOCK_SIZE = 16384 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check if N > BLOCK_SIZE:
on line 133 disables the multi-block execution path (the else
block) in the Triton kernel. The kernel is designed to handle feature dimensions N
larger than BLOCK_SIZE
, but this check prevents it. This makes the else
branch in the kernel dead code and limits the kernel's applicability to large feature dimensions. The hard cap if BLOCK_SIZE > 16384:
on line 139 is also overly restrictive. Remove the error check and the hard cap.
if N > BLOCK_SIZE: | |
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") | |
# heuristics for number of warps | |
num_warps = min(max(BLOCK_SIZE // 256, 1), 4) | |
num_warps = triton.next_power_of_2(num_warps) | |
if BLOCK_SIZE > 16384: | |
BLOCK_SIZE = 16384 | |
# heuristics for number of warps | |
num_warps = min(max(BLOCK_SIZE // 256, 1), 4) | |
num_warps = triton.next_power_of_2(num_warps) |
with torch.no_grad(): | ||
output = func(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initial call to func
is problematic for functions that perform in-place operations, as it will modify the input tensors. All subsequent calls during the warm-up and measurement phases will then operate on this modified data, leading to incorrect benchmark results. Ensure that each timed call operates on a fresh, unmodified copy of the input data.
def setUp(self): | ||
"""Set up common test parameters.""" | ||
self.tokens = [1, 2, 3, 1024, 2048, 4096, 8192, 16384] | ||
self.dims = [1, 2, 3, 512, 1024, 1025, 3200, 16384, 32768] # [512, 1024, 1032, 1536, 3200, 6144, 12800] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Compute SNR between y_pred(tensor) and y_real(tensor) | ||
|
||
SNR can be calcualted as following equation: | ||
|
||
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 | ||
|
||
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. | ||
|
||
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) | ||
|
||
Args: | ||
y_pred (torch.Tensor): _description_ | ||
y_real (torch.Tensor): _description_ | ||
reduction (str, optional): _description_. Defaults to 'mean'. | ||
|
||
Raises: | ||
ValueError: _description_ | ||
ValueError: _description_ | ||
|
||
Returns: | ||
torch.Tensor: _description_ | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint for the error
function is -> torch.Tensor
, but the function returns a float via .item()
. It should be -> float
. Also, the docstring contains placeholder text and mentions a reduction
parameter that doesn't exist. The formula in the docstring is for an element-wise ratio mean, but the implementation calculates a ratio of summed powers. Update the function signature and the docstring for clarity and correctness.
def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> float:
"""
Compute SNR error between y_pred(tensor) and y_real(tensor).
The SNR error is calculated as the ratio of noise power to signal power:
`sum((y_pred - y_real)^2) / sum(y_real^2)`
Args:
y_pred (torch.Tensor): The predicted tensor.
y_real (torch.Tensor): The ground truth tensor.
Raises:
ValueError: If tensors have different shapes.
Returns:
float: The computed SNR error value.
"""
def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs): | ||
""" | ||
A decorator function to assist in performance testing of CUDA operations. | ||
|
||
This function will: | ||
1. Automatically determine whether any parameters in the argument list, | ||
or the output of the `func`, are of type `torch.Tensor`. | ||
2. If so, calculate the memory usage of the input and output tensors | ||
on the GPU (based on their data type and `torch.numel()`). | ||
3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations. | ||
4. Record the execution time during these iterations. | ||
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput. | ||
|
||
Args: | ||
func (function): The function to benchmark. | ||
shape (list of int): The problem shape. | ||
tflops (float): The computational workload (in TFLOPS) per call of `func`. | ||
steps (int): The number of times the function is executed during benchmarking. | ||
*args: Positional arguments to be passed to the `func`. | ||
**kwargs: Keyword arguments to be passed to the `func`. | ||
|
||
Returns: | ||
function result | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for the benchmark
function is inaccurate. It's described as a "decorator function", but it's a regular utility function. It claims to "Establish a CUDA graph", but the implementation uses torch.cuda.Event
for timing, not CUDA graphs. The Returns
section says it returns the "function result", but it actually returns None
and only prints the metrics. Update the docstring to accurately describe the function's behavior.
def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
"""
A utility function to assist in performance testing of CUDA operations.
This function will:
1. Automatically determine whether any parameters in the argument list,
or the output of the `func`, are of type `torch.Tensor`.
2. If so, calculate the memory usage of the input and output tensors
on the GPU (based on their data type and `torch.numel()`).
3. Execute `func` repeatedly for `steps` iterations after a warm-up period.
4. Record the execution time during these iterations using CUDA events.
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.
Args:
func (function): The function to benchmark.
shape (list of int): The problem shape.
tflops (float): The computational workload (in TFLOPS) per call of `func`.
steps (int): The number of times the function is executed during benchmarking.
*args: Positional arguments to be passed to the `func`.
**kwargs: Keyword arguments to be passed to the `func`.
Returns:
None
"""
def test_performance(self): | ||
"""Test the performance of rmsnorm using benchmark.""" | ||
for token_num in self.tokens: | ||
for dim in self.dims: | ||
with self.subTest(shape=[token_num, dim]): | ||
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) | ||
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) | ||
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype) | ||
|
||
shape = [token_num, dim] | ||
tflops = 0.0 | ||
benchmark(self.torch_add_rmsnorm, shape, tflops, 100, X, R, W) | ||
benchmark(fused_add_rmsnorm_inplace, shape, tflops, 100, X, R, W, eps=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This performance test uses the benchmark
utility, which is flawed for in-place operations like torch_add_rmsnorm
and fused_add_rmsnorm_inplace
. The benchmark function modifies the input tensors on its first call, causing subsequent warm-up and measurement runs to use altered data. This will lead to unreliable performance results. The benchmark
function in custom_kernel_utis.py
needs to be fixed to handle functions with side effects correctly before this performance test can be considered reliable.
No description provided.