diff --git a/megatron/core/distributed/distributed_data_parallel_config.py b/megatron/core/distributed/distributed_data_parallel_config.py index 3f97beab82..e2a026d836 100644 --- a/megatron/core/distributed/distributed_data_parallel_config.py +++ b/megatron/core/distributed/distributed_data_parallel_config.py @@ -146,6 +146,14 @@ def __post_init__(self): """Check the validity of the config.""" if self.reuse_grad_buf_for_mxfp8_param_ag: assert self.fp8_param_gather, "Reuse grad buffer only when keeping params in MXFP8." + # Using mxfp8 param without overlap param gather and overlap grad reduce will cause NaN. + # TODO: Remove this assertion when the issue is fixed. + assert ( + self.overlap_param_gather + ), "--overlap-param-gather is required when using mxfp8 params" + assert ( + self.overlap_grad_reduce + ), "--overlap-grad-reduce is required when using mxfp8 params" if self.nccl_ub: if 'expandable_segments:True' in os.getenv('PYTORCH_CUDA_ALLOC_CONF', '').split(','): diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py index 8682675849..5151ecabfb 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py @@ -137,6 +137,14 @@ def __post_init__(self): """Check the validity of the config.""" if self.reuse_grad_buf_for_mxfp8_param_ag: assert self.fp8_param_gather, "Reuse grad buffer only when keeping params in MXFP8." + # Using mxfp8 param without overlap param gather and overlap grad reduce will cause NaN. + # TODO: Remove this assertion when the issue is fixed. + assert ( + self.overlap_param_gather + ), "--overlap-param-gather is required when using mxfp8 params" + assert ( + self.overlap_grad_reduce + ), "--overlap-grad-reduce is required when using mxfp8 params" if self.nccl_ub: if 'expandable_segments:True' in os.getenv('PYTORCH_CUDA_ALLOC_CONF', '').split(','):