diff --git a/.github/workflows/check_api_backwards_compatibility_workflow.yml b/.github/workflows/check_api_backwards_compatibility_workflow.yml index 1dc419b8a6..e21d798c69 100644 --- a/.github/workflows/check_api_backwards_compatibility_workflow.yml +++ b/.github/workflows/check_api_backwards_compatibility_workflow.yml @@ -66,7 +66,7 @@ jobs: # Default baseline for automatic PR checks # Can be: branch name (e.g., 'main'), commit hash, or tag # Will be resolved to commit hash during execution - DEFAULT_BASELINE: '712dff880cdf88e51289ad71e47d92f46d25a2d3' + DEFAULT_BASELINE: 'f7fb5ecbe218672719053fa304d91767ce30ffa1' # Tag pattern for auto-detection (e.g., 'core_r*', 'core_v*') TAG_PATTERN: 'core_v*' # Tag regex filter (e.g., '^core_v[0-9]+\.[0-9]+\.[0-9]+$' for stable versions only) diff --git a/megatron/core/distributed/fsdp/src/README.md b/megatron/core/distributed/fsdp/src/README.md index 9e036f22f6..886f122268 100644 --- a/megatron/core/distributed/fsdp/src/README.md +++ b/megatron/core/distributed/fsdp/src/README.md @@ -124,7 +124,7 @@ device_mesh = torch.distributed.device_mesh.init_device_mesh( device_mesh[("dp_outer", "dp_shard")]._flatten("dp") # Only required if using CP. Otherwise, just pass dp_shard to FSDP. device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp") -# Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group. +# Only required if using HFSDP. Otherwise, don't pass hybrid_fsdp_group. device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp") hsdp_group = device_mesh["hsdp"].get_group() # Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP. @@ -149,7 +149,7 @@ model, optimizer = fully_shard( # Only required for TP-sensitive models (i.e. Megatron-LM / TransformerEngine) or when using DTensor-based TP. # Otherwise, set this to None. tp_dim="tp", - # Only required when using HSDP. Otherwise, set this to None. + # Only required when fully-sharding the optimizer state in HFSDP. Otherwise, set this to None. hybrid_fsdp_group=hsdp_group, # Only required for FSDP + EP. Otherwise, set this to None. expt_device_mesh=expt_device_mesh, @@ -185,7 +185,7 @@ model.load_state_dict(ckpt_state_dict["model"], strict=False) optimizer.load_state_dict(ckpt_state_dict["optimizer"]) ``` -- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported. +- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported. - Default: `optim_grads_params` or `3` for `zero_dp_strategy` and `no_shard` or `0` for `outer_dp_sharding_strategy` - `0` or `no_shard` implies that your model is not sharded. Similar memory usage to `DDP`. - `1` or `optim` implies that your optimizer state is sharded for distributed optimization. Similar to optimizer state sharding in `ZeRO-DP`. @@ -199,7 +199,7 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"]) - `dp_outer_dim` is the name of the sub-mesh corresponding to the "outer" DP group, which is required for replication or sharding in HSDP. `fully_shard` will perform HSDP if `dp_outer_dim` is specified. - `tp_dim` is the name of the sub-mesh used for tensor parallelism (TP), which is required for `(FSDP, TP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` TP. - For more information about tensor parallelism, refer to: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053). - - `hybrid_fsdp_group` is the `ProcessGroup` which contains all ranks in the flattened `dp_shard_dim` and `dp_outer_dim` sub-meshes utilized to specify the `(DP-Outer, DP-Shard)` sharded coordinate system for the weight and gradient buffers. Required for HSDP. + - `hybrid_fsdp_group` is the `ProcessGroup` which contains all ranks in the flattened `dp_shard_dim` and `dp_outer_dim` sub-meshes utilized to specify the `(DP-Outer, DP-Shard)` sharded coordinate system for the weight and gradient buffers. Required for HFSDP only, i.e. fully-sharded optimizer state with HSDP. - `expt_device_mesh` is another [`torch.distributed.DeviceMesh`](https://docs.pytorch.org/docs/stable/distributed.html#devicemesh) tailored for the expert parallel (EP) modules in `MegatronFSDP`. - `dp_shard_dim` is the name of the sub-mesh required for FSDP sharding of the EP modules, enabling expert data parallelism (EDP). - `tp_dim` is the name of the sub-mesh used for expert tensor parallelism (ETP), which is required for `(FSDP, ETP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` ETP. diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py index e98362a1a0..d6150ae10b 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py @@ -142,12 +142,16 @@ def fully_shard_model( "zero_dp_strategy to use FSDP ('optim_grads_params', 3), because " "outer sharding is dependent on inner sharding." ) - if (dp_outer_dim is None) ^ (hybrid_fsdp_group is None): - # XOR - HSDP requires both or neither of dp_outer_dim and hybrid_fsdp_group - # to be specified, so if XOR then raise an error. + if _outer_fsdp_sharding and hybrid_fsdp_group is None: + # If fully-sharding the optimizer state on DP-Outer, you must provide the + # completely flattened HFSDP group for logical rank assignment to the + # optimizer state full-sharding ranks. raise ValueError( - f"dp_outer_dim={dp_outer_dim} and hybrid_fsdp_group={hybrid_fsdp_group} must be " - "specified together for Hybrid FSDP (HSDP), or both set to None (for FSDP)." + "[HFSDP] Fully-sharding the optimizer on DP-Outer " + f"(outer_dp_sharding_strategy={outer_dp_sharding_strategy}) " + f"requires a fully-flattened hybrid_fsdp_group={hybrid_fsdp_group} " + "for rank assignment to the optimizer state. You can flatten your DeviceMesh " + f"via `DeviceMesh[(DP-Outer, DP-Shard)]._flatten()` & `DeviceMesh.get_group()`." ) if init_model_with_meta_device and zero_dp_strategy == "no_shard": raise ValueError( diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 8a63e0f5cf..f962ebe4f2 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -311,12 +311,21 @@ def _init_fsdp_param_and_grad_buffer(self): else: if self.ddp_config.average_in_collective: gradient_scaling_factor = 1.0 + # Utilized to re-scale expert gradients to DP. + # (edp_size/dp_size) * (1/edp_size) = 1/dp_size + # FIXME(@cspades): Currently not used gradient_reduce_preprocessing()? expert_gradient_scaling_factor = ( self.dist_index.get_dp_group(is_expert_parallel=True).size() - / self.dist_index.get_dp_group().size() + / self.dist_index.get_fsdp_group().size() ) + if self.dist_index.use_hybrid_fsdp: + # Also divide the DP-Outer size in the conversion factor. + expert_gradient_scaling_factor /= self.dist_index.get_outer_fsdp_group().size() else: - data_parallel_world_size = self.dist_index.get_dp_group().size() + data_parallel_world_size = self.dist_index.get_fsdp_group().size() + if self.dist_index.use_hybrid_fsdp: + # Also multiply the DP-Outer size in the DP size. + data_parallel_world_size *= self.dist_index.get_outer_fsdp_group().size() gradient_scaling_factor = 1.0 / data_parallel_world_size expert_gradient_scaling_factor = 1.0 / data_parallel_world_size diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py index 490d80c0f2..d358ae6cab 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py @@ -25,6 +25,8 @@ from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard +from .utils import get_mesh_names + def gather_and_compute_chunk_metadata(dtensor: DTensor) -> ChunkStorageMetadata: """ @@ -272,7 +274,14 @@ def gather_uneven_dtensor_to_full_tensor( if not device_mesh.mesh_dim_names: process_group = device_mesh.get_group() else: - process_group = device_mesh._flatten().get_group() + # Check if the fully-flattened mesh exists first. + full_flattened_mesh_dim_name = "_".join(device_mesh.mesh_dim_names) + if full_flattened_mesh_dim_name in get_mesh_names(device_mesh): + # Retrieve the existing flattened DeviceMesh ProcessGroup. + process_group = device_mesh[full_flattened_mesh_dim_name].get_group() + else: + # Create the _-separated flattened DeviceMesh ProcessGroup. + process_group = device_mesh._flatten().get_group() # Collect chunk metadata for uneven shards (update if missing) if not hasattr(dtensor._local_tensor, "__create_chunk_list__"): diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py index b94a332bb0..977ac8c750 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py @@ -167,13 +167,10 @@ def get_mesh_names(device_mesh: Optional[DeviceMesh] = None) -> list[str]: submesh_dim_name for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items() for submesh_dim_name in (child_mesh.mesh_dim_names or []) - if root_mesh == device_mesh + # Add flattened or other unaccounted for children of the root mesh. + if root_mesh == device_mesh and submesh_dim_name not in mesh_dim_names ] - # Combine without duplicate dimensions. - for dim_name in submesh_dim_names: - if dim_name not in mesh_dim_names: - mesh_dim_names.append(dim_name) - return mesh_dim_names + return mesh_dim_names + submesh_dim_names def contains_submesh( @@ -787,16 +784,17 @@ def register_submesh(device_mesh, submesh, is_expert_parallel): if self.use_hybrid_fsdp: if self.outer_fsdp_group is None: raise ValueError( - "[FSDPDistributedIndex][use_hybrid_fsdp=True] Hybrid FSDP requires " - "an outer-DP process group (dp_outer_dim, outer_fsdp_group)." + "[FSDPDistributedIndex] Hybrid-Sharded Data Parallelism (HSDP) requires a " + "DP-Outer ProcessGroup for model replication or optimizer full-sharding. " + f"Check that {self.device_mesh} contains an outer DP sub-mesh.\n" + f"dp_outer_dim={self.dp_outer_dim} / outer_fsdp_group={self.outer_fsdp_group}" ) - if self.hybrid_fsdp_group is None: + if self.hsdp_outer_dp_shard and self.hybrid_fsdp_group is None: raise ValueError( - "[FSDPDistributedIndex][use_hybrid_fsdp=True] Hybrid FSDP requires " - "a hybrid FSDP process group (hybrid_fsdp_group). " - "This group can be manufactured by flattening the outer-DP " + "[FSDPDistributedIndex] Hybrid FSDP (HFSDP) requires a fully-flattened hybrid " + "FSDP process group (hybrid_fsdp_group). Created by flattening the outer-DP " "(dp_outer_dim, outer_fsdp_group) and FSDP (dp_shard_dim, fsdp_group) " - "process groups or sub-meshes." + "ProcessGroup(s) or sub-meshes." ) def get_submesh( diff --git a/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py b/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py index ee485dd0b0..d5251f217b 100644 --- a/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py +++ b/tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py @@ -1,5 +1,6 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import shutil -from contextlib import nullcontext from copy import deepcopy from pathlib import Path @@ -277,7 +278,10 @@ def test_fully_shard( dp_outer_dim=DP_OUTER if dp_outer_strategy is not None else None, tp_dim=TP, hybrid_fsdp_group=( - device_mesh[HSDP].get_group() if dp_outer_strategy is not None else None + # Only need this fully-flattened group if you are using HFSDP. + device_mesh[HSDP].get_group() + if dp_outer_strategy == OPTIM + else None ), fsdp_unit_modules=fsdp_unit_modules, zero_dp_strategy=dp_shard_strategy, @@ -326,27 +330,19 @@ def test_fully_shard( # Because of uneven sharding, we need to gather the result from all ranks # to verify if any gradients exist or not at this step of training. grads_exist_gathered = [None] * sharding_group.size() - torch.distributed.gather_object( - grads_exist, - object_gather_list=grads_exist_gathered if sharding_group.rank() == 0 else None, - group=sharding_group, - group_dst=0, + torch.distributed.all_gather_object( + object_list=grads_exist_gathered, obj=grads_exist, group=sharding_group ) - if sharding_group.rank() == 0: - # Gradients exist on at least one of the optimizer sharding ranks. - # Update grads_exist on Rank 0 only. - grads_exist = any(grads_exist_gathered) - torch.distributed.barrier() + # Gradients exist on at least one of the optimizer sharding ranks. + grads_exist = any(grads_exist_gathered) # Gradients do not exist until synchronization is activated. - # Use collected result on Rank 0 only. - if sharding_group.rank() == 0: - if step == NUM_STEPS - 1: - assert grads_exist, "Root module gradients should exist on final microbatch." - else: - assert ( - not grads_exist - ), "Root module gradients should not exist prior to optimization step." + if step == NUM_STEPS - 1: + assert grads_exist, "Root module gradients should exist on final microbatch." + else: + assert ( + not grads_exist + ), "Root module gradients should not exist prior to optimization step." torch.distributed.barrier() # Optimizer step. Apply accumulated gradients to the model weights. @@ -415,7 +411,10 @@ def test_dcp_checkpoint_save_and_load( dp_shard_dim=DP_SHARD_CP, dp_outer_dim=DP_OUTER, tp_dim=TP, - hybrid_fsdp_group=device_mesh[HSDP].get_group(), + # Only need this fully-flattened group if you are using HFSDP. + hybrid_fsdp_group=( + device_mesh[HSDP].get_group() if outer_shard_strategy == OPTIM else None + ), fsdp_unit_modules=fsdp_unit_modules, zero_dp_strategy=shard_strategy, outer_dp_sharding_strategy=outer_shard_strategy, @@ -496,7 +495,10 @@ def test_dcp_checkpoint_save_and_load( dp_shard_dim=DP_SHARD_CP, dp_outer_dim=DP_OUTER, tp_dim=TP, - hybrid_fsdp_group=device_mesh[HSDP].get_group(), + # Only need this fully-flattened group if you are using HFSDP. + hybrid_fsdp_group=( + device_mesh[HSDP].get_group() if outer_shard_strategy == OPTIM else None + ), fsdp_unit_modules=fsdp_unit_modules, zero_dp_strategy=shard_strategy, outer_dp_sharding_strategy=outer_shard_strategy,