Add fused modulated RMSNorm#1225
Conversation
259ed8c to
6c1c731
Compare
- Added assertions to ensure scale and shift tensors have compatible shapes for broadcasting. - Updated the documentation for LigerModulatedRMSNormFunction to clarify modulation behavior and broadcasting rules. - Introduced a new test for mixed dtype modulation to verify correct handling of bf16 and fp32 inputs, ensuring gradients are preserved in their original dtype. This improves robustness and clarity in the modulation implementation.
- Renamed `TorchModulatedRMSNorm` to `NaiveModulatedRMSNorm` and `LigerRMSNormWithTorchModulation` to `LigerRMSNormWithNaiveModulation` to better reflect their functionality. - Added detailed docstrings for the new classes to clarify their purpose and implementation. - Updated the benchmark setup to use the new class names and added support for the `huggingface` kernel provider. - Enhanced the test suite by parameterizing the `test_correctness_functional` function to include dtype handling for better robustness in testing. - Adjusted assertions in tests to use dynamic tolerances based on dtype, improving accuracy in validation. This refactor improves code clarity and testing coverage for the modulated RMSNorm implementations.
|
@Tcc0403 This PR is now open and ready for review. |
| class LigerRMSNormWithNaiveModulation(nn.Module): | ||
| def __init__(self, hidden_size, eps=1e-6): | ||
| """ | ||
| LigerRMSNormWithNaiveModulation is equivalent to NaiveModulatedRMSNorm above, but | ||
| uses the LigerRMSNorm kernel for the base normalization step (modulation is | ||
| applied in eager PyTorch). Useful to isolate the benefit of fusing modulation | ||
| into the norm kernel. | ||
| """ | ||
| super().__init__() | ||
| self.rms_norm = LigerRMSNorm(hidden_size=hidden_size, eps=eps, in_place=False) | ||
|
|
||
| def forward(self, hidden_states, scale, shift=None): | ||
| output = self.rms_norm(hidden_states) * (1 + scale) | ||
| if shift is not None: | ||
| output = output + shift | ||
| return output |
There was a problem hiding this comment.
I feel it's nice to have but not necessary for comparison.
| pytest.param( | ||
| 0.0, | ||
| "none", | ||
| marks=pytest.mark.skipif(device == "npu", reason="Ascend NPU does not support this test"), |
There was a problem hiding this comment.
thanks for caring npu support, we should mark it at function level.
| "bs, sl, hd", | ||
| [ | ||
| (2, 16, 512), | ||
| (5, 7, 123), | ||
| ], |
There was a problem hiding this comment.
set larger shapes close to actual models
| def _broadcast_modulation(modulation, hidden_states): | ||
| if modulation.dim() == 1: | ||
| return modulation | ||
| if modulation.shape == hidden_states.shape: | ||
| return modulation | ||
| if hidden_states.dim() == 3 and modulation.dim() == 2 and modulation.shape[0] == hidden_states.shape[0]: | ||
| return modulation[:, None, :] | ||
| if hidden_states.dim() == 2 and modulation.dim() == 2 and hidden_states.shape[0] % modulation.shape[0] == 0: | ||
| rows_per_modulation = hidden_states.shape[0] // modulation.shape[0] | ||
| return modulation.repeat_interleave(rows_per_modulation, dim=0) | ||
| raise AssertionError("Unsupported modulation shape for reference implementation.") |
There was a problem hiding this comment.
shouldn't it belong to reference nn.module class?
| raise AssertionError("Unsupported modulation shape for reference implementation.") | ||
|
|
||
|
|
||
| class ModulatedRMSNormReference(nn.Module): |
There was a problem hiding this comment.
what's the difference between this one and the one in benchmark? I prefer putting huggingface face or any reference that is currently being used as our test reference, and make the benchmark script point to test reference as single source of truth
| 2e-1, | ||
| 2e-2, |
There was a problem hiding this comment.
is it the best tolerance we can pass?
Summary
Adds a fused Modulated RMSNorm kernel for the common pattern:
This is intended for FiLM/AdaRMSNorm-style modulation where RMSNorm output is scaled and optionally shifted by per-hidden-dimension modulation tensors.
Related issue: #1224
Changes
LigerModulatedRMSNormFunctionwith fused forward/backward Triton kernels.LigerModulatedRMSNormmodule wrapper andliger_modulated_rms_normfunctional API.llama,gemma,none.(hidden,),(batch, hidden), or row-wise(batch, seq, hidden)when rows divide the hidden-state rows.Benchmark
Environment: NVIDIA GeForce RTX 3090, bf16, hidden size 4096, row-wise
scaleandshift.Speed, p50 ms
Memory, p50 MB