1313# limitations under the License.
1414
1515import logging
16- import random
1716from typing import List , Optional
1817
1918try :
2322except ImportError :
2423 HAVE_EINOPS = False
2524
26- import numpy as np
2725import torch
2826import torch .distributed as dist
2927
3432except ImportError :
3533 HAVE_DTENSOR = False
3634
37- from megatron .core import parallel_state , tensor_parallel
35+ from megatron .core import parallel_state
3836from megatron .core .config_logger import has_config_logger_enabled , log_config_to_disk
3937from megatron .core .distributed .data_parallel_base import _BaseDataParallel
4038from megatron .core .distributed .distributed_data_parallel_config import DistributedDataParallelConfig
41- from megatron .core .extensions .transformer_engine import TELinear
4239from megatron .core .process_groups_config import ProcessGroupCollection
4340from megatron .core .transformer .transformer_config import TransformerConfig
4441from megatron .core .transformer .transformer_layer import TransformerLayer
@@ -98,8 +95,6 @@ def __init__(
9895 else :
9996 self .fsdp_unit_modules = []
10097
101- self ._fix_tensor_parallel_attributes (module )
102-
10398 super ().__init__ (
10499 config = config ,
105100 module = MegatronFSDP (
@@ -124,8 +119,6 @@ def __init__(
124119 self .module .state_dict_for_save_checkpoint = self .module .state_dict
125120 self .state_dict_for_save_checkpoint = self .state_dict
126121
127- self .sync_rng_states_across_tp_group ()
128-
129122 def load_state_dict (self , state_dict , strict = True ):
130123 """
131124 Load the state dictionary into the module.
@@ -148,44 +141,6 @@ def load_state_dict(self, state_dict, strict=True):
148141
149142 self .module .load_state_dict (custom_state_dict , strict = strict )
150143
151- def _fix_tensor_parallel_attributes (self , module ):
152- is_expert_param = lambda n , p : ".experts." in n
153- is_router_param = lambda n , p : ".router.weight" in n
154-
155- if parallel_state .get_tensor_model_parallel_group ():
156- tp_size = parallel_state .get_tensor_model_parallel_group ().size ()
157- else :
158- tp_size = 1
159-
160- if parallel_state .get_expert_tensor_parallel_group ():
161- expt_tp_size = parallel_state .get_expert_tensor_parallel_group ().size ()
162- else :
163- expt_tp_size = 1
164-
165- param_to_direct_module = {}
166- for name , m in module .named_modules ():
167- for p in m .parameters (recurse = False ):
168- param_to_direct_module [p ] = (name , m )
169-
170- for name , param in module .named_parameters ():
171- if is_expert_param (name , param ) and expt_tp_size > 1 :
172- setattr (param , "_mcore_tp" , True )
173- if "linear_fc1.weight" in name :
174- setattr (param , "_tp_partition_dim" , 0 )
175- elif "linear_fc2.weight" in name :
176- setattr (param , "_tp_partition_dim" , 1 )
177-
178- if not is_expert_param (name , param ) and tp_size > 1 :
179- m_name , direct_module = param_to_direct_module [param ]
180- if isinstance (direct_module , (TELinear ,)):
181- parallel_mode = getattr (direct_module , "parallel_mode" , None )
182- if parallel_mode is None :
183- setattr (param , "_mcore_tp" , True )
184- setattr (param , "_tp_duplicated" , True )
185- elif is_router_param (name , param ):
186- setattr (param , "_mcore_tp" , True )
187- setattr (param , "_tp_duplicated" , True )
188-
189144 def _init_dist_index (self , pg_collection ):
190145 """
191146 Initialize the distributed index for the module.
@@ -199,7 +154,6 @@ def _init_dist_index(self, pg_collection):
199154 enable_hsdp = self .ddp_config .num_distributed_optimizer_instances > 1
200155 if pg_collection is None :
201156 tp_group = parallel_state .get_tensor_model_parallel_group ()
202- expt_tp_group = parallel_state .get_expert_tensor_parallel_group ()
203157 if enable_hsdp :
204158 dp_cp_group = parallel_state .get_data_parallel_group (
205159 with_context_parallel = True , partial_data_parallel = True
@@ -214,11 +168,8 @@ def _init_dist_index(self, pg_collection):
214168 )
215169 outer_fsdp_group = None
216170 hybrid_fsdp_group = None
217- expt_dp_group = parallel_state .get_expert_data_parallel_group ()
218- ep_group = parallel_state .get_expert_model_parallel_group ()
219171 else :
220172 tp_group = getattr (pg_collection , 'tp' , None )
221- expt_tp_group = getattr (pg_collection , 'expt_tp' , None )
222173 if enable_hsdp :
223174 dp_cp_group = pg_collection .intra_dp_cp
224175 outer_fsdp_group = pg_collection .inter_dist_opt
@@ -227,17 +178,11 @@ def _init_dist_index(self, pg_collection):
227178 dp_cp_group = pg_collection .dp_cp
228179 outer_fsdp_group = None
229180 hybrid_fsdp_group = None
230- expt_dp_group = getattr (pg_collection , 'expt_dp' , None )
231- ep_group = getattr (pg_collection , 'ep' , None )
232181
233182 if tp_group is None :
234183 single_rank_group = dist .new_group (ranks = [dist .get_rank ()])
235184 tp_group = single_rank_group
236185
237- if expt_tp_group is None :
238- single_rank_group = dist .new_group (ranks = [dist .get_rank ()])
239- expt_tp_group = single_rank_group
240-
241186 if enable_hsdp :
242187 mesh = _get_hsdp_tp_mesh (outer_fsdp_group , dp_cp_group , tp_group )
243188 dist_index = FSDPDistributedIndex (
@@ -254,17 +199,6 @@ def _init_dist_index(self, pg_collection):
254199 hybrid_fsdp_group = hybrid_fsdp_group ,
255200 )
256201 else :
257- if ep_group is not None :
258- expt_mesh = _get_dp_tp_mesh (expt_dp_group , expt_tp_group , ep_size = ep_group .size ())
259- expt_device_mesh = DeviceMesh .from_group (
260- [expt_dp_group , expt_tp_group ],
261- device_type = "cuda" ,
262- mesh = expt_mesh .tolist (),
263- mesh_dim_names = ["dp_cp" , "tp" ],
264- )
265- else :
266- expt_device_mesh = None
267-
268202 mesh = _get_dp_tp_mesh (dp_cp_group , tp_group )
269203 dist_index = FSDPDistributedIndex (
270204 device_mesh = DeviceMesh .from_group (
@@ -275,11 +209,8 @@ def _init_dist_index(self, pg_collection):
275209 ),
276210 dp_shard_dim = "dp_cp" ,
277211 tp_dim = "tp" ,
278- expt_device_mesh = expt_device_mesh ,
279212 )
280213
281- self .tp_group = tp_group
282-
283214 return dist_index
284215
285216 def stop_communication (self ):
@@ -289,20 +220,6 @@ def stop_communication(self):
289220 self .module .synchronize_gradient_reduce ()
290221 self .module .synchronize_param_gather ()
291222
292- def sync_rng_states_across_tp_group (self ):
293- """
294- Synchronize the tensor parallel random number generator states.
295- """
296- if self .tp_group .size () <= 1 :
297- return
298-
299- if self .tp_group .rank () == 0 :
300- broadcast_list = [_get_rng_state_dict ()]
301- else :
302- broadcast_list = [None ]
303- torch .distributed .broadcast_object_list (broadcast_list , group = self .tp_group , group_src = 0 )
304- _load_rng_state_dict (broadcast_list [0 ])
305-
306223
307224def _get_hsdp_tp_mesh (outer_fsdp_dp_group , dp_cp_group , tp_group ):
308225 assert HAVE_EINOPS , "einops is not installed. Please install it with `pip install einops`."
@@ -356,46 +273,29 @@ def _get_hsdp_tp_mesh(outer_fsdp_dp_group, dp_cp_group, tp_group):
356273 return mesh
357274
358275
359- def _get_dp_tp_mesh (dp_cp_group , tp_group , ep_size = 1 ):
276+ def _get_dp_tp_mesh (dp_cp_group , tp_group ):
360277 assert HAVE_EINOPS , "einops is not installed. Please install it with `pip install einops`."
361278 world_size = dist .get_world_size ()
362279
363280 tp_size = dist .get_world_size (tp_group ) if tp_group is not None else 1
364- # TODO: Supports configurable (dp, cp, ep, tp) order.
365- mesh = einops .rearrange (
366- torch .arange (world_size ),
367- "(dp_cp ep tp) -> ep dp_cp tp" ,
368- dp_cp = dp_cp_group .size (),
369- tp = tp_size ,
370- ep = ep_size ,
371- )
281+ # TODO: Supports configurable (dp, cp, tp) order.
282+ mesh = einops .rearrange (torch .arange (world_size ), "(dp_cp tp) -> dp_cp tp" , tp = tp_size )
372283
373- mesh_dp_ranks = einops .rearrange (mesh , 'ep dp_cp tp -> (ep tp) dp_cp' , dp_cp = dp_cp_group . size () )
284+ mesh_dp_ranks = einops .rearrange (mesh , 'dp_cp tp -> tp dp_cp' , tp = tp_size )
374285 dp_cp_group_ranks = dist .get_process_group_ranks (dp_cp_group )
375286 assert _check_mesh_ranks_and_group_ranks_are_consistent (mesh_dp_ranks , dp_cp_group_ranks ), (
376287 f"[Megatron-FSDP] Data Parallel ranks in the mesh { mesh_dp_ranks } "
377288 f"do not match the ranks in the DP group { dp_cp_group_ranks } ."
378289 )
379290
380- mesh_tp_ranks = einops .rearrange (mesh , 'ep dp_cp tp -> (dp_cp ep ) tp' , tp = tp_size )
291+ mesh_tp_ranks = einops .rearrange (mesh , 'dp_cp tp -> (dp_cp) tp' , tp = tp_size )
381292 tp_group_ranks = dist .get_process_group_ranks (tp_group )
382293 assert _check_mesh_ranks_and_group_ranks_are_consistent (mesh_tp_ranks , tp_group_ranks ), (
383294 f"[Megatron-FSDP] Tensor Parallel ranks in the mesh { mesh_tp_ranks } "
384295 f"do not match the ranks in the TP group { tp_group_ranks } ."
385296 )
386297
387- # Exclude the expert parallel dimension
388- rank = dist .get_rank ()
389- dp_tp_meshes = [per_ep_mesh for per_ep_mesh in mesh if rank in per_ep_mesh .reshape (- 1 ).tolist ()]
390- assert (
391- len (dp_tp_meshes ) == 1
392- ), f"[Megatron-FSDP] Current rank { rank } is not unique in the mesh ranks { mesh .tolist ()} ."
393- assert len (dp_tp_meshes [0 ].reshape (- 1 ).tolist ()) == dp_cp_group .size () * tp_group .size (), (
394- f"[Megatron-FSDP] DP-TP mesh size { len (dp_tp_meshes [0 ].reshape (- 1 ).tolist ())} "
395- f"does not match expected size { dp_cp_group .size () * tp_group .size ()} ."
396- )
397-
398- return dp_tp_meshes [0 ]
298+ return mesh
399299
400300
401301def _check_mesh_ranks_and_group_ranks_are_consistent (mesh_ranks , group_ranks ):
@@ -410,22 +310,3 @@ def _check_mesh_ranks_and_group_ranks_are_consistent(mesh_ranks, group_ranks):
410310 f"{ mesh_ranks .tolist ()} does not match the group ranks { group_ranks } ."
411311 )
412312 return sorted (current_ranks [0 ]) == sorted (group_ranks )
413-
414-
415- def _get_rng_state_dict ():
416- rng_state_dict = {
417- 'random_rng_state' : random .getstate (),
418- 'np_rng_state' : np .random .get_state (),
419- 'torch_rng_state' : torch .get_rng_state (),
420- 'cuda_rng_state' : torch .cuda .get_rng_state (),
421- 'rng_tracker_states' : tensor_parallel .get_cuda_rng_tracker ().get_states (),
422- }
423- return rng_state_dict
424-
425-
426- def _load_rng_state_dict (rng_state_dict ):
427- random .setstate (rng_state_dict ['random_rng_state' ])
428- np .random .set_state (rng_state_dict ['np_rng_state' ])
429- torch .set_rng_state (rng_state_dict ['torch_rng_state' ])
430- torch .cuda .set_rng_state (rng_state_dict ['cuda_rng_state' ])
431- tensor_parallel .get_cuda_rng_tracker ().set_states (rng_state_dict ['rng_tracker_states' ])
0 commit comments