Skip to content

add rmsnorm-add fusion kernel #996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

theNiemand
Copy link

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @theNiemand, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized, fused kernel for the common 'add and RMSNorm' operation, crucial for improving the efficiency of large language models. It leverages Triton to achieve significant performance gains by executing these sequential operations in a single, in-place kernel. Additionally, it provides new utility functions for robust testing and performance measurement of custom CUDA kernels, alongside a comprehensive test suite for the new fused operation.

Highlights

  • New Fused Kernel: Introduced a new Triton-based fused_add_rmsnorm_inplace kernel (lightllm/models/llama/triton_kernel/fused_add_rmsnorm_inplace.py) for efficient, in-place computation of residual addition followed by RMS Normalization. This kernel is designed to optimize performance by fusing these common operations, particularly relevant for transformer architectures like Llama.
  • Utility Functions: Added general-purpose utility functions to lightllm/utils/custom_kernel_utis.py: an error function for calculating Signal-to-Noise Ratio (SNR) between tensors (useful for numerical accuracy validation) and a benchmark function to measure CUDA kernel performance (TFLOPS and memory throughput).
  • Comprehensive Testing: Included a dedicated unit test file (unit_tests/models/llama/test_fused_add_rmsnorm.py) to verify the accuracy and benchmark the performance of the new fused RMSNorm kernel. The tests compare the fused kernel's output against a reference PyTorch implementation across various input shapes and use the newly added error and benchmark utilities.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new fused Triton kernel for add + rmsnorm operations, along with utility functions for benchmarking and error calculation, and corresponding unit tests. The core idea of the fused kernel is excellent for performance. However, the Triton kernel wrapper incorrectly restricts its use to smaller feature dimensions, the new benchmark utility is not safe for functions with in-place side effects, and the tests for the new kernel include dimensions that will cause failures and rely on the flawed benchmarking tool. Additionally, there are medium-severity issues related to incorrect docstrings and type hints in the new utility functions. Addressing the critical issues is essential for the correctness and reliability of these new components.

Comment on lines +133 to +140
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
num_warps = triton.next_power_of_2(num_warps)
if BLOCK_SIZE > 16384:
BLOCK_SIZE = 16384

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The check if N > BLOCK_SIZE: on line 133 disables the multi-block execution path (the else block) in the Triton kernel. The kernel is designed to handle feature dimensions N larger than BLOCK_SIZE, but this check prevents it. This makes the else branch in the kernel dead code and limits the kernel's applicability to large feature dimensions. The hard cap if BLOCK_SIZE > 16384: on line 139 is also overly restrictive. Remove the error check and the hard cap.

Suggested change
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
num_warps = triton.next_power_of_2(num_warps)
if BLOCK_SIZE > 16384:
BLOCK_SIZE = 16384
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
num_warps = triton.next_power_of_2(num_warps)

Comment on lines +208 to +209
with torch.no_grad():
output = func(*args, **kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initial call to func is problematic for functions that perform in-place operations, as it will modify the input tensors. All subsequent calls during the warm-up and measurement phases will then operate on this modified data, leading to incorrect benchmark results. Ensure that each timed call operates on a fresh, unmodified copy of the input data.

def setUp(self):
"""Set up common test parameters."""
self.tokens = [1, 2, 3, 1024, 2048, 4096, 8192, 16384]
self.dims = [1, 2, 3, 512, 1024, 1025, 3200, 16384, 32768] # [512, 1024, 1032, 1536, 3200, 6144, 12800]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dimension 32768 will cause the test to fail with the current implementation of fused_add_rmsnorm_inplace because the wrapper function raises a RuntimeError for dimensions greater than 16384. Remove or adjust this test case.

Comment on lines +130 to +153
def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)

SNR can be calcualted as following equation:

SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2

if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.

SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)

Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.

Raises:
ValueError: _description_
ValueError: _description_

Returns:
torch.Tensor: _description_
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for the error function is -> torch.Tensor, but the function returns a float via .item(). It should be -> float. Also, the docstring contains placeholder text and mentions a reduction parameter that doesn't exist. The formula in the docstring is for an element-wise ratio mean, but the implementation calculates a ratio of summed powers. Update the function signature and the docstring for clarity and correctness.

def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> float:
    """
    Compute SNR error between y_pred(tensor) and y_real(tensor).

    The SNR error is calculated as the ratio of noise power to signal power:
    `sum((y_pred - y_real)^2) / sum(y_real^2)`

    Args:
        y_pred (torch.Tensor): The predicted tensor.
        y_real (torch.Tensor): The ground truth tensor.

    Raises:
        ValueError: If tensors have different shapes.

    Returns:
        float: The computed SNR error value.
    """

Comment on lines +168 to +191
def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
"""
A decorator function to assist in performance testing of CUDA operations.

This function will:
1. Automatically determine whether any parameters in the argument list,
or the output of the `func`, are of type `torch.Tensor`.
2. If so, calculate the memory usage of the input and output tensors
on the GPU (based on their data type and `torch.numel()`).
3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations.
4. Record the execution time during these iterations.
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.

Args:
func (function): The function to benchmark.
shape (list of int): The problem shape.
tflops (float): The computational workload (in TFLOPS) per call of `func`.
steps (int): The number of times the function is executed during benchmarking.
*args: Positional arguments to be passed to the `func`.
**kwargs: Keyword arguments to be passed to the `func`.

Returns:
function result
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for the benchmark function is inaccurate. It's described as a "decorator function", but it's a regular utility function. It claims to "Establish a CUDA graph", but the implementation uses torch.cuda.Event for timing, not CUDA graphs. The Returns section says it returns the "function result", but it actually returns None and only prints the metrics. Update the docstring to accurately describe the function's behavior.

def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
    """
    A utility function to assist in performance testing of CUDA operations.

    This function will:
    1. Automatically determine whether any parameters in the argument list,
       or the output of the `func`, are of type `torch.Tensor`.
    2. If so, calculate the memory usage of the input and output tensors
       on the GPU (based on their data type and `torch.numel()`).
    3. Execute `func` repeatedly for `steps` iterations after a warm-up period.
    4. Record the execution time during these iterations using CUDA events.
    5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.

    Args:
        func (function): The function to benchmark.
        shape (list of int): The problem shape.
        tflops (float): The computational workload (in TFLOPS) per call of `func`.
        steps (int): The number of times the function is executed during benchmarking.
        *args: Positional arguments to be passed to the `func`.
        **kwargs: Keyword arguments to be passed to the `func`.

    Returns:
        None
    """

Comment on lines +47 to +59
def test_performance(self):
"""Test the performance of rmsnorm using benchmark."""
for token_num in self.tokens:
for dim in self.dims:
with self.subTest(shape=[token_num, dim]):
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype)

shape = [token_num, dim]
tflops = 0.0
benchmark(self.torch_add_rmsnorm, shape, tflops, 100, X, R, W)
benchmark(fused_add_rmsnorm_inplace, shape, tflops, 100, X, R, W, eps=1e-6)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This performance test uses the benchmark utility, which is flawed for in-place operations like torch_add_rmsnorm and fused_add_rmsnorm_inplace. The benchmark function modifies the input tensors on its first call, causing subsequent warm-up and measurement runs to use altered data. This will lead to unreliable performance results. The benchmark function in custom_kernel_utis.py needs to be fixed to handle functions with side effects correctly before this performance test can be considered reliable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant