File tree Expand file tree Collapse file tree 1 file changed +3
-9
lines changed Expand file tree Collapse file tree 1 file changed +3
-9
lines changed Original file line number Diff line number Diff line change @@ -290,15 +290,9 @@ def sharded_state_dict(
290290 local_params = group .pop ('params' )
291291 # save whether this group is empty, so we can use non-empty rank for metadata
292292 group ['params' ] = bool (local_params .unwrap ())
293- all_rank_groups = [None for _ in range (get_pg_size (self .pg_collection .dp_cp ))]
294- torch .distributed .all_gather_object (
295- all_rank_groups , group , self .pg_collection .dp_cp
296- )
297- nonempty_rank_group = next ((g for g in all_rank_groups if g ['params' ]), None )
298- if nonempty_rank_group is None :
299- raise ValueError (
300- 'LayerWiseDistributedOptimizer dist save seeing empty groups on all ranks.'
301- )
293+ all_rank_groups = [None for _ in range (torch .distributed .get_world_size ())]
294+ torch .distributed .all_gather_object (all_rank_groups , group )
295+ nonempty_rank_group = next ((g for g in all_rank_groups if g ['params' ]), group )
302296 nonempty_rank_group ['params' ] = local_params
303297 sd ['optimizer' ]['param_groups' ][i ] = nonempty_rank_group
304298 return sharded_state_dict
You can’t perform that action at this time.
0 commit comments