Skip to content

Add fused modulated RMSNorm#1225

Open
yukiu00 wants to merge 4 commits into
linkedin:mainfrom
yukiu00:codex-modulated-rms-norm
Open

Add fused modulated RMSNorm#1225
yukiu00 wants to merge 4 commits into
linkedin:mainfrom
yukiu00:codex-modulated-rms-norm

Conversation

@yukiu00
Copy link
Copy Markdown
Contributor

@yukiu00 yukiu00 commented May 13, 2026

Summary

Adds a fused Modulated RMSNorm kernel for the common pattern:

y = rms_norm(x, weight, eps) * (1 + scale) + shift

This is intended for FiLM/AdaRMSNorm-style modulation where RMSNorm output is scaled and optionally shifted by per-hidden-dimension modulation tensors.

Related issue: #1224

Changes

  • Add LigerModulatedRMSNormFunction with fused forward/backward Triton kernels.
  • Add LigerModulatedRMSNorm module wrapper and liger_modulated_rms_norm functional API.
  • Support optional RMSNorm weight, optional shift, and existing RMSNorm casting modes: llama, gemma, none.
  • Support modulation shapes that flatten to (hidden,), (batch, hidden), or row-wise (batch, seq, hidden) when rows divide the hidden-state rows.
  • Add correctness tests across fp32/bf16, casting modes, modulation broadcast modes, optional shift, and optional affine weight.
  • Add benchmark script for speed and memory comparison.

Benchmark

Environment: NVIDIA GeForce RTX 3090, bf16, hidden size 4096, row-wise scale and shift.

Speed, p50 ms

BT Mode Fused LigerRMSNorm + torch mod Speedup
1024 forward 0.0430 0.1091 2.5x
1024 backward 0.0763 0.1403 1.8x
1024 full 0.1403 0.2642 1.9x
4096 forward 0.1546 0.3999 2.6x
4096 backward 0.2540 0.4895 1.9x
4096 full 0.4787 1.3793 2.9x

Memory, p50 MB

BT Mode Fused LigerRMSNorm + torch mod Savings
1024 forward 32.02 56.02 42.8%
1024 backward 89.33 97.33 8.2%
1024 full 89.33 97.33 8.2%
4096 forward 128.03 224.03 42.9%
4096 backward 353.34 385.34 8.3%
4096 full 353.34 385.34 8.3%

@yukiu00 yukiu00 changed the title [codex] Add fused modulated RMSNorm Add fused modulated RMSNorm May 13, 2026
@yukiu00 yukiu00 force-pushed the codex-modulated-rms-norm branch from 259ed8c to 6c1c731 Compare May 13, 2026 11:55
yukiu00 added 3 commits May 13, 2026 21:15
- Added assertions to ensure scale and shift tensors have compatible shapes for broadcasting.
- Updated the documentation for LigerModulatedRMSNormFunction to clarify modulation behavior and broadcasting rules.
- Introduced a new test for mixed dtype modulation to verify correct handling of bf16 and fp32 inputs, ensuring gradients are preserved in their original dtype.

This improves robustness and clarity in the modulation implementation.
- Renamed `TorchModulatedRMSNorm` to `NaiveModulatedRMSNorm` and `LigerRMSNormWithTorchModulation` to `LigerRMSNormWithNaiveModulation` to better reflect their functionality.
- Added detailed docstrings for the new classes to clarify their purpose and implementation.
- Updated the benchmark setup to use the new class names and added support for the `huggingface` kernel provider.
- Enhanced the test suite by parameterizing the `test_correctness_functional` function to include dtype handling for better robustness in testing.
- Adjusted assertions in tests to use dynamic tolerances based on dtype, improving accuracy in validation.

This refactor improves code clarity and testing coverage for the modulated RMSNorm implementations.
@yukiu00 yukiu00 marked this pull request as ready for review May 13, 2026 16:34
@yukiu00
Copy link
Copy Markdown
Contributor Author

yukiu00 commented May 13, 2026

@Tcc0403 This PR is now open and ready for review.

Comment on lines +46 to +61
class LigerRMSNormWithNaiveModulation(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LigerRMSNormWithNaiveModulation is equivalent to NaiveModulatedRMSNorm above, but
uses the LigerRMSNorm kernel for the base normalization step (modulation is
applied in eager PyTorch). Useful to isolate the benefit of fusing modulation
into the norm kernel.
"""
super().__init__()
self.rms_norm = LigerRMSNorm(hidden_size=hidden_size, eps=eps, in_place=False)

def forward(self, hidden_states, scale, shift=None):
output = self.rms_norm(hidden_states) * (1 + scale)
if shift is not None:
output = output + shift
return output
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's nice to have but not necessary for comparison.

pytest.param(
0.0,
"none",
marks=pytest.mark.skipif(device == "npu", reason="Ascend NPU does not support this test"),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for caring npu support, we should mark it at function level.

Comment on lines +109 to +113
"bs, sl, hd",
[
(2, 16, 512),
(5, 7, 123),
],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set larger shapes close to actual models

Comment on lines +31 to +41
def _broadcast_modulation(modulation, hidden_states):
if modulation.dim() == 1:
return modulation
if modulation.shape == hidden_states.shape:
return modulation
if hidden_states.dim() == 3 and modulation.dim() == 2 and modulation.shape[0] == hidden_states.shape[0]:
return modulation[:, None, :]
if hidden_states.dim() == 2 and modulation.dim() == 2 and hidden_states.shape[0] % modulation.shape[0] == 0:
rows_per_modulation = hidden_states.shape[0] // modulation.shape[0]
return modulation.repeat_interleave(rows_per_modulation, dim=0)
raise AssertionError("Unsupported modulation shape for reference implementation.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it belong to reference nn.module class?

raise AssertionError("Unsupported modulation shape for reference implementation.")


class ModulatedRMSNormReference(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between this one and the one in benchmark? I prefer putting huggingface face or any reference that is currently being used as our test reference, and make the benchmark script point to test reference as single source of truth

Comment on lines +121 to +122
2e-1,
2e-2,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it the best tolerance we can pass?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants