diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index e5cab72ea..3ea3c74f7 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -17,6 +17,14 @@ import triton import triton.language as tl +try: + from torch.distributed.tensor import Shard + + _DTENSOR_AVAILABLE = True +except ImportError: + _DTENSOR_AVAILABLE = False + Shard = None + from liger_kernel.ops.utils import calculate_settings from liger_kernel.ops.utils import compare_version from liger_kernel.ops.utils import ensure_contiguous @@ -25,6 +33,30 @@ from liger_kernel.ops.utils import torch_to_triton_dtype from liger_kernel.utils import is_npu_available + +def _is_hidden_dim_sharded(dtensor: "torch.distributed.tensor.DTensor") -> bool: + """ + Check if the DTensor is sharded on the hidden dimension (last dimension). + + This is used to determine whether we need to gather the full tensor for RMSNorm + computation (Tensor Parallel case) or can compute locally (Context Parallel case). + + Args: + dtensor: A DTensor instance to check. + + Returns: + True if the tensor is sharded on the hidden (last) dimension (TP case), + False otherwise (CP case - can compute locally). + """ + if not _DTENSOR_AVAILABLE or Shard is None: + return False + hidden_dim = dtensor.ndim - 1 # Last dimension is the hidden dimension + for placement in dtensor.placements: + if isinstance(placement, Shard) and placement.dim == hidden_dim: + return True + return False + + if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): try: # typical import path with dispatch available @@ -609,12 +641,25 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row X: (B, T, H) or (BxT, H) W: (H,) """ + # Track DTensor metadata for potential reconstruction in backward + ctx.is_dtensor_input = False + ctx.dtensor_device_mesh = None + ctx.dtensor_placements = None + if isinstance(X, torch.distributed.tensor.DTensor): - # Input tensor is output of a tensor parallel module and - # needs to be gathered to a local tensor to compute - # RMSE layer norm on each TP worker. - # TODO: support CP. - X = X.full_tensor() + if _is_hidden_dim_sharded(X): + # Tensor Parallel (TP): hidden dimension is sharded across devices. + # RMSNorm requires the full hidden dimension to compute the RMS, + # so we need to gather the full tensor. + X = X.full_tensor() + else: + # Context Parallel (CP): sequence dimension is sharded. + # RMSNorm computes independently for each sequence position, + # so we can compute locally without gathering. + ctx.is_dtensor_input = True + ctx.dtensor_device_mesh = X.device_mesh + ctx.dtensor_placements = X.placements + X = X.to_local() Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode) ctx.offset = offset @@ -628,6 +673,15 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row ctx.save_for_backward(X, W, RSTD) else: ctx.save_for_backward(X, RSTD) + + # If input was a CP DTensor, wrap output back into DTensor + if ctx.is_dtensor_input: + Y = torch.distributed.tensor.DTensor.from_local( + Y, + device_mesh=ctx.dtensor_device_mesh, + placements=ctx.dtensor_placements, + ) + return Y @staticmethod @@ -643,12 +697,36 @@ def backward(ctx, dY): W = None if isinstance(dY, torch.distributed.tensor.DTensor): - # Gradients are output of a tensor parallel module and - # needs to be gathered to a local tensor for computing RMSE layer. - # TODO: support CP. - dY = dY.full_tensor() + if ctx.is_dtensor_input: + # Context Parallel (CP): sequence dimension is sharded. + # We can compute gradients locally for each sequence position. + dY = dY.to_local() + else: + # Tensor Parallel (TP): hidden dimension is sharded. + # Need to gather the full gradient tensor. + dY = dY.full_tensor() dX, dW = rms_norm_backward( dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode ) + + # If input was a CP DTensor, handle output accordingly + if ctx.is_dtensor_input: + # Wrap dX back into DTensor with the same placements + dX = torch.distributed.tensor.DTensor.from_local( + dX, + device_mesh=ctx.dtensor_device_mesh, + placements=ctx.dtensor_placements, + ) + + # For dW, we need to all-reduce across all sharded mesh dimensions + # since each device only computed gradients for its local sequence positions, + # but the weight is shared across all positions. For multi-dimensional meshes + # (e.g., batch + sequence sharding), we must reduce across each sharded dim. + if dW is not None and _DTENSOR_AVAILABLE and Shard is not None: + for i, placement in enumerate(ctx.dtensor_placements): + if isinstance(placement, Shard): + pg = ctx.dtensor_device_mesh.get_group(mesh_dim=i) + torch.distributed.all_reduce(dW, group=pg) + return dX, dW, None, None, None, None, None diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index 8c36472bb..12a826650 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -309,3 +309,103 @@ def test_dtensor_rms_norm(world_size, bs, sl, hd, dtype, atol, rtol, offset, cas nprocs=world_size, join=True, ) + + +def _test_dtensor_rms_norm_context_parallel( + rank, world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode, file_name +): + """ + Test RMSNorm with Context Parallel (CP) - sequence dimension sharding. + + Unlike Tensor Parallel (TP) which shards on hidden dimension, CP shards on + sequence dimension. RMSNorm can compute locally for CP since each position + is independent, avoiding the need for full_tensor() gathering. + """ + torch.distributed.init_process_group( + backend=infer_comm_backend(), + init_method=f"file://{file_name}", + rank=rank, + world_size=world_size, + ) + device = f"{infer_device()}:{rank}" if infer_device() != "cpu" else "cpu" + device_mesh = torch.distributed.device_mesh.init_device_mesh( + infer_device(), mesh_shape=(world_size,), mesh_dim_names=("cp",) + ) + + # Create a tensor and shard on sequence dimension (dim=1) for CP + # sl must be divisible by world_size for even sharding + t = torch.randn(bs, sl, hd, device=device, dtype=dtype, requires_grad=True) + dt = torch.distributed.tensor.distribute_tensor( + t, + device_mesh=device_mesh, + placements=[torch.distributed.tensor.Shard(1)], # Shard on sequence dim for CP + ) + + # Weight is replicated across all devices + w = torch.randn(hd, device=device, dtype=dtype, requires_grad=True) + w1 = w.detach().clone().requires_grad_(True) + w2 = w.detach().clone().requires_grad_(True) + + # Forward pass: compare DTensor (CP) result with regular tensor result + y1 = liger_rms_norm(X=dt, W=w1, eps=1e-6, offset=offset, casting_mode=casting_mode) + y2 = liger_rms_norm(X=t, W=w2, eps=1e-6, offset=offset, casting_mode=casting_mode) + + # y1 is a DTensor sharded on sequence dim, y2 is a regular tensor + # Compare the full tensors + torch.testing.assert_close(y1.full_tensor(), y2, atol=atol, rtol=rtol) + + # Backward pass + grad = torch.randn_like(y2) + dgrad = torch.distributed.tensor.distribute_tensor( + grad, + device_mesh=device_mesh, + placements=[torch.distributed.tensor.Shard(1)], # Same sharding as output + ) + + y1.backward(dgrad) + y2.backward(grad) + + # Check weight gradients: should match after all-reduce in backward + torch.testing.assert_close(w1.grad, w2.grad, atol=atol, rtol=rtol) + + # Check input gradients: dt.grad is a DTensor, t.grad is a regular tensor + torch.testing.assert_close(dt.grad.full_tensor(), t.grad, atol=atol, rtol=rtol) + + torch.distributed.destroy_process_group() + + +@pytest.mark.xfail( + torch.cuda.device_count() < 4, + reason="Pending multi-GPU host support. This test requires at least 4 GPUs.", +) +@pytest.mark.parametrize( + "world_size, bs, sl, hd", + [ + (2, 2, 8, 16), # sl=8 divisible by world_size=2 + (4, 2, 16, 32), # sl=16 divisible by world_size=4 + (2, 3, 6, 17), # weird shapes: non-power-of-2 batch, seq, hidden dims + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-4, 1e-6), + (torch.bfloat16, 2e-1, 2e-2), + ], +) +@pytest.mark.parametrize( + "offset, casting_mode", + [ + (0.0, "llama"), + (1.0, "gemma"), + ], +) +def test_dtensor_rms_norm_context_parallel(world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode): + """Test RMSNorm with Context Parallel (sequence dimension sharding).""" + with tempfile.NamedTemporaryFile() as f: + mp.spawn( + _test_dtensor_rms_norm_context_parallel, + args=(world_size, bs, sl, hd, dtype, atol, rtol, offset, casting_mode, f.name), + nprocs=world_size, + join=True, + )