The feature, motivation
I would like to add a fused Modulated RMSNorm kernel for AdaLN/FiLM-style conditioning patterns used in diffusion Transformers.
The operation is:
y = rms_norm(x, weight, eps) * (1 + scale) + shift
with shift optional:
y = rms_norm(x, weight, eps) * (1 + scale)
This pattern shows up in DiT-style blocks where a conditioning vector produces per-layer modulation parameters. In current PyTorch-style implementations, RMSNorm and the scale/shift modulation are usually separate ops, which means extra kernel launches and an intermediate normalized tensor.
Liger already has a fast RMSNorm kernel, so this seems like a small and useful extension for diffusion/video Transformer workloads.
Proposed API:
liger_modulated_rms_norm(
X,
W,
scale,
shift=None,
eps=1e-6,
offset=0.0,
casting_mode="llama",
in_place=True,
)
and a module wrapper:
LigerModulatedRMSNorm(hidden_size, eps=1e-6, ...)
Initial scope:
- Fuse RMSNorm + modulation scale + optional shift
- Support
scale/shift shaped per-row or per-batch, broadcast over tokens
- Support
W=None / elementwise_affine=False
- Keep existing RMSNorm options where possible:
offset, casting_mode, in_place
- Add correctness tests for forward and backward against a PyTorch reference
- Add a benchmark comparing:
- PyTorch/HF RMSNorm + modulation
LigerRMSNorm + modulation
- fused
LigerModulatedRMSNorm
Out of scope for the first PR:
- Fusing the modulation MLP that produces
scale and shift
- Fusing residual gates
- Diffusers monkey-patching
- LayerNorm + modulation
Alternatives
The current alternative is to use LigerRMSNorm and then apply scale/shift with regular PyTorch ops. That works, but still materializes the normalized output and launches extra elementwise kernels.
Additional context
This is a common pattern in recent diffusion Transformer architectures:
The feature, motivation
I would like to add a fused Modulated RMSNorm kernel for AdaLN/FiLM-style conditioning patterns used in diffusion Transformers.
The operation is:
with
shiftoptional:This pattern shows up in DiT-style blocks where a conditioning vector produces per-layer modulation parameters. In current PyTorch-style implementations, RMSNorm and the scale/shift modulation are usually separate ops, which means extra kernel launches and an intermediate normalized tensor.
Liger already has a fast
RMSNormkernel, so this seems like a small and useful extension for diffusion/video Transformer workloads.Proposed API:
and a module wrapper:
Initial scope:
scale/shiftshaped per-row or per-batch, broadcast over tokensW=None/elementwise_affine=Falseoffset,casting_mode,in_placeLigerRMSNorm+ modulationLigerModulatedRMSNormOut of scope for the first PR:
scaleandshiftAlternatives
The current alternative is to use
LigerRMSNormand then apply scale/shift with regular PyTorch ops. That works, but still materializes the normalized output and launches extra elementwise kernels.Additional context
This is a common pattern in recent diffusion Transformer architectures:
https://ojs.aaai.org/index.php/AAAI/article/view/11671
https://www.wpeebles.com/DiT
https://github.com/facebookresearch/DiT/blob/main/models.py
shift,scale, andgatemodulation:https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
https://huggingface.co/docs/diffusers/api/normalization
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py