Skip to content

Add fused Modulated RMSNorm for DiT-style scale/shift conditioning #1224

@yukiu00

Description

@yukiu00

The feature, motivation

I would like to add a fused Modulated RMSNorm kernel for AdaLN/FiLM-style conditioning patterns used in diffusion Transformers.

The operation is:

y = rms_norm(x, weight, eps) * (1 + scale) + shift

with shift optional:

y = rms_norm(x, weight, eps) * (1 + scale)

This pattern shows up in DiT-style blocks where a conditioning vector produces per-layer modulation parameters. In current PyTorch-style implementations, RMSNorm and the scale/shift modulation are usually separate ops, which means extra kernel launches and an intermediate normalized tensor.

Liger already has a fast RMSNorm kernel, so this seems like a small and useful extension for diffusion/video Transformer workloads.

Proposed API:

liger_modulated_rms_norm(
    X,
    W,
    scale,
    shift=None,
    eps=1e-6,
    offset=0.0,
    casting_mode="llama",
    in_place=True,
)

and a module wrapper:

LigerModulatedRMSNorm(hidden_size, eps=1e-6, ...)

Initial scope:

  • Fuse RMSNorm + modulation scale + optional shift
  • Support scale/shift shaped per-row or per-batch, broadcast over tokens
  • Support W=None / elementwise_affine=False
  • Keep existing RMSNorm options where possible: offset, casting_mode, in_place
  • Add correctness tests for forward and backward against a PyTorch reference
  • Add a benchmark comparing:
    • PyTorch/HF RMSNorm + modulation
    • LigerRMSNorm + modulation
    • fused LigerModulatedRMSNorm

Out of scope for the first PR:

  • Fusing the modulation MLP that produces scale and shift
  • Fusing residual gates
  • Diffusers monkey-patching
  • LayerNorm + modulation

Alternatives

The current alternative is to use LigerRMSNorm and then apply scale/shift with regular PyTorch ops. That works, but still materializes the normalized output and launches extra elementwise kernels.

Additional context

This is a common pattern in recent diffusion Transformer architectures:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions