diff --git a/benchmark/scripts/benchmark_rms_norm.py b/benchmark/scripts/benchmark_rms_norm.py index 6bcd56a83..11b288ddb 100644 --- a/benchmark/scripts/benchmark_rms_norm.py +++ b/benchmark/scripts/benchmark_rms_norm.py @@ -41,6 +41,7 @@ def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu M = extra_benchmark_config["M"] eps = extra_benchmark_config["eps"] dtype = extra_benchmark_config["dtype"] + freeze_weight = extra_benchmark_config.get("freeze_weight", False) x_shape = (M, N) @@ -51,6 +52,10 @@ def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu dy = torch.randn_like(x) x.requires_grad_(True) + if freeze_weight: + triton_rms.weight.requires_grad_(False) + llama_rms.weight.requires_grad_(False) + # utility functions def y_fwd(): @@ -60,10 +65,16 @@ def y_fwd(): if provider == "huggingface": return llama_rms(x) + grad_to_none = [x] + if provider == "liger" and triton_rms.weight.requires_grad: + grad_to_none.append(triton_rms.weight) + elif provider == "huggingface" and llama_rms.weight.requires_grad: + grad_to_none.append(llama_rms.weight) + if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench( y_fwd, - grad_to_none=[x], + grad_to_none=grad_to_none, rep=500, quantiles=QUANTILES, ) @@ -71,7 +82,7 @@ def y_fwd(): y = y_fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(dy, retain_graph=True), - grad_to_none=[x], + grad_to_none=grad_to_none, rep=500, quantiles=QUANTILES, ) @@ -83,7 +94,7 @@ def full(): ms_50, ms_20, ms_80 = triton.testing.do_bench( full, - grad_to_none=[x], + grad_to_none=grad_to_none, rep=500, quantiles=QUANTILES, ) @@ -103,6 +114,7 @@ def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO M = extra_benchmark_config["M"] eps = extra_benchmark_config["eps"] dtype = extra_benchmark_config["dtype"] + freeze_weight = extra_benchmark_config.get("freeze_weight", False) x_shape = (M, N) @@ -113,6 +125,10 @@ def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO dy = torch.randn_like(x) x.requires_grad_(True) + if freeze_weight: + triton_rms.weight.requires_grad_(False) + llama_rms.weight.requires_grad_(False) + # utility functions def y_fwd(): if provider == "liger": @@ -142,7 +158,10 @@ def full(): "x_label": "hidden size", "x_values": [2**i for i in range(10, 16)], "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}], + "extra_benchmark_configs": [ + {"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6, "freeze_weight": False}, + {"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6, "freeze_weight": True}, + ], "overwrite": args.overwrite, } diff --git a/benchmark/scripts/benchmark_rms_norm_mixed.py b/benchmark/scripts/benchmark_rms_norm_mixed.py new file mode 100644 index 000000000..4ebc55bc9 --- /dev/null +++ b/benchmark/scripts/benchmark_rms_norm_mixed.py @@ -0,0 +1,223 @@ +import math + +import torch +import torch.nn as nn +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.utils import infer_device + +device = infer_device() + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class LoRALinear(nn.Module): + def __init__(self, in_features, out_features, r=8, alpha=16.0, bias=False): + super().__init__() + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + self.weight.requires_grad_(False) # base weight frozen (LoRA) + self.lora_A = nn.Parameter(torch.empty(r, in_features)) + self.lora_B = nn.Parameter(torch.empty(out_features, r)) + self.scaling = alpha / r + if bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + # Init with small random values so grads flow through both A and B + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_B, a=math.sqrt(5)) + + def forward(self, x): + base = x @ self.weight.t() + lora = (x @ self.lora_A.t()) @ self.lora_B.t() + out = base + lora * self.scaling + if self.bias is not None: + out = out + self.bias + return out + + +class MixedBlock(nn.Module): + def __init__(self, norm_cls, hidden_size, eps, lora_r, lora_alpha): + super().__init__() + self.norm = norm_cls(hidden_size=hidden_size, eps=eps) + self.proj = LoRALinear(hidden_size, hidden_size, r=lora_r, alpha=lora_alpha) + + def forward(self, x): + return self.proj(self.norm(x)) + + +def _build_block(provider, hidden_size, eps, dtype, lora_r, lora_alpha, freeze_norm_weight): + norm_cls = LigerRMSNorm if provider == "liger" else LlamaRMSNorm + block = MixedBlock(norm_cls, hidden_size=hidden_size, eps=eps, lora_r=lora_r, lora_alpha=lora_alpha) + block = block.to(device=device, dtype=dtype) + if freeze_norm_weight: + block.norm.weight.requires_grad_(False) + return block + + +def _grad_to_none_tensors(module, x): + tensors = [x] + for p in module.parameters(): + if p.requires_grad: + tensors.append(p) + return tensors + + +def bench_speed_rms_norm_mixed(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra = input.extra_benchmark_config + M = extra["M"] + eps = extra["eps"] + dtype = extra["dtype"] + lora_r = extra["lora_r"] + lora_alpha = extra["lora_alpha"] + freeze_norm_weight = extra.get("freeze_norm_weight", True) + + x_shape = (M, N) + + block = _build_block(provider, N, eps, dtype, lora_r, lora_alpha, freeze_norm_weight) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + return block(x) + + grad_to_none = _grad_to_none_tensors(block, x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + y_fwd, + grad_to_none=grad_to_none, + rep=500, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + grad_to_none=grad_to_none, + rep=500, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=grad_to_none, + rep=500, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_rms_norm_mixed(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + N = input.x + provider = input.kernel_provider + + extra = input.extra_benchmark_config + M = extra["M"] + eps = extra["eps"] + dtype = extra["dtype"] + lora_r = extra["lora_r"] + lora_alpha = extra["lora_alpha"] + freeze_norm_weight = extra.get("freeze_norm_weight", True) + + x_shape = (M, N) + + block = _build_block(provider, N, eps, dtype, lora_r, lora_alpha, freeze_norm_weight) + + x = torch.randn(x_shape, dtype=dtype, device=device) + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + return block(x) + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "rms_norm_mixed", + "x_name": "H", + "x_label": "hidden size", + "x_values": [2**i for i in range(10, 16)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "M": 2048, + "dtype": torch.bfloat16, + "eps": 1e-6, + "lora_r": 8, + "lora_alpha": 16.0, + "freeze_norm_weight": True, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_rms_norm_mixed, + kernel_operation_modes=["forward", "full", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_rms_norm_mixed, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/src/liger_kernel/ops/fused_add_rms_norm.py b/src/liger_kernel/ops/fused_add_rms_norm.py index 866377687..e14b4eeb4 100644 --- a/src/liger_kernel/ops/fused_add_rms_norm.py +++ b/src/liger_kernel/ops/fused_add_rms_norm.py @@ -140,6 +140,7 @@ def _fused_add_rms_norm_backward_kernel( casting_mode: tl.constexpr, BLOCK_SIZE: tl.constexpr, has_dS_out: tl.constexpr, + compute_dW: tl.constexpr, ): """ This kernel is adapted from the rms_norm backward kernel, and is adapted to support the residual @@ -161,10 +162,10 @@ def _fused_add_rms_norm_backward_kernel( col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) W_row = W_row + offset + if compute_dW: + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) for row_idx in range(row_start, row_end): dy_base = dY_ptr + row_idx * dY_row_stride @@ -202,16 +203,18 @@ def _fused_add_rms_norm_backward_kernel( else: dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) - # calculate the gradient of W - if casting_mode == _CASTING_MODE_LLAMA: - dW_row += dY_row * (X_row * rstd_row).to(X_dtype) - else: - # here X_row is already in fp32 (see previous if block) - dW_row += dY_row * (X_row * rstd_row) + if compute_dW: + # calculate the gradient of W + if casting_mode == _CASTING_MODE_LLAMA: + dW_row += dY_row * (X_row * rstd_row).to(X_dtype) + else: + # here X_row is already in fp32 (see previous if block) + dW_row += dY_row * (X_row * rstd_row) tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask) - tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) + if compute_dW: + tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) _str_to_casting_mode = { @@ -276,11 +279,14 @@ def fused_add_rms_norm_forward(X, R, W, eps, offset, casting_mode): return Y.view(*shape), S.view(*shape), RSTD, BLOCK_SIZE, num_warps, casting_mode -def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place): +def fused_add_rms_norm_backward( + dY, dS_out, S, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, compute_dW +): shape = dY.shape dim = shape[-1] dY = dY.view(-1, dim) - dS_out = dS_out.view(-1, dim) + if dS_out is not None: + dS_out = dS_out.view(-1, dim) S = S.view(-1, dim) n_rows, n_cols = dY.shape @@ -292,8 +298,12 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL elif S.device.type == "npu": sm_count = get_npu_core_count() - # fp32 for numerical stability especially. - _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + compute_dW = compute_dW and W is not None + if compute_dW: + # fp32 for numerical stability especially. + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + else: + _dW = None if n_cols > BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") @@ -325,8 +335,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL W.stride(0), RSTD, RSTD.stride(0), - _dW, - _dW.stride(0), + _dW if compute_dW else S, + _dW.stride(0) if compute_dW else 0, n_rows, n_cols, offset, @@ -335,11 +345,15 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, has_dS_out=dS_out is not None, + compute_dW=compute_dW, **kernel_args, # XPU-specific optimization ) dX = dX.view(*shape) - dW = _dW.sum(dim=0).to(W.dtype) + if compute_dW: + dW = _dW.sum(dim=0).to(W.dtype) + else: + dW = None return dX, dX, dW # dR is equal to dX @@ -394,6 +408,7 @@ def backward(ctx, dY, dS_out): Y: (B, T, H) or (BxT, H) """ S, W, RSTD = ctx.saved_tensors + need_dW = ctx.needs_input_grad[2] dX, dR, dW = fused_add_rms_norm_backward( dY, dS_out, @@ -405,6 +420,9 @@ def backward(ctx, dY, dS_out): ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, + compute_dW=need_dW, ) + if not need_dW: + dW = None return dX, dR, dW, None, None, None, None, None diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 859c546e3..7ed302639 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -113,6 +113,8 @@ def _group_norm_backward_kernel( channels_per_group: tl.constexpr, # number of groups in group norm BLOCK_SIZE: tl.constexpr, dtype: tl.constexpr, + compute_dW: tl.constexpr, + compute_dB: tl.constexpr, ): """ References: @@ -144,8 +146,10 @@ def _group_norm_backward_kernel( # We need to compute the sum terms of the backprop equations across all channels in the group for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): - dW = 0.0 - dB = 0.0 + if compute_dW: + dW = 0.0 + if compute_dB: + dB = 0.0 # Move the pointers to the correct channel W = tl.load(W_ptr + channel_idx) for i in tl.range(0, hidden_size, BLOCK_SIZE): @@ -163,16 +167,20 @@ def _group_norm_backward_kernel( ) x_hat = (X - mean) * rstd - dW += tl.sum(UPSTREAM_grad * x_hat) - dB += tl.sum(UPSTREAM_grad) + if compute_dW: + dW += tl.sum(UPSTREAM_grad * x_hat) + if compute_dB: + dB += tl.sum(UPSTREAM_grad) wdy = W * UPSTREAM_grad c1 += tl.sum(x_hat * wdy) c2 += tl.sum(wdy) # Need to ensure additions to the same channel are atomic - tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) - tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) + if compute_dW: + tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) + if compute_dB: + tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) N = hidden_size * channels_per_group c1 = c1 / N @@ -237,7 +245,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE -def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): +def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups, compute_dW, compute_dB): shape = dY.shape batch_size = shape[0] hidden_size = dY.shape[-1] @@ -248,8 +256,14 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): dtype=X.dtype, device=X.device, ) - DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) - DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) + if compute_dW: + DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) + else: + DW = None + if compute_dB: + DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) + else: + DB = None triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) @@ -263,13 +277,15 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): Mean.stride(1), RSTD, DX, - DW, - DB, + DW if compute_dW else X, + DB if compute_dB else X, dY, hidden_size, channels_per_group, BLOCK_SIZE=BLOCK_SIZE, dtype=triton_dtype, + compute_dW=compute_dW, + compute_dB=compute_dB, ) # Return tensors in the original shape @@ -305,5 +321,9 @@ def forward( @ensure_contiguous def backward(ctx, dY): X, W, B, Mean, RSTD = ctx.saved_tensors - DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) + need_dW = ctx.needs_input_grad[1] + need_dB = ctx.needs_input_grad[2] + DX, DW, DB = group_norm_backward( + dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups, compute_dW=need_dW, compute_dB=need_dB + ) return DX, DW, DB, None, None, None diff --git a/src/liger_kernel/ops/layer_norm.py b/src/liger_kernel/ops/layer_norm.py index e8ac6b5f3..5fd88d23e 100644 --- a/src/liger_kernel/ops/layer_norm.py +++ b/src/liger_kernel/ops/layer_norm.py @@ -107,6 +107,8 @@ def _layer_norm_backward_kernel( n_cols, rows_per_program: tl.constexpr, BLOCK_SIZE: tl.constexpr, + compute_dW: tl.constexpr, + compute_dB: tl.constexpr, ): """ References: @@ -119,8 +121,10 @@ def _layer_norm_backward_kernel( cols = tl.arange(0, BLOCK_SIZE) mask = cols < n_cols - dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + if compute_dW: + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + if compute_dB: + db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) # Pre-load weights once (same optimization as forward pass) w = tl.load(W_ptr + cols, mask=mask, other=0.0) @@ -159,11 +163,15 @@ def _layer_norm_backward_kernel( # Accumulate weight and bias gradients for this thread block's assigned rows dw = dy_f32 * x_hat db = dy_f32 - dW_row += dw - db_row += db + if compute_dW: + dW_row += dw + if compute_dB: + db_row += db - tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask) - tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask) + if compute_dW: + tl.store(DW_ptr + row_block_id * stride_dw + cols, dW_row, mask=mask) + if compute_dB: + tl.store(DB_ptr + row_block_id * stride_db + cols, db_row, mask=mask) def layer_norm_forward(X, W, B, eps): @@ -227,7 +235,7 @@ def layer_norm_forward(X, W, B, eps): return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps -def layer_norm_backward(dY, X, W, B, Mean, RSTD): +def layer_norm_backward(dY, X, W, B, Mean, RSTD, compute_dW, compute_dB): """ Args: dY: Gradient of output @@ -253,9 +261,15 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD): elif X.device.type == "npu": sm_count = get_npu_core_count() - # fp32 for numerical stability especially. - _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) - _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + if compute_dW: + # fp32 for numerical stability especially. + _DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + else: + _DW = None + if compute_dB: + _DB = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + else: + _DB = None # Calculate optimal block size and warp configuration BLOCK_SIZE, num_warps = calculate_settings(n_cols) @@ -284,22 +298,30 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD): RSTD.stride(0), DX, DX.stride(0), - _DW, - _DW.stride(0), - _DB, - _DB.stride(0), + _DW if compute_dW else X, + _DW.stride(0) if compute_dW else 0, + _DB if compute_dB else X, + _DB.stride(0) if compute_dB else 0, dY, dY.stride(0), n_rows, n_cols, rows_per_program=rows_per_program, BLOCK_SIZE=BLOCK_SIZE, + compute_dW=compute_dW, + compute_dB=compute_dB, **kernel_args, ) DX = DX.view(*shape) - DW = _DW.sum(dim=0).to(W.dtype) - DB = _DB.sum(dim=0).to(B.dtype) + if compute_dW: + DW = _DW.sum(dim=0).to(W.dtype) + else: + DW = None + if compute_dB: + DB = _DB.sum(dim=0).to(B.dtype) + else: + DB = None return DX, DW, DB @@ -316,5 +338,7 @@ def forward(ctx, X, W, B, eps): @ensure_contiguous def backward(ctx, dY): X, W, B, Mean, RSTD = ctx.saved_tensors - DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD) + need_dW = ctx.needs_input_grad[1] + need_dB = ctx.needs_input_grad[2] + DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD, compute_dW=need_dW, compute_dB=need_dB) return DX, DW, DB, None diff --git a/src/liger_kernel/ops/poly_norm.py b/src/liger_kernel/ops/poly_norm.py index 2198e522d..a6ad502eb 100644 --- a/src/liger_kernel/ops/poly_norm.py +++ b/src/liger_kernel/ops/poly_norm.py @@ -1,3 +1,4 @@ +import math import operator import torch @@ -113,6 +114,8 @@ def _poly_norm_backward_kernel( n_cols, rows_per_program: tl.constexpr, BLOCK_SIZE: tl.constexpr, + compute_dW: tl.constexpr, + compute_dB: tl.constexpr, ): """ PolyNorm Backward Kernel Gradient: @@ -131,10 +134,12 @@ def _poly_norm_backward_kernel( mask = col_offsets < n_cols # Initialize accumulators for weight and bias gradients (scalars) - dW0_acc = 0.0 - dW1_acc = 0.0 - dW2_acc = 0.0 - dB_acc = 0.0 + if compute_dW: + dW0_acc = 0.0 + dW1_acc = 0.0 + dW2_acc = 0.0 + if compute_dB: + dB_acc = 0.0 # Load weights w0 = tl.load(W_ptr + 0).to(tl.float32) @@ -161,7 +166,8 @@ def _poly_norm_backward_kernel( X_pow1 = X_row # Accumulate bias gradient: dB = sum(dY) - dB_acc += tl.sum(dY_row, axis=0) + if compute_dB: + dB_acc += tl.sum(dY_row, axis=0) # Compute gradient w.r.t. input using closed-form formula # For p=3: ∂L/∂x from w0 * norm(x³) @@ -182,9 +188,10 @@ def _poly_norm_backward_kernel( grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1) # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p - dW0_acc += rstd_3 * S_3 - dW1_acc += rstd_2 * S_2 - dW2_acc += rstd_1 * S_1 + if compute_dW: + dW0_acc += rstd_3 * S_3 + dW1_acc += rstd_2 * S_2 + dW2_acc += rstd_1 * S_1 # Total gradient dX_row = grad_x_3 + grad_x_2 + grad_x_1 @@ -193,10 +200,12 @@ def _poly_norm_backward_kernel( tl.store(dx_base + col_offsets, dX_row, mask=mask) # Store accumulated gradients (scalars) - tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc) - tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc) - tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc) - tl.store(dB_ptr + row_block_id, dB_acc) + if compute_dW: + tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc) + tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc) + tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc) + if compute_dB: + tl.store(dB_ptr + row_block_id, dB_acc) def poly_norm_forward(X, W, B, eps=1e-6): @@ -255,7 +264,7 @@ def poly_norm_forward(X, W, B, eps=1e-6): return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps -def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place): +def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place, compute_dW, compute_dB): """ PolyNorm Backward Pass @@ -279,8 +288,6 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place): n_rows, n_cols = dY.shape # Get number of SMs for parallelization - import math - sm_count = 1 if X.device.type == "cuda": sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count @@ -295,8 +302,14 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place): else: dX = torch.zeros_like(dY) - _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device) - _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device) + if compute_dW: + _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device) + else: + _dW = None + if compute_dB: + _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device) + else: + _dB = None rows_per_program = math.ceil(n_rows / sm_count) grid = (sm_count,) @@ -317,21 +330,29 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place): W, RSTD, RSTD.stride(0), - _dW, - _dW.stride(0), - _dB, + _dW if compute_dW else X, + _dW.stride(0) if compute_dW else 0, + _dB if compute_dB else X, n_rows, n_cols, rows_per_program, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, + compute_dW=compute_dW, + compute_dB=compute_dB, **kernel_args, ) # Reduce gradients across SMs dX = dX.view(*shape) - dW = _dW.sum(dim=0).to(W.dtype) - dB = _dB.sum().to(W.dtype) + if compute_dW: + dW = _dW.sum(dim=0).to(W.dtype) + else: + dW = None + if compute_dB: + dB = _dB.sum().to(W.dtype) + else: + dB = None return dX, dW, dB @@ -380,5 +401,9 @@ def backward(ctx, grad_output): dX, dW, dB: gradients w.r.t. X, W, B """ X, W, RSTD = ctx.saved_tensors - dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place) + need_dW = ctx.needs_input_grad[1] + need_dB = ctx.needs_input_grad[2] + dX, dW, dB = poly_norm_backward( + grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, compute_dW=need_dW, compute_dB=need_dB + ) return dX, dW, dB, None, None diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index e5cab72ea..bc25406dd 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -140,6 +140,7 @@ def _rms_norm_backward_kernel( rows_per_program, casting_mode: tl.constexpr, elementwise_affine: tl.constexpr, + compute_dW: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ @@ -153,12 +154,11 @@ def _rms_norm_backward_kernel( col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - if elementwise_affine: - dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - if elementwise_affine: W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) W_row = W_row + offset + if compute_dW: + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) for row_idx in range(row_start, row_end): dy_base = dY_ptr + row_idx * dY_row_stride @@ -198,7 +198,7 @@ def _rms_norm_backward_kernel( dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) - if elementwise_affine: + if elementwise_affine and compute_dW: # calculate the gradient of W if casting_mode == _CASTING_MODE_LLAMA: dW_row += dY_row * (X_row * rstd_row).to(X_dtype) @@ -208,7 +208,7 @@ def _rms_norm_backward_kernel( tl.store(dx_base + col_offsets, dX_row.to(X_dtype), mask=mask) - if elementwise_affine: + if elementwise_affine and compute_dW: tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) @@ -317,6 +317,7 @@ def _block_rms_norm_backward_kernel( offset, casting_mode: tl.constexpr, elementwise_affine: tl.constexpr, + compute_dW: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_ROW: tl.constexpr, ): @@ -332,10 +333,10 @@ def _block_rms_norm_backward_kernel( col_mask = col_offsets < n_cols if elementwise_affine: - dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0) W_row = W_row + offset + if compute_dW: + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW): row_idx = start + tl.arange(0, BLOCK_ROW) @@ -381,7 +382,7 @@ def _block_rms_norm_backward_kernel( -(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row ) - if elementwise_affine: + if elementwise_affine and compute_dW: if casting_mode == _CASTING_MODE_LLAMA: # TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0 dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0) @@ -395,7 +396,7 @@ def _block_rms_norm_backward_kernel( mask=row_mask[:, None] & col_mask[None, :], ) - if elementwise_affine: + if elementwise_affine and compute_dW: tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask) @@ -482,7 +483,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode -def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode): +def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode, compute_dW): shape = dY.shape dim = shape[-1] dY = dY.view(-1, dim) @@ -497,12 +498,18 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp sm_count = get_npu_core_count() if W is not None: + elementwise_affine = True + else: + elementwise_affine = False + + compute_dW = compute_dW and elementwise_affine + if compute_dW: # fp32 for numerical stability especially. _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) - elementwise_affine = True else: + # When not computing dW, we pass a dummy pointer (X) to the kernel. + # The kernel never reads/writes to it when compute_dW=False. _dW = None - elementwise_affine = False if n_cols > BLOCK_SIZE: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") @@ -532,14 +539,15 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp W.stride(0) if elementwise_affine else 0, RSTD, RSTD.stride(0), - _dW, - _dW.stride(0) if elementwise_affine else 0, + _dW if compute_dW else X, + _dW.stride(0) if compute_dW else 0, n_rows, n_cols, offset, rows_per_program, casting_mode, elementwise_affine=elementwise_affine, + compute_dW=compute_dW, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, **kernel_args, # XPU-specific optimization @@ -559,20 +567,21 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp W.stride(0) if elementwise_affine else 0, RSTD, RSTD.stride(0), - _dW, - _dW.stride(0) if elementwise_affine else 0, + _dW if compute_dW else X, + _dW.stride(0) if compute_dW else 0, n_rows, n_cols, offset, casting_mode, elementwise_affine=elementwise_affine, + compute_dW=compute_dW, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, **kernel_args, # XPU-specific optimization ) dX = dX.view(*shape) - if elementwise_affine: + if compute_dW: dW = _dW.sum(dim=0).to(W.dtype) else: dW = None @@ -648,7 +657,20 @@ def backward(ctx, dY): # TODO: support CP. dY = dY.full_tensor() + need_dW = ctx.needs_input_grad[1] and ctx.elementwise_affine dX, dW = rms_norm_backward( - dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode + dY, + X, + W, + RSTD, + ctx.offset, + ctx.casting_mode, + ctx.BLOCK_SIZE, + ctx.num_warps, + ctx.in_place, + ctx.row_mode, + compute_dW=need_dW, ) + if not need_dW: + dW = None return dX, dW, None, None, None, None, None diff --git a/test/transformers/test_fused_add_rms_norm.py b/test/transformers/test_fused_add_rms_norm.py index a2567800c..04de6c710 100644 --- a/test/transformers/test_fused_add_rms_norm.py +++ b/test/transformers/test_fused_add_rms_norm.py @@ -158,6 +158,75 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m assert_verbose_allclose(r1.grad, r2.grad, atol=atol, rtol=rtol, max_print=20) +@pytest.mark.flaky(reruns=3, reruns_delay=2) +@pytest.mark.parametrize( + "bs, sl, hd", + [ + (2, 128, 512), + (5, 123, 123), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-4, 1e-6), + pytest.param( + torch.bfloat16, + 2e-1, + 2e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + ], +) +@pytest.mark.parametrize( + "reference, offset, casting_mode", + [ + (LlamaAddRMSNorm, 0.0, "llama"), + (GemmaAddRMSNorm, 1.0, "gemma"), + ], +) +def test_correctness_frozen_weight(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode): + """Test that frozen weight (requires_grad=False) works correctly and produces no weight gradient.""" + _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) + _residual = torch.randn(bs, sl, hd, device=device, dtype=dtype) + + h1 = _tensor.clone().requires_grad_(True) + r1 = _residual.clone().requires_grad_(True) + h2 = _tensor.clone().requires_grad_(True) + r2 = _residual.clone().requires_grad_(True) + + dh = torch.randn(bs, sl, hd, device=device, dtype=dtype) + dr = torch.randn(bs, sl, hd, device=device, dtype=dtype) + + # reference with frozen weight + ref_rms = reference(hidden_size=hd).to(device).to(dtype) + ref_rms.weight.requires_grad_(False) + ref_h, ref_r = ref_rms(h1, r1) + torch.autograd.backward((ref_h, ref_r), (dh, dr), retain_graph=True) + + # triton with frozen weight + triton_rms = ( + LigerFusedAddRMSNorm(hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=False) + .to(device) + .to(dtype) + ) + triton_rms.weight.requires_grad_(False) + triton_h, triton_r = triton_rms(h2, r2) + torch.autograd.backward((triton_h, triton_r), (dh, dr), retain_graph=True) + + # Check forward output matches + assert_verbose_allclose(ref_h, triton_h, atol=atol, rtol=rtol) + assert_verbose_allclose(ref_r, triton_r, atol=atol, rtol=rtol) + + # Check weight.grad is None for both (frozen weight should not have gradient) + assert ref_rms.weight.grad is None, "Reference weight.grad should be None when frozen" + assert triton_rms.weight.grad is None, "Triton weight.grad should be None when frozen" + + # Check input gradients still match + assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20) + assert_verbose_allclose(r1.grad, r2.grad, atol=atol, rtol=rtol, max_print=20) + + @pytest.mark.parametrize( "bs, sl, hd", [ diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 3454947f9..24895a86b 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -50,3 +50,78 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol), "Bias grads different" assert torch.allclose(liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol), "Weight grads different" + + +@pytest.mark.parametrize( + "batch_size, num_channels, num_groups, hidden_size", + [ + (1, 32, 32, 4), + (16, 48, 12, 8192), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-4, 1e-4), + ], +) +@pytest.mark.parametrize( + "freeze_weight, freeze_bias", + [ + (True, False), + (False, True), + (True, True), + ], +) +def test_liger_group_norm_frozen_params( + batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol, freeze_weight, freeze_bias +): + """Test that frozen weight/bias (requires_grad=False) works correctly and produces no gradient.""" + torch.manual_seed(0) + + _tensor = torch.randn(batch_size, num_channels, hidden_size, dtype=dtype, device=device) + + liger_x = _tensor.clone().detach().requires_grad_(True) + torch_x = _tensor.clone().detach().requires_grad_(True) + + liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).to(device) + torch_ln = torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6).to(dtype).to(device) + + with torch.no_grad(): + torch_ln.weight.copy_(liger_ln.weight) + torch_ln.bias.copy_(liger_ln.bias) + + # Freeze weight and/or bias + if freeze_weight: + liger_ln.weight.requires_grad_(False) + torch_ln.weight.requires_grad_(False) + if freeze_bias: + liger_ln.bias.requires_grad_(False) + torch_ln.bias.requires_grad_(False) + + liger_output = liger_ln(liger_x) + torch_output = torch_ln(torch_x) + + assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) + + grad_output = torch.randn_like(torch_x) + liger_output.backward(grad_output, retain_graph=True) + torch_output.backward(grad_output, retain_graph=True) + + # Check input gradients match + assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) + + # Check frozen params have no gradient + if freeze_weight: + assert liger_ln.weight.grad is None, "Liger weight.grad should be None when frozen" + assert torch_ln.weight.grad is None, "Torch weight.grad should be None when frozen" + else: + assert torch.allclose(liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol), ( + "Weight grads different" + ) + + if freeze_bias: + assert liger_ln.bias.grad is None, "Liger bias.grad should be None when frozen" + assert torch_ln.bias.grad is None, "Torch bias.grad should be None when frozen" + else: + assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol), "Bias grads different" diff --git a/test/transformers/test_layer_norm.py b/test/transformers/test_layer_norm.py index 14021f5f8..0a929d358 100644 --- a/test/transformers/test_layer_norm.py +++ b/test/transformers/test_layer_norm.py @@ -63,6 +63,87 @@ def test_liger_layer_norm( assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) +@pytest.mark.parametrize( + "batch_size, seq_len, hidden_size", + [ + (2, 8, 64), + (4, 16, 128), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-5, 1e-5), + (torch.bfloat16, 2e-2, 2e-2), # Relaxed tolerance for bfloat16 due to lower precision + ], +) +@pytest.mark.parametrize( + "freeze_weight, freeze_bias", + [ + (True, False), + (False, True), + (True, True), + ], +) +def test_liger_layer_norm_frozen_params( + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + atol: float, + rtol: float, + freeze_weight: bool, + freeze_bias: bool, +) -> None: + """Test that frozen weight/bias (requires_grad=False) works correctly and produces no gradient.""" + torch.manual_seed(0) + + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) + + liger_x = x.clone().requires_grad_(True) + torch_x = x.clone().requires_grad_(True) + + liger_ln = LigerLayerNorm(hidden_size, eps=1e-6).to(dtype).to(device) + torch_ln = torch.nn.LayerNorm(hidden_size, eps=1e-6).to(dtype).to(device) + + with torch.no_grad(): + torch_ln.weight.copy_(liger_ln.weight) + torch_ln.bias.copy_(liger_ln.bias) + + # Freeze weight and/or bias + if freeze_weight: + liger_ln.weight.requires_grad_(False) + torch_ln.weight.requires_grad_(False) + if freeze_bias: + liger_ln.bias.requires_grad_(False) + torch_ln.bias.requires_grad_(False) + + liger_output = liger_ln(liger_x) + torch_output = torch_ln(torch_x) + + assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) + + grad_output = torch.randn_like(x) + liger_output.backward(grad_output, retain_graph=True) + torch_output.backward(grad_output, retain_graph=True) + + # Check input gradients match + assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) + + # Check frozen params have no gradient + if freeze_weight: + assert liger_ln.weight.grad is None, "Liger weight.grad should be None when frozen" + assert torch_ln.weight.grad is None, "Torch weight.grad should be None when frozen" + else: + assert torch.allclose(liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol) + + if freeze_bias: + assert liger_ln.bias.grad is None, "Liger bias.grad should be None when frozen" + assert torch_ln.bias.grad is None, "Torch bias.grad should be None when frozen" + else: + assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) + + @pytest.mark.parametrize( "batch_size, seq_len, hidden_size", [ diff --git a/test/transformers/test_poly_norm.py b/test/transformers/test_poly_norm.py index 324238194..02789c4ae 100644 --- a/test/transformers/test_poly_norm.py +++ b/test/transformers/test_poly_norm.py @@ -147,6 +147,85 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol): assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol, max_print=20) +@pytest.mark.flaky(reruns=3, reruns_delay=2) +@pytest.mark.parametrize( + "bs, sl, hd", + [ + (2, 128, 512), + (5, 123, 123), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-4, 1e-6), + pytest.param( + torch.bfloat16, + 2e-1, + 2e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + ], +) +@pytest.mark.parametrize( + "freeze_weight, freeze_bias", + [ + (True, False), + (False, True), + (True, True), + ], +) +def test_correctness_frozen_params(bs, sl, hd, dtype, atol, rtol, freeze_weight, freeze_bias): + """Test that frozen weight/bias (requires_grad=False) works correctly and produces no gradient.""" + _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) + + x1 = _tensor.clone().requires_grad_(True) + x2 = _tensor.clone().requires_grad_(True) + + grad_output = torch.randn(bs, sl, hd, device=device, dtype=dtype) + + # Reference: Naive PyTorch implementation + naive_poly_norm = NaivePolyNorm(eps=1e-6).to(device).to(dtype) + + # Liger implementation + liger_poly_norm = LigerPolyNorm(eps=1e-6).to(device).to(dtype) + liger_poly_norm.weight.data.copy_(naive_poly_norm.weight.data) + liger_poly_norm.bias.data.copy_(naive_poly_norm.bias.data) + + # Freeze weight and/or bias + if freeze_weight: + naive_poly_norm.weight.requires_grad_(False) + liger_poly_norm.weight.requires_grad_(False) + if freeze_bias: + naive_poly_norm.bias.requires_grad_(False) + liger_poly_norm.bias.requires_grad_(False) + + ref_output = naive_poly_norm(x1) + ref_output.backward(grad_output, retain_graph=True) + + triton_output = liger_poly_norm(x2) + triton_output.backward(grad_output, retain_graph=True) + + # Check forward pass + assert_verbose_allclose(ref_output, triton_output, atol=atol, rtol=rtol) + + # Check input gradient + assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol, max_print=20) + + # Check frozen params have no gradient + if freeze_weight: + assert naive_poly_norm.weight.grad is None, "Naive weight.grad should be None when frozen" + assert liger_poly_norm.weight.grad is None, "Liger weight.grad should be None when frozen" + else: + assert_verbose_allclose(naive_poly_norm.weight.grad, liger_poly_norm.weight.grad, atol=atol, rtol=rtol) + + if freeze_bias: + assert naive_poly_norm.bias.grad is None, "Naive bias.grad should be None when frozen" + assert liger_poly_norm.bias.grad is None, "Liger bias.grad should be None when frozen" + else: + assert_verbose_allclose(naive_poly_norm.bias.grad, liger_poly_norm.bias.grad, atol=atol, rtol=rtol) + + @pytest.mark.parametrize( "bs, sl, hd", [ diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index 8c36472bb..c0bce3c55 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -184,6 +184,75 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20) +@pytest.mark.flaky(reruns=3, reruns_delay=2) +@pytest.mark.parametrize( + "bs, sl, hd", + [ + (2, 128, 512), + (5, 123, 123), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-4, 1e-6), + pytest.param( + torch.bfloat16, + 2e-1, + 2e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), + ], +) +@pytest.mark.parametrize( + "reference, offset, casting_mode", + [ + (LlamaRMSNorm, 0.0, "llama"), + (GemmaRMSNorm, 1.0, "gemma"), + ], +) +def test_correctness_frozen_weight(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode): + """Test that frozen weight (requires_grad=False) works correctly and produces no weight gradient.""" + _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) + + h1 = _tensor.clone().requires_grad_(True) + h2 = _tensor.clone().requires_grad_(True) + + do = torch.randn(bs, sl, hd, device=device, dtype=dtype) + + # reference (llama or gemma) with frozen weight + ref_rms = reference(hidden_size=hd, elementwise_affine=True).to(device).to(dtype) + ref_rms.weight.requires_grad_(False) + ref_o = ref_rms(h1) + ref_o.backward(do, retain_graph=True) + + # triton with frozen weight + triton_rms = ( + LigerRMSNorm( + hidden_size=hd, + offset=offset, + casting_mode=casting_mode, + in_place=False, + elementwise_affine=True, + ) + .to(device) + .to(dtype) + ) + triton_rms.weight.requires_grad_(False) + triton_o = triton_rms(h2) + triton_o.backward(do, retain_graph=True) + + # Check forward output matches + assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol) + + # Check weight.grad is None for both (frozen weight should not have gradient) + assert ref_rms.weight.grad is None, "Reference weight.grad should be None when frozen" + assert triton_rms.weight.grad is None, "Triton weight.grad should be None when frozen" + + # Check input gradients still match + assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20) + + @pytest.mark.parametrize( "bs, sl, hd", [