Skip to content

Commit 575deca

Browse files
committed
fix empty group on pp chunk
1 parent 3d1e121 commit 575deca

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

megatron/core/optimizer/layer_wise_optimizer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -290,15 +290,9 @@ def sharded_state_dict(
290290
local_params = group.pop('params')
291291
# save whether this group is empty, so we can use non-empty rank for metadata
292292
group['params'] = bool(local_params.unwrap())
293-
all_rank_groups = [None for _ in range(get_pg_size(self.pg_collection.dp_cp))]
294-
torch.distributed.all_gather_object(
295-
all_rank_groups, group, self.pg_collection.dp_cp
296-
)
297-
nonempty_rank_group = next((g for g in all_rank_groups if g['params']), None)
298-
if nonempty_rank_group is None:
299-
raise ValueError(
300-
'LayerWiseDistributedOptimizer dist save seeing empty groups on all ranks.'
301-
)
293+
all_rank_groups = [None for _ in range(torch.distributed.get_world_size())]
294+
torch.distributed.all_gather_object(all_rank_groups, group)
295+
nonempty_rank_group = next((g for g in all_rank_groups if g['params']), group)
302296
nonempty_rank_group['params'] = local_params
303297
sd['optimizer']['param_groups'][i] = nonempty_rank_group
304298
return sharded_state_dict

0 commit comments

Comments
 (0)