Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2782,6 +2782,9 @@ def reduce_gradients(
outer_fsdp_group_grad_reduce (bool, optional): Whether to reduce gradients
across outer-DP groups. Defaults to False.
"""
# Sort parameters by their bucket IDs to ensure a deterministic processing order.
# Performing reduce-scatter operations out of order can lead to hangs.
params = sorted(list(params), key=lambda x: self.buffer.param_to_param_group[x])
Copy link
Contributor

@Skylion007 Skylion007 Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the cast to list is necessary, sorted can take an an Iterable including a set.

Suggested change
params = sorted(list(params), key=lambda x: self.buffer.param_to_param_group[x])
params = sorted(params, key=lambda x: self.buffer.param_to_param_group[x])

for param in params:
bucket_id = self.buffer.param_to_param_group[param]
param_group = self.buffer.parameter_groups[bucket_id]
Expand Down
Loading