Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
# Default baseline for automatic PR checks
# Can be: branch name (e.g., 'main'), commit hash, or tag
# Will be resolved to commit hash during execution
DEFAULT_BASELINE: '712dff880cdf88e51289ad71e47d92f46d25a2d3'
DEFAULT_BASELINE: 'f7fb5ecbe218672719053fa304d91767ce30ffa1'
# Tag pattern for auto-detection (e.g., 'core_r*', 'core_v*')
TAG_PATTERN: 'core_v*'
# Tag regex filter (e.g., '^core_v[0-9]+\.[0-9]+\.[0-9]+$' for stable versions only)
Expand Down
8 changes: 4 additions & 4 deletions megatron/core/distributed/fsdp/src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ device_mesh = torch.distributed.device_mesh.init_device_mesh(
device_mesh[("dp_outer", "dp_shard")]._flatten("dp")
# Only required if using CP. Otherwise, just pass dp_shard to FSDP.
device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp")
# Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group.
# Only required if using HFSDP. Otherwise, don't pass hybrid_fsdp_group.
device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp")
hsdp_group = device_mesh["hsdp"].get_group()
# Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP.
Expand All @@ -149,7 +149,7 @@ model, optimizer = fully_shard(
# Only required for TP-sensitive models (i.e. Megatron-LM / TransformerEngine) or when using DTensor-based TP.
# Otherwise, set this to None.
tp_dim="tp",
# Only required when using HSDP. Otherwise, set this to None.
# Only required when fully-sharding the optimizer state in HFSDP. Otherwise, set this to None.
hybrid_fsdp_group=hsdp_group,
# Only required for FSDP + EP. Otherwise, set this to None.
expt_device_mesh=expt_device_mesh,
Expand Down Expand Up @@ -185,7 +185,7 @@ model.load_state_dict(ckpt_state_dict["model"], strict=False)
optimizer.load_state_dict(ckpt_state_dict["optimizer"])
```

- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported.
- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported.
- Default: `optim_grads_params` or `3` for `zero_dp_strategy` and `no_shard` or `0` for `outer_dp_sharding_strategy`
- `0` or `no_shard` implies that your model is not sharded. Similar memory usage to `DDP`.
- `1` or `optim` implies that your optimizer state is sharded for distributed optimization. Similar to optimizer state sharding in `ZeRO-DP`.
Expand All @@ -199,7 +199,7 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"])
- `dp_outer_dim` is the name of the sub-mesh corresponding to the "outer" DP group, which is required for replication or sharding in HSDP. `fully_shard` will perform HSDP if `dp_outer_dim` is specified.
- `tp_dim` is the name of the sub-mesh used for tensor parallelism (TP), which is required for `(FSDP, TP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` TP.
- For more information about tensor parallelism, refer to: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053).
- `hybrid_fsdp_group` is the `ProcessGroup` which contains all ranks in the flattened `dp_shard_dim` and `dp_outer_dim` sub-meshes utilized to specify the `(DP-Outer, DP-Shard)` sharded coordinate system for the weight and gradient buffers. Required for HSDP.
- `hybrid_fsdp_group` is the `ProcessGroup` which contains all ranks in the flattened `dp_shard_dim` and `dp_outer_dim` sub-meshes utilized to specify the `(DP-Outer, DP-Shard)` sharded coordinate system for the weight and gradient buffers. Required for HFSDP only, i.e. fully-sharded optimizer state with HSDP.
- `expt_device_mesh` is another [`torch.distributed.DeviceMesh`](https://docs.pytorch.org/docs/stable/distributed.html#devicemesh) tailored for the expert parallel (EP) modules in `MegatronFSDP`.
- `dp_shard_dim` is the name of the sub-mesh required for FSDP sharding of the EP modules, enabling expert data parallelism (EDP).
- `tp_dim` is the name of the sub-mesh used for expert tensor parallelism (ETP), which is required for `(FSDP, ETP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` ETP.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,16 @@ def fully_shard_model(
"zero_dp_strategy to use FSDP ('optim_grads_params', 3), because "
"outer sharding is dependent on inner sharding."
)
if (dp_outer_dim is None) ^ (hybrid_fsdp_group is None):
# XOR - HSDP requires both or neither of dp_outer_dim and hybrid_fsdp_group
# to be specified, so if XOR then raise an error.
if _outer_fsdp_sharding and hybrid_fsdp_group is None:
Copy link
Contributor

@shjwudp shjwudp Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to change this line? My understanding is that we only need to handle the new ValidateError introduced by the PyTorch 2.9 and perform the check before calling device_mesh._flatten.

Copy link
Member Author

@cspades cspades Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to your previous comment, this change and the above change are related to (1), so not related to the new validation error. The original argument validation was that if you are using HSDP or HFSDP, in which case dp_outer_dim is not None, then you must pass hybrid_fsdp_group. Now I am relaxing it to only when you are using HFSDP.

While we are talking about this, do you think this is a reasonable relaxation? I currently cannot imagine any situation where we need the fully-flattened DP ranks when we are using HSDP, since HSDP's collectives seem simple and orthogonal.

The 3 things I'm doing in this PR:

  1. Make HSDP easier to use.
  2. Fix the DeviceMesh validation error.
  3. Fix gradient unit tests.

# If fully-sharding the optimizer state on DP-Outer, you must provide the
# completely flattened HFSDP group for logical rank assignment to the
# optimizer state full-sharding ranks.
raise ValueError(
f"dp_outer_dim={dp_outer_dim} and hybrid_fsdp_group={hybrid_fsdp_group} must be "
"specified together for Hybrid FSDP (HSDP), or both set to None (for FSDP)."
"[HFSDP] Fully-sharding the optimizer on DP-Outer "
f"(outer_dp_sharding_strategy={outer_dp_sharding_strategy}) "
f"requires a fully-flattened hybrid_fsdp_group={hybrid_fsdp_group} "
"for rank assignment to the optimizer state. You can flatten your DeviceMesh "
f"via `DeviceMesh[(DP-Outer, DP-Shard)]._flatten()` & `DeviceMesh.get_group()`."
)
if init_model_with_meta_device and zero_dp_strategy == "no_shard":
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,21 @@ def _init_fsdp_param_and_grad_buffer(self):
else:
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
# Utilized to re-scale expert gradients to DP.
# (edp_size/dp_size) * (1/edp_size) = 1/dp_size
# FIXME(@cspades): Currently not used gradient_reduce_preprocessing()?
expert_gradient_scaling_factor = (
self.dist_index.get_dp_group(is_expert_parallel=True).size()
/ self.dist_index.get_dp_group().size()
Copy link
Contributor

@shjwudp shjwudp Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the torch 2.9 check affect the behavior of get_dp_group().size()? Do we need to update the logic here?

Copy link
Member Author

@cspades cspades Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this change is related to Torch version and upcoming (not released yet) DeviceMesh changes, it's related to our API requiring FSDPDistributedIndex.hybrid_fsdp_group which is used by get_dp_group().

I don't see a clear way to flatten DP-Inner and DP-Outer without messing with the user's distributed environment, so I don't think there is a clear way to improve the implementation of get_dp_group() at this moment.

So to summarize, the issue is:

  • Both HSDP and HFSDP use get_dp_group() -> FSDPDistributedIndex.hybrid_fsdp_group.
    • While HSDP does not "need" FSDPDistributedIndex.hybrid_fsdp_group, it makes sense for get_dp_group() to return the fully-flattened DP group when we are using HSDP, since that is the group that the function is designed to get.
  • HSDP does not need this for core Megatron-FSDP functionality. Only used to compute dp_group.size().

So I WAR the size() issue like this.

/ self.dist_index.get_fsdp_group().size()
)
if self.dist_index.use_hybrid_fsdp:
# Also divide the DP-Outer size in the conversion factor.
expert_gradient_scaling_factor /= self.dist_index.get_outer_fsdp_group().size()
else:
data_parallel_world_size = self.dist_index.get_dp_group().size()
data_parallel_world_size = self.dist_index.get_fsdp_group().size()
if self.dist_index.use_hybrid_fsdp:
# Also multiply the DP-Outer size in the DP size.
data_parallel_world_size *= self.dist_index.get_outer_fsdp_group().size()
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType
from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard

from .utils import get_mesh_names


def gather_and_compute_chunk_metadata(dtensor: DTensor) -> ChunkStorageMetadata:
"""
Expand Down Expand Up @@ -272,7 +274,14 @@ def gather_uneven_dtensor_to_full_tensor(
if not device_mesh.mesh_dim_names:
process_group = device_mesh.get_group()
else:
process_group = device_mesh._flatten().get_group()
# Check if the fully-flattened mesh exists first.
full_flattened_mesh_dim_name = "_".join(device_mesh.mesh_dim_names)
if full_flattened_mesh_dim_name in get_mesh_names(device_mesh):
# Retrieve the existing flattened DeviceMesh ProcessGroup.
process_group = device_mesh[full_flattened_mesh_dim_name].get_group()
else:
# Create the _-separated flattened DeviceMesh ProcessGroup.
process_group = device_mesh._flatten().get_group()
Comment on lines -275 to +284
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shjwudp Everything related to (2) is fixed in 3-4 lines of code here^^^

Before, we just immediately _flatten(). Going into Torch 2.10 or 2.11, they will not allow us to create new DeviceMesh that matches the flattened name of an existing DeviceMesh. So I just use our helper function get_mesh_names() which checks for sub- and flat- dimensions and if I find an existing flattened mesh, we just use that mesh.

I believe this still potentially has loopholes. If the user creates a flattened DeviceMesh dimension with the same name but different topology than our desired mesh, then it will use the user's setting instead. I feel like there is no way to fix this fundamental issue (but we can add a warning message, that may be a good idea), so it will be the user's responsibility to have reasonably-named flattened meshes, i.e. dp_cp ~ the flattening of the dp and cp dims, which is the default behavior (i.e. "_".join([<mesh dims to flatten>])) of device_mesh._flatten().

Another thing to note is the DTensor.device_mesh here is a child/sub-mesh of the Megatron-FSDP root mesh. In future Torch versions, the flattened mesh will be a member of the DeviceMesh used to flatten. So there will be a lower chance of issues, because the user will likely call root_mesh._flatten() but not call root_mesh[("dp_shard", "dp_outer")]._flatten(), so we will be less likely to accidentally use the user's original DeviceMesh!


# Collect chunk metadata for uneven shards (update if missing)
if not hasattr(dtensor._local_tensor, "__create_chunk_list__"):
Expand Down
24 changes: 11 additions & 13 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,10 @@ def get_mesh_names(device_mesh: Optional[DeviceMesh] = None) -> list[str]:
submesh_dim_name
for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items()
for submesh_dim_name in (child_mesh.mesh_dim_names or [])
if root_mesh == device_mesh
# Add flattened or other unaccounted for children of the root mesh.
if root_mesh == device_mesh and submesh_dim_name not in mesh_dim_names
]
# Combine without duplicate dimensions.
for dim_name in submesh_dim_names:
if dim_name not in mesh_dim_names:
mesh_dim_names.append(dim_name)
return mesh_dim_names
return mesh_dim_names + submesh_dim_names


def contains_submesh(
Expand Down Expand Up @@ -787,16 +784,17 @@ def register_submesh(device_mesh, submesh, is_expert_parallel):
if self.use_hybrid_fsdp:
if self.outer_fsdp_group is None:
raise ValueError(
"[FSDPDistributedIndex][use_hybrid_fsdp=True] Hybrid FSDP requires "
"an outer-DP process group (dp_outer_dim, outer_fsdp_group)."
"[FSDPDistributedIndex] Hybrid-Sharded Data Parallelism (HSDP) requires a "
"DP-Outer ProcessGroup for model replication or optimizer full-sharding. "
f"Check that {self.device_mesh} contains an outer DP sub-mesh.\n"
f"dp_outer_dim={self.dp_outer_dim} / outer_fsdp_group={self.outer_fsdp_group}"
)
if self.hybrid_fsdp_group is None:
if self.hsdp_outer_dp_shard and self.hybrid_fsdp_group is None:
raise ValueError(
"[FSDPDistributedIndex][use_hybrid_fsdp=True] Hybrid FSDP requires "
"a hybrid FSDP process group (hybrid_fsdp_group). "
"This group can be manufactured by flattening the outer-DP "
"[FSDPDistributedIndex] Hybrid FSDP (HFSDP) requires a fully-flattened hybrid "
"FSDP process group (hybrid_fsdp_group). Created by flattening the outer-DP "
"(dp_outer_dim, outer_fsdp_group) and FSDP (dp_shard_dim, fsdp_group) "
"process groups or sub-meshes."
"ProcessGroup(s) or sub-meshes."
)

def get_submesh(
Expand Down
46 changes: 24 additions & 22 deletions tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import shutil
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path

Expand Down Expand Up @@ -277,7 +278,10 @@ def test_fully_shard(
dp_outer_dim=DP_OUTER if dp_outer_strategy is not None else None,
tp_dim=TP,
hybrid_fsdp_group=(
device_mesh[HSDP].get_group() if dp_outer_strategy is not None else None
# Only need this fully-flattened group if you are using HFSDP.
device_mesh[HSDP].get_group()
if dp_outer_strategy == OPTIM
else None
),
fsdp_unit_modules=fsdp_unit_modules,
zero_dp_strategy=dp_shard_strategy,
Expand Down Expand Up @@ -326,27 +330,19 @@ def test_fully_shard(
# Because of uneven sharding, we need to gather the result from all ranks
# to verify if any gradients exist or not at this step of training.
grads_exist_gathered = [None] * sharding_group.size()
torch.distributed.gather_object(
grads_exist,
object_gather_list=grads_exist_gathered if sharding_group.rank() == 0 else None,
group=sharding_group,
group_dst=0,
torch.distributed.all_gather_object(
object_list=grads_exist_gathered, obj=grads_exist, group=sharding_group
)
if sharding_group.rank() == 0:
# Gradients exist on at least one of the optimizer sharding ranks.
# Update grads_exist on Rank 0 only.
grads_exist = any(grads_exist_gathered)
torch.distributed.barrier()
# Gradients exist on at least one of the optimizer sharding ranks.
grads_exist = any(grads_exist_gathered)

# Gradients do not exist until synchronization is activated.
# Use collected result on Rank 0 only.
if sharding_group.rank() == 0:
if step == NUM_STEPS - 1:
assert grads_exist, "Root module gradients should exist on final microbatch."
else:
assert (
not grads_exist
), "Root module gradients should not exist prior to optimization step."
if step == NUM_STEPS - 1:
assert grads_exist, "Root module gradients should exist on final microbatch."
else:
assert (
not grads_exist
), "Root module gradients should not exist prior to optimization step."
torch.distributed.barrier()

# Optimizer step. Apply accumulated gradients to the model weights.
Expand Down Expand Up @@ -415,7 +411,10 @@ def test_dcp_checkpoint_save_and_load(
dp_shard_dim=DP_SHARD_CP,
dp_outer_dim=DP_OUTER,
tp_dim=TP,
hybrid_fsdp_group=device_mesh[HSDP].get_group(),
# Only need this fully-flattened group if you are using HFSDP.
hybrid_fsdp_group=(
device_mesh[HSDP].get_group() if outer_shard_strategy == OPTIM else None
),
fsdp_unit_modules=fsdp_unit_modules,
zero_dp_strategy=shard_strategy,
outer_dp_sharding_strategy=outer_shard_strategy,
Expand Down Expand Up @@ -496,7 +495,10 @@ def test_dcp_checkpoint_save_and_load(
dp_shard_dim=DP_SHARD_CP,
dp_outer_dim=DP_OUTER,
tp_dim=TP,
hybrid_fsdp_group=device_mesh[HSDP].get_group(),
# Only need this fully-flattened group if you are using HFSDP.
hybrid_fsdp_group=(
device_mesh[HSDP].get_group() if outer_shard_strategy == OPTIM else None
),
fsdp_unit_modules=fsdp_unit_modules,
zero_dp_strategy=shard_strategy,
outer_dp_sharding_strategy=outer_shard_strategy,
Expand Down
Loading