Skip to content

Commit 13edb58

Browse files
committed
Revert "cp: Megatron-FSDP Expert Parallel (DeepSeek-v3) Support into dev (#1987)"
This reverts commit cc33e00.
1 parent cc33e00 commit 13edb58

File tree

21 files changed

+765
-2224
lines changed

21 files changed

+765
-2224
lines changed

megatron/core/distributed/fsdp/mcore_fsdp_adapter.py

Lines changed: 7 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import random
1716
from typing import List, Optional
1817

1918
try:
@@ -23,7 +22,6 @@
2322
except ImportError:
2423
HAVE_EINOPS = False
2524

26-
import numpy as np
2725
import torch
2826
import torch.distributed as dist
2927

@@ -34,11 +32,10 @@
3432
except ImportError:
3533
HAVE_DTENSOR = False
3634

37-
from megatron.core import parallel_state, tensor_parallel
35+
from megatron.core import parallel_state
3836
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
3937
from megatron.core.distributed.data_parallel_base import _BaseDataParallel
4038
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
41-
from megatron.core.extensions.transformer_engine import TELinear
4239
from megatron.core.process_groups_config import ProcessGroupCollection
4340
from megatron.core.transformer.transformer_config import TransformerConfig
4441
from megatron.core.transformer.transformer_layer import TransformerLayer
@@ -98,8 +95,6 @@ def __init__(
9895
else:
9996
self.fsdp_unit_modules = []
10097

101-
self._fix_tensor_parallel_attributes(module)
102-
10398
super().__init__(
10499
config=config,
105100
module=MegatronFSDP(
@@ -124,8 +119,6 @@ def __init__(
124119
self.module.state_dict_for_save_checkpoint = self.module.state_dict
125120
self.state_dict_for_save_checkpoint = self.state_dict
126121

127-
self.sync_rng_states_across_tp_group()
128-
129122
def load_state_dict(self, state_dict, strict=True):
130123
"""
131124
Load the state dictionary into the module.
@@ -148,44 +141,6 @@ def load_state_dict(self, state_dict, strict=True):
148141

149142
self.module.load_state_dict(custom_state_dict, strict=strict)
150143

151-
def _fix_tensor_parallel_attributes(self, module):
152-
is_expert_param = lambda n, p: ".experts." in n
153-
is_router_param = lambda n, p: ".router.weight" in n
154-
155-
if parallel_state.get_tensor_model_parallel_group():
156-
tp_size = parallel_state.get_tensor_model_parallel_group().size()
157-
else:
158-
tp_size = 1
159-
160-
if parallel_state.get_expert_tensor_parallel_group():
161-
expt_tp_size = parallel_state.get_expert_tensor_parallel_group().size()
162-
else:
163-
expt_tp_size = 1
164-
165-
param_to_direct_module = {}
166-
for name, m in module.named_modules():
167-
for p in m.parameters(recurse=False):
168-
param_to_direct_module[p] = (name, m)
169-
170-
for name, param in module.named_parameters():
171-
if is_expert_param(name, param) and expt_tp_size > 1:
172-
setattr(param, "_mcore_tp", True)
173-
if "linear_fc1.weight" in name:
174-
setattr(param, "_tp_partition_dim", 0)
175-
elif "linear_fc2.weight" in name:
176-
setattr(param, "_tp_partition_dim", 1)
177-
178-
if not is_expert_param(name, param) and tp_size > 1:
179-
m_name, direct_module = param_to_direct_module[param]
180-
if isinstance(direct_module, (TELinear,)):
181-
parallel_mode = getattr(direct_module, "parallel_mode", None)
182-
if parallel_mode is None:
183-
setattr(param, "_mcore_tp", True)
184-
setattr(param, "_tp_duplicated", True)
185-
elif is_router_param(name, param):
186-
setattr(param, "_mcore_tp", True)
187-
setattr(param, "_tp_duplicated", True)
188-
189144
def _init_dist_index(self, pg_collection):
190145
"""
191146
Initialize the distributed index for the module.
@@ -199,7 +154,6 @@ def _init_dist_index(self, pg_collection):
199154
enable_hsdp = self.ddp_config.num_distributed_optimizer_instances > 1
200155
if pg_collection is None:
201156
tp_group = parallel_state.get_tensor_model_parallel_group()
202-
expt_tp_group = parallel_state.get_expert_tensor_parallel_group()
203157
if enable_hsdp:
204158
dp_cp_group = parallel_state.get_data_parallel_group(
205159
with_context_parallel=True, partial_data_parallel=True
@@ -214,11 +168,8 @@ def _init_dist_index(self, pg_collection):
214168
)
215169
outer_fsdp_group = None
216170
hybrid_fsdp_group = None
217-
expt_dp_group = parallel_state.get_expert_data_parallel_group()
218-
ep_group = parallel_state.get_expert_model_parallel_group()
219171
else:
220172
tp_group = getattr(pg_collection, 'tp', None)
221-
expt_tp_group = getattr(pg_collection, 'expt_tp', None)
222173
if enable_hsdp:
223174
dp_cp_group = pg_collection.intra_dp_cp
224175
outer_fsdp_group = pg_collection.inter_dist_opt
@@ -227,17 +178,11 @@ def _init_dist_index(self, pg_collection):
227178
dp_cp_group = pg_collection.dp_cp
228179
outer_fsdp_group = None
229180
hybrid_fsdp_group = None
230-
expt_dp_group = getattr(pg_collection, 'expt_dp', None)
231-
ep_group = getattr(pg_collection, 'ep', None)
232181

233182
if tp_group is None:
234183
single_rank_group = dist.new_group(ranks=[dist.get_rank()])
235184
tp_group = single_rank_group
236185

237-
if expt_tp_group is None:
238-
single_rank_group = dist.new_group(ranks=[dist.get_rank()])
239-
expt_tp_group = single_rank_group
240-
241186
if enable_hsdp:
242187
mesh = _get_hsdp_tp_mesh(outer_fsdp_group, dp_cp_group, tp_group)
243188
dist_index = FSDPDistributedIndex(
@@ -254,17 +199,6 @@ def _init_dist_index(self, pg_collection):
254199
hybrid_fsdp_group=hybrid_fsdp_group,
255200
)
256201
else:
257-
if ep_group is not None:
258-
expt_mesh = _get_dp_tp_mesh(expt_dp_group, expt_tp_group, ep_size=ep_group.size())
259-
expt_device_mesh = DeviceMesh.from_group(
260-
[expt_dp_group, expt_tp_group],
261-
device_type="cuda",
262-
mesh=expt_mesh.tolist(),
263-
mesh_dim_names=["dp_cp", "tp"],
264-
)
265-
else:
266-
expt_device_mesh = None
267-
268202
mesh = _get_dp_tp_mesh(dp_cp_group, tp_group)
269203
dist_index = FSDPDistributedIndex(
270204
device_mesh=DeviceMesh.from_group(
@@ -275,11 +209,8 @@ def _init_dist_index(self, pg_collection):
275209
),
276210
dp_shard_dim="dp_cp",
277211
tp_dim="tp",
278-
expt_device_mesh=expt_device_mesh,
279212
)
280213

281-
self.tp_group = tp_group
282-
283214
return dist_index
284215

285216
def stop_communication(self):
@@ -289,20 +220,6 @@ def stop_communication(self):
289220
self.module.synchronize_gradient_reduce()
290221
self.module.synchronize_param_gather()
291222

292-
def sync_rng_states_across_tp_group(self):
293-
"""
294-
Synchronize the tensor parallel random number generator states.
295-
"""
296-
if self.tp_group.size() <= 1:
297-
return
298-
299-
if self.tp_group.rank() == 0:
300-
broadcast_list = [_get_rng_state_dict()]
301-
else:
302-
broadcast_list = [None]
303-
torch.distributed.broadcast_object_list(broadcast_list, group=self.tp_group, group_src=0)
304-
_load_rng_state_dict(broadcast_list[0])
305-
306223

307224
def _get_hsdp_tp_mesh(outer_fsdp_dp_group, dp_cp_group, tp_group):
308225
assert HAVE_EINOPS, "einops is not installed. Please install it with `pip install einops`."
@@ -356,46 +273,29 @@ def _get_hsdp_tp_mesh(outer_fsdp_dp_group, dp_cp_group, tp_group):
356273
return mesh
357274

358275

359-
def _get_dp_tp_mesh(dp_cp_group, tp_group, ep_size=1):
276+
def _get_dp_tp_mesh(dp_cp_group, tp_group):
360277
assert HAVE_EINOPS, "einops is not installed. Please install it with `pip install einops`."
361278
world_size = dist.get_world_size()
362279

363280
tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
364-
# TODO: Supports configurable (dp, cp, ep, tp) order.
365-
mesh = einops.rearrange(
366-
torch.arange(world_size),
367-
"(dp_cp ep tp) -> ep dp_cp tp",
368-
dp_cp=dp_cp_group.size(),
369-
tp=tp_size,
370-
ep=ep_size,
371-
)
281+
# TODO: Supports configurable (dp, cp, tp) order.
282+
mesh = einops.rearrange(torch.arange(world_size), "(dp_cp tp) -> dp_cp tp", tp=tp_size)
372283

373-
mesh_dp_ranks = einops.rearrange(mesh, 'ep dp_cp tp -> (ep tp) dp_cp', dp_cp=dp_cp_group.size())
284+
mesh_dp_ranks = einops.rearrange(mesh, 'dp_cp tp -> tp dp_cp', tp=tp_size)
374285
dp_cp_group_ranks = dist.get_process_group_ranks(dp_cp_group)
375286
assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_dp_ranks, dp_cp_group_ranks), (
376287
f"[Megatron-FSDP] Data Parallel ranks in the mesh {mesh_dp_ranks} "
377288
f"do not match the ranks in the DP group {dp_cp_group_ranks}."
378289
)
379290

380-
mesh_tp_ranks = einops.rearrange(mesh, 'ep dp_cp tp -> (dp_cp ep) tp', tp=tp_size)
291+
mesh_tp_ranks = einops.rearrange(mesh, 'dp_cp tp -> (dp_cp) tp', tp=tp_size)
381292
tp_group_ranks = dist.get_process_group_ranks(tp_group)
382293
assert _check_mesh_ranks_and_group_ranks_are_consistent(mesh_tp_ranks, tp_group_ranks), (
383294
f"[Megatron-FSDP] Tensor Parallel ranks in the mesh {mesh_tp_ranks} "
384295
f"do not match the ranks in the TP group {tp_group_ranks}."
385296
)
386297

387-
# Exclude the expert parallel dimension
388-
rank = dist.get_rank()
389-
dp_tp_meshes = [per_ep_mesh for per_ep_mesh in mesh if rank in per_ep_mesh.reshape(-1).tolist()]
390-
assert (
391-
len(dp_tp_meshes) == 1
392-
), f"[Megatron-FSDP] Current rank {rank} is not unique in the mesh ranks {mesh.tolist()}."
393-
assert len(dp_tp_meshes[0].reshape(-1).tolist()) == dp_cp_group.size() * tp_group.size(), (
394-
f"[Megatron-FSDP] DP-TP mesh size {len(dp_tp_meshes[0].reshape(-1).tolist())} "
395-
f"does not match expected size {dp_cp_group.size() * tp_group.size()}."
396-
)
397-
398-
return dp_tp_meshes[0]
298+
return mesh
399299

400300

401301
def _check_mesh_ranks_and_group_ranks_are_consistent(mesh_ranks, group_ranks):
@@ -410,22 +310,3 @@ def _check_mesh_ranks_and_group_ranks_are_consistent(mesh_ranks, group_ranks):
410310
f"{mesh_ranks.tolist()} does not match the group ranks {group_ranks}."
411311
)
412312
return sorted(current_ranks[0]) == sorted(group_ranks)
413-
414-
415-
def _get_rng_state_dict():
416-
rng_state_dict = {
417-
'random_rng_state': random.getstate(),
418-
'np_rng_state': np.random.get_state(),
419-
'torch_rng_state': torch.get_rng_state(),
420-
'cuda_rng_state': torch.cuda.get_rng_state(),
421-
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states(),
422-
}
423-
return rng_state_dict
424-
425-
426-
def _load_rng_state_dict(rng_state_dict):
427-
random.setstate(rng_state_dict['random_rng_state'])
428-
np.random.set_state(rng_state_dict['np_rng_state'])
429-
torch.set_rng_state(rng_state_dict['torch_rng_state'])
430-
torch.cuda.set_rng_state(rng_state_dict['cuda_rng_state'])
431-
tensor_parallel.get_cuda_rng_tracker().set_states(rng_state_dict['rng_tracker_states'])

megatron/core/distributed/fsdp/src/README.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,6 @@ device_mesh[("dp_shard", "cp")]._flatten("dp_shard_cp")
127127
# Only required if using HSDP. Otherwise, don't pass hybrid_fsdp_group.
128128
device_mesh[("dp_outer", "dp_shard", "cp")]._flatten("hsdp")
129129
hsdp_group = device_mesh["hsdp"].get_group()
130-
# Initialize DeviceMesh for expert parallel (EP) modules when using FSDP + EP.
131-
expert_device_mesh = torch.distributed.device_mesh.init_device_mesh(
132-
"cuda",
133-
mesh_shape=(expt_dp_shard_size, expt_tp_size),
134-
mesh_dim_names=("dp_shard", "tp"),
135-
)
136130

137131
# Fully-shards your model and distributes your optimizer.
138132
model, optimizer = fully_shard(
@@ -151,8 +145,6 @@ model, optimizer = fully_shard(
151145
tp_dim="tp",
152146
# Only required when using HSDP. Otherwise, set this to None.
153147
hybrid_fsdp_group=hsdp_group,
154-
# Only required for FSDP + EP. Otherwise, set this to None.
155-
expt_device_mesh=expt_device_mesh,
156148
# FSDP Sharding Strategy: no_shard (0) / optim (1) / optim_grads (2) / optim_grads_params (3)
157149
zero_dp_strategy=3,
158150
outer_dp_sharding_strategy=1,
@@ -200,9 +192,6 @@ optimizer.load_state_dict(ckpt_state_dict["optimizer"])
200192
- `tp_dim` is the name of the sub-mesh used for tensor parallelism (TP), which is required for `(FSDP, TP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` TP.
201193
- For more information about tensor parallelism, refer to: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053).
202194
- `hybrid_fsdp_group` is the `ProcessGroup` which contains all ranks in the flattened `dp_shard_dim` and `dp_outer_dim` sub-meshes utilized to specify the `(DP-Outer, DP-Shard)` sharded coordinate system for the weight and gradient buffers. Required for HSDP.
203-
- `expt_device_mesh` is another [`torch.distributed.DeviceMesh`](https://docs.pytorch.org/docs/stable/distributed.html#devicemesh) tailored for the expert parallel (EP) modules in `MegatronFSDP`.
204-
- `dp_shard_dim` is the name of the sub-mesh required for FSDP sharding of the EP modules, enabling expert data parallelism (EDP).
205-
- `tp_dim` is the name of the sub-mesh used for expert tensor parallelism (ETP), which is required for `(FSDP, ETP)`-strided sharding when using Megatron-LM or Torch-native `DTensor` ETP.
206195
- `init_model_with_meta_device` has `MegatronFSDP` initialize your `meta`-device model in shards on every CUDA device to avoid OOM when initializing extremely large models that cannot fit on a single device. Users can initialize their model on a [`meta`-device](https://docs.pytorch.org/docs/stable/meta.html) (`with torch.device('meta'): ...`), and ``MegatronFSDP`` will further shard and initialize the model parameters layer-by-layer adhering to the customizable `module.reset_parameters` method, which prevents the entire model from being allocated in memory at any point during runtime.
207196
- Defaults to `False`.
208197
- Note that the `device` argument which installs your model on a specific device or rank will be deactivated when `init_model_with_meta_device=True`.

megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def fully_shard_model(
6464
dp_outer_dim: Optional[str] = None,
6565
tp_dim: Optional[str] = None,
6666
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
67-
expt_device_mesh: Optional[DeviceMesh] = None,
6867
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
6968
zero_dp_strategy: str | int = 3,
7069
outer_dp_sharding_strategy: str | int = 0,
@@ -184,10 +183,8 @@ def fully_shard_model(
184183
tp_dim=tp_dim,
185184
# Only required for HSDP.
186185
hybrid_fsdp_group=hybrid_fsdp_group,
187-
# Access to flattened DP rank assignments for HSDP.
186+
# Access to flattened DP rank assignments for HFSDP.
188187
hsdp_outer_dp_shard=_outer_fsdp_sharding,
189-
# Only required for Megatron-FSDP + EP.
190-
expt_device_mesh=expt_device_mesh,
191188
)
192189

193190
# Wrap model in Megatron FSDP.
@@ -333,7 +330,6 @@ def fully_shard(
333330
dp_outer_dim: Optional[str] = None,
334331
tp_dim: Optional[str] = None,
335332
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
336-
expt_device_mesh: Optional[DeviceMesh] = None,
337333
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
338334
zero_dp_strategy: str | int = 3,
339335
outer_dp_sharding_strategy: str | int = 0,
@@ -395,9 +391,6 @@ def fully_shard(
395391
by flattening the outer-FSDP (dp_outer_dim) and FSDP (dp_shard_dim) process groups
396392
or sub-meshes. Defaults to None. Required for HSDP, i.e. if dp_outer_dim is not None.
397393
398-
expt_device_mesh (Optional[DeviceMesh]):
399-
Expert parallel device mesh object defining the topology for MoE distributed training.
400-
401394
fsdp_unit_modules (Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]]):
402395
List of (sub-)module classes or (sub-)module class import paths that are "units",
403396
which are torch.nn.Module(s) that are sharded and scheduled by Megatron-FSDP.
@@ -510,7 +503,6 @@ def fully_shard(
510503
dp_outer_dim=dp_outer_dim,
511504
tp_dim=tp_dim,
512505
hybrid_fsdp_group=hybrid_fsdp_group,
513-
expt_device_mesh=expt_device_mesh,
514506
fsdp_unit_modules=fsdp_unit_modules,
515507
zero_dp_strategy=zero_dp_strategy,
516508
outer_dp_sharding_strategy=outer_dp_sharding_strategy,

megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,7 @@ def __init__(
235235
self.dist_index = dist_index
236236

237237
# If Megatron Expert Parallelism is enabled, you need to provide an expt_dp_group.
238-
if (
239-
has_expert_parameters
240-
and self.dist_index.get_fsdp_group(is_expert_parallel=True) is None
241-
):
238+
if has_expert_parameters and self.dist_index.get_expert_dp_group() is None:
242239
raise ValueError(
243240
"[Megatron-FSDP] Megatron Expert Parallelism is enabled, but no expt_dp_group is"
244241
"provided."
@@ -356,7 +353,9 @@ def _init_fsdp_param_and_grad_buffer(self):
356353
)
357354

358355
# Set the suggested communication unit size for reduce-scatter and all-gather pipelines.
359-
suggested_communication_unit_size = self.ddp_config.suggested_communication_unit_size
356+
suggested_communication_unit_size = (
357+
self.ddp_config.suggested_communication_unit_size or 1_000_000_000
358+
)
360359
if suggested_communication_unit_size is None:
361360
if self.data_parallel_sharding_strategy == "optim_grads_params":
362361
total_param_elements = 0
@@ -371,8 +370,6 @@ def _init_fsdp_param_and_grad_buffer(self):
371370
suggested_communication_unit_size = total_param_elements // total_fsdp_module * 2
372371
elif self.bucket_size is not None:
373372
suggested_communication_unit_size = self.bucket_size
374-
else:
375-
suggested_communication_unit_size = 1_000_000_000
376373

377374
# Cap to 1B elements.
378375
suggested_communication_unit_size = max(

0 commit comments

Comments
 (0)