diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index bdf480d867..6a294b6960 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -2782,6 +2782,9 @@ def reduce_gradients( outer_fsdp_group_grad_reduce (bool, optional): Whether to reduce gradients across outer-DP groups. Defaults to False. """ + # Sort parameters by their bucket IDs to ensure a deterministic processing order. + # Performing reduce-scatter operations out of order can lead to hangs. + params = sorted(list(params), key=lambda x: self.buffer.param_to_param_group[x]) for param in params: bucket_id = self.buffer.param_to_param_group[param] param_group = self.buffer.parameter_groups[bucket_id]