diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 1e41bf9d8c..1916bfff07 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -1122,6 +1122,7 @@ def initialize_model_parallel( for ranks in expert_decoder_rank_generator.get_ranks('ep'): group = create_group( ranks, + timeout=timeout, pg_options=get_nccl_options("ep", nccl_comm_cfgs), group_desc="EXPERT_MODEL_PARALLEL_GROUP", )