Skip to content

Commit 91698d8

Browse files
committed
Relax the constraint to pass full-sharding group when using HSDP instead of HFSDP.
Signed-off-by: Cory Ye <[email protected]>
1 parent cdeb68c commit 91698d8

File tree

5 files changed

+48
-31
lines changed

5 files changed

+48
-31
lines changed

megatron/core/distributed/fsdp/src/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ device_mesh = torch.distributed.device_mesh.init_device_mesh(
124124
device_mesh[("dp_outer", "dp_shard")]._flatten("dp")
125125
# Only required if using CP. Otherwise, just pass dp_shard to FSDP.
126126
device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp")
127-
# Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group.
127+
# Only required if using HFSDP. Otherwise, don't pass hybrid_fsdp_group.
128128
device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp")
129129
hsdp_group = device_mesh["hsdp"].get_group()
130130
# Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP.
@@ -149,7 +149,7 @@ model, optimizer = fully_shard(
149149
# Only required for TP-sensitive models (i.e. Megatron-LM / TransformerEngine) or when using DTensor-based TP.
150150
# Otherwise, set this to None.
151151
tp_dim="tp",
152-
# Only required when using HSDP. Otherwise, set this to None.
152+
# Only required when fully-sharding the optimizer state in HFSDP. Otherwise, set this to None.
153153
hybrid_fsdp_group=hsdp_group,
154154
# Only required for FSDP + EP. Otherwise, set this to None.
155155
expt_device_mesh=expt_device_mesh,
@@ -185,7 +185,7 @@ model.load_state_dict(ckpt_state_dict["model"], strict=False)
185185
optimizer.load_state_dict(ckpt_state_dict["optimizer"])
186186
```
187187

188-
- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim` / `hybrid_fsdp_group`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported.
188+
- `zero_dp_strategy` (and `outer_dp_sharding_strategy`) configure different degrees of zero-redundancy data parallelism as described in [ZeRO (Zero Redundancy Optimizer)](https://arxiv.org/abs/1910.02054). It reduces CUDA memory utilization during model training by distributing model parameters, gradients, and optimizer states across multiple devices in the DP `ProcessGroup`, and collectively communicating subsets of parameters and gradients to specific devices when needed for computation or differentiation. More aggressive sharding strategies will entail more communication overhead, with `no_shard` being the least memory efficient but most communication efficient, and `optim_grads_params` being the most memory efficient but least communication efficient. `outer_dp_sharding_strategy` has the same options, except for the (required) "outer" DP group (`dp_outer_dim`) when using [Hybrid-Sharded Data Parallelism (HSDP)](https://arxiv.org/pdf/2304.11277), and only `no_shard` (DP Replication) and `optim` (Optimizer State Hybrid Sharding, requires `zero_dp_strategy='optim_grads_params`) are supported.
189189
- Default: `optim_grads_params` or `3` for `zero_dp_strategy` and `no_shard` or `0` for `outer_dp_sharding_strategy`
190190
- `0` or `no_shard` implies that your model is not sharded. Similar memory usage to `DDP`.
191191
- `1` or `optim` implies that your optimizer state is sharded for distributed optimization. Similar to optimizer state sharding in `ZeRO-DP`.
@@ -199,7 +199,7 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"])
199199
- `dp_outer_dim` is the name of the sub-mesh corresponding to the "outer" DP group, which is required for replication or sharding in HSDP. `fully_shard` will perform HSDP if `dp_outer_dim` is specified.
200200
- `tp_dim` is the name of the sub-mesh used for tensor parallelism (TP), which is required for `(FSDP, TP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` TP.
201201
- For more information about tensor parallelism, refer to: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053).
202-
- `hybrid_fsdp_group` is the `ProcessGroup` which contains all ranks in the flattened `dp_shard_dim` and `dp_outer_dim` sub-meshes utilized to specify the `(DP-Outer, DP-Shard)` sharded coordinate system for the weight and gradient buffers. Required for HSDP.
202+
- `hybrid_fsdp_group` is the `ProcessGroup` which contains all ranks in the flattened `dp_shard_dim` and `dp_outer_dim` sub-meshes utilized to specify the `(DP-Outer, DP-Shard)` sharded coordinate system for the weight and gradient buffers. Required for HFSDP only, i.e. fully-sharded optimizer state with HSDP.
203203
- `expt_device_mesh` is another [`torch.distributed.DeviceMesh`](https://docs.pytorch.org/docs/stable/distributed.html#devicemesh) tailored for the expert parallel (EP) modules in `MegatronFSDP`.
204204
- `dp_shard_dim` is the name of the sub-mesh required for FSDP sharding of the EP modules, enabling expert data parallelism (EDP).
205205
- `tp_dim` is the name of the sub-mesh used for expert tensor parallelism (ETP), which is required for `(FSDP, ETP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` ETP.

megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,16 @@ def fully_shard_model(
142142
"zero_dp_strategy to use FSDP ('optim_grads_params', 3), because "
143143
"outer sharding is dependent on inner sharding."
144144
)
145-
if (dp_outer_dim is None) ^ (hybrid_fsdp_group is None):
146-
# XOR - HSDP requires both or neither of dp_outer_dim and hybrid_fsdp_group
147-
# to be specified, so if XOR then raise an error.
145+
if _outer_fsdp_sharding and hybrid_fsdp_group is None:
146+
# If fully-sharding the optimizer state on DP-Outer, you must provide the
147+
# completely flattened HFSDP group for logical rank assignment to the
148+
# optimizer state full-sharding ranks.
148149
raise ValueError(
149-
f"dp_outer_dim={dp_outer_dim} and hybrid_fsdp_group={hybrid_fsdp_group} must be "
150-
"specified together for Hybrid FSDP (HSDP), or both set to None (for FSDP)."
150+
"[HFSDP] Fully-sharding the optimizer on DP-Outer "
151+
f"(outer_dp_sharding_strategy={outer_dp_sharding_strategy}) "
152+
f"requires a fully-flattened hybrid_fsdp_group={hybrid_fsdp_group} "
153+
"for rank assignment to the optimizer state. You can flatten your DeviceMesh "
154+
f"via `DeviceMesh[(DP-Outer, DP-Shard)]._flatten()` & `DeviceMesh.get_group()`."
151155
)
152156
if init_model_with_meta_device and zero_dp_strategy == "no_shard":
153157
raise ValueError(

megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,21 @@ def _init_fsdp_param_and_grad_buffer(self):
311311
else:
312312
if self.ddp_config.average_in_collective:
313313
gradient_scaling_factor = 1.0
314+
# Utilized to re-scale expert gradients to DP.
315+
# (edp_size/dp_size) * (1/edp_size) = 1/dp_size
316+
# FIXME(@cspades): Currently not used gradient_reduce_preprocessing()?
314317
expert_gradient_scaling_factor = (
315318
self.dist_index.get_dp_group(is_expert_parallel=True).size()
316-
/ self.dist_index.get_dp_group().size()
319+
/ self.dist_index.get_fsdp_group().size()
317320
)
321+
if self.dist_index.use_hybrid_fsdp:
322+
# Also divide the DP-Outer size in the conversion factor.
323+
expert_gradient_scaling_factor /= self.dist_index.get_outer_fsdp_group().size()
318324
else:
319-
data_parallel_world_size = self.dist_index.get_dp_group().size()
325+
data_parallel_world_size = self.dist_index.get_fsdp_group().size()
326+
if self.dist_index.use_hybrid_fsdp:
327+
# Also multiply the DP-Outer size in the DP size.
328+
data_parallel_world_size *= self.dist_index.get_outer_fsdp_group().size()
320329
gradient_scaling_factor = 1.0 / data_parallel_world_size
321330
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
322331

megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,10 @@ def get_mesh_names(device_mesh: Optional[DeviceMesh] = None) -> list[str]:
167167
submesh_dim_name
168168
for child_mesh, root_mesh in _mesh_resources.child_to_root_mapping.items()
169169
for submesh_dim_name in (child_mesh.mesh_dim_names or [])
170-
if root_mesh == device_mesh
170+
# Add flattened or other unaccounted for children of the root mesh.
171+
if root_mesh == device_mesh and submesh_dim_name not in mesh_dim_names
171172
]
172-
# Combine without duplicate dimensions.
173-
for dim_name in submesh_dim_names:
174-
if dim_name not in mesh_dim_names:
175-
mesh_dim_names.append(dim_name)
176-
return mesh_dim_names
173+
return mesh_dim_names + submesh_dim_names
177174

178175

179176
def contains_submesh(
@@ -787,16 +784,17 @@ def register_submesh(device_mesh, submesh, is_expert_parallel):
787784
if self.use_hybrid_fsdp:
788785
if self.outer_fsdp_group is None:
789786
raise ValueError(
790-
"[FSDPDistributedIndex][use_hybrid_fsdp=True] Hybrid FSDP requires "
791-
"an outer-DP process group (dp_outer_dim, outer_fsdp_group)."
787+
"[FSDPDistributedIndex] Hybrid-Sharded Data Parallelism (HSDP) requires a "
788+
"DP-Outer ProcessGroup for model replication or optimizer full-sharding. "
789+
f"Check that {self.device_mesh} contains an outer DP sub-mesh.\n"
790+
f"dp_outer_dim={self.dp_outer_dim} / outer_fsdp_group={self.outer_fsdp_group}"
792791
)
793-
if self.hybrid_fsdp_group is None:
792+
if self.hsdp_outer_dp_shard and self.hybrid_fsdp_group is None:
794793
raise ValueError(
795-
"[FSDPDistributedIndex][use_hybrid_fsdp=True] Hybrid FSDP requires "
796-
"a hybrid FSDP process group (hybrid_fsdp_group). "
797-
"This group can be manufactured by flattening the outer-DP "
794+
"[FSDPDistributedIndex] Hybrid FSDP (HFSDP) requires a fully-flattened hybrid "
795+
"FSDP process group (hybrid_fsdp_group). Created by flattening the outer-DP "
798796
"(dp_outer_dim, outer_fsdp_group) and FSDP (dp_shard_dim, fsdp_group) "
799-
"process groups or sub-meshes."
797+
"ProcessGroup(s) or sub-meshes."
800798
)
801799

802800
def get_submesh(

tests/unit_tests/distributed/fsdp/test_mfsdp_fully_shard.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import shutil
2-
from contextlib import nullcontext
32
from copy import deepcopy
43
from pathlib import Path
54

@@ -277,7 +276,10 @@ def test_fully_shard(
277276
dp_outer_dim=DP_OUTER if dp_outer_strategy is not None else None,
278277
tp_dim=TP,
279278
hybrid_fsdp_group=(
280-
device_mesh[HSDP].get_group() if dp_outer_strategy is not None else None
279+
# Only need this fully-flattened group if you are using HFSDP.
280+
device_mesh[HSDP].get_group()
281+
if dp_outer_strategy == OPTIM
282+
else None
281283
),
282284
fsdp_unit_modules=fsdp_unit_modules,
283285
zero_dp_strategy=dp_shard_strategy,
@@ -327,9 +329,7 @@ def test_fully_shard(
327329
# to verify if any gradients exist or not at this step of training.
328330
grads_exist_gathered = [None] * sharding_group.size()
329331
torch.distributed.all_gather_object(
330-
object_list=grads_exist_gathered,
331-
obj=grads_exist,
332-
group=sharding_group,
332+
object_list=grads_exist_gathered, obj=grads_exist, group=sharding_group
333333
)
334334
# Gradients exist on at least one of the optimizer sharding ranks.
335335
grads_exist = any(grads_exist_gathered)
@@ -409,7 +409,10 @@ def test_dcp_checkpoint_save_and_load(
409409
dp_shard_dim=DP_SHARD_CP,
410410
dp_outer_dim=DP_OUTER,
411411
tp_dim=TP,
412-
hybrid_fsdp_group=device_mesh[HSDP].get_group(),
412+
# Only need this fully-flattened group if you are using HFSDP.
413+
hybrid_fsdp_group=(
414+
device_mesh[HSDP].get_group() if outer_shard_strategy == OPTIM else None
415+
),
413416
fsdp_unit_modules=fsdp_unit_modules,
414417
zero_dp_strategy=shard_strategy,
415418
outer_dp_sharding_strategy=outer_shard_strategy,
@@ -490,7 +493,10 @@ def test_dcp_checkpoint_save_and_load(
490493
dp_shard_dim=DP_SHARD_CP,
491494
dp_outer_dim=DP_OUTER,
492495
tp_dim=TP,
493-
hybrid_fsdp_group=device_mesh[HSDP].get_group(),
496+
# Only need this fully-flattened group if you are using HFSDP.
497+
hybrid_fsdp_group=(
498+
device_mesh[HSDP].get_group() if outer_shard_strategy == OPTIM else None
499+
),
494500
fsdp_unit_modules=fsdp_unit_modules,
495501
zero_dp_strategy=shard_strategy,
496502
outer_dp_sharding_strategy=outer_shard_strategy,

0 commit comments

Comments
 (0)