diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 5a0793ef5b9..e8c6a05340c 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -770,12 +770,15 @@ def reduce_aux_losses_tracker_across_ranks(track_names: Optional[List[str]] = No torch.distributed.all_reduce( values, group=tracker[name]['avg_group'], op=torch.distributed.ReduceOp.AVG ) - # This ensures proper loss averaging across all ranks including CP ranks - torch.distributed.all_reduce( - values, - group=parallel_state.get_data_parallel_group(with_context_parallel=True), - op=torch.distributed.ReduceOp.AVG, - ) + # Average aux losses across data parallel ranks. + # The `global_load_balancing_loss` already uses `tp_dp_cp_group` in `reduce_group`, + # so we don't need to reduce it again. Others use `tp_cp_group` in `reduce_group`. + if name != "global_load_balancing_loss": + torch.distributed.all_reduce( + values, + group=parallel_state.get_data_parallel_group(with_context_parallel=False), + op=torch.distributed.ReduceOp.AVG, + ) def track_moe_metrics(