-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Various small fixes for Megatron-FSDP. #2346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ddbbe75
684a9b6
2e0266d
3116c60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -311,12 +311,21 @@ def _init_fsdp_param_and_grad_buffer(self): | |
| else: | ||
| if self.ddp_config.average_in_collective: | ||
| gradient_scaling_factor = 1.0 | ||
| # Utilized to re-scale expert gradients to DP. | ||
| # (edp_size/dp_size) * (1/edp_size) = 1/dp_size | ||
| # FIXME(@cspades): Currently not used gradient_reduce_preprocessing()? | ||
| expert_gradient_scaling_factor = ( | ||
| self.dist_index.get_dp_group(is_expert_parallel=True).size() | ||
| / self.dist_index.get_dp_group().size() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will the torch 2.9 check affect the behavior of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this change is related to Torch version and upcoming (not released yet) DeviceMesh changes, it's related to our API requiring I don't see a clear way to flatten DP-Inner and DP-Outer without messing with the user's distributed environment, so I don't think there is a clear way to improve the implementation of So to summarize, the issue is:
So I WAR the |
||
| / self.dist_index.get_fsdp_group().size() | ||
| ) | ||
| if self.dist_index.use_hybrid_fsdp: | ||
| # Also divide the DP-Outer size in the conversion factor. | ||
| expert_gradient_scaling_factor /= self.dist_index.get_outer_fsdp_group().size() | ||
| else: | ||
| data_parallel_world_size = self.dist_index.get_dp_group().size() | ||
| data_parallel_world_size = self.dist_index.get_fsdp_group().size() | ||
| if self.dist_index.use_hybrid_fsdp: | ||
| # Also multiply the DP-Outer size in the DP size. | ||
| data_parallel_world_size *= self.dist_index.get_outer_fsdp_group().size() | ||
| gradient_scaling_factor = 1.0 / data_parallel_world_size | ||
| expert_gradient_scaling_factor = 1.0 / data_parallel_world_size | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,8 @@ | |
| from torch.distributed.checkpoint.planner import TensorWriteData, WriteItem, WriteItemType | ||
| from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard | ||
|
|
||
| from .utils import get_mesh_names | ||
|
|
||
|
|
||
| def gather_and_compute_chunk_metadata(dtensor: DTensor) -> ChunkStorageMetadata: | ||
| """ | ||
|
|
@@ -272,7 +274,14 @@ def gather_uneven_dtensor_to_full_tensor( | |
| if not device_mesh.mesh_dim_names: | ||
| process_group = device_mesh.get_group() | ||
| else: | ||
| process_group = device_mesh._flatten().get_group() | ||
| # Check if the fully-flattened mesh exists first. | ||
| full_flattened_mesh_dim_name = "_".join(device_mesh.mesh_dim_names) | ||
| if full_flattened_mesh_dim_name in get_mesh_names(device_mesh): | ||
| # Retrieve the existing flattened DeviceMesh ProcessGroup. | ||
| process_group = device_mesh[full_flattened_mesh_dim_name].get_group() | ||
| else: | ||
| # Create the _-separated flattened DeviceMesh ProcessGroup. | ||
| process_group = device_mesh._flatten().get_group() | ||
|
Comment on lines
-275
to
+284
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @shjwudp Everything related to (2) is fixed in 3-4 lines of code here^^^ Before, we just immediately I believe this still potentially has loopholes. If the user creates a flattened DeviceMesh dimension with the same name but different topology than our desired mesh, then it will use the user's setting instead. I feel like there is no way to fix this fundamental issue (but we can add a warning message, that may be a good idea), so it will be the user's responsibility to have reasonably-named flattened meshes, i.e. Another thing to note is the |
||
|
|
||
| # Collect chunk metadata for uneven shards (update if missing) | ||
| if not hasattr(dtensor._local_tensor, "__create_chunk_list__"): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to change this line? My understanding is that we only need to handle the new ValidateError introduced by the PyTorch 2.9 and perform the check before calling
device_mesh._flatten.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to your previous comment, this change and the above change are related to (1), so not related to the new validation error. The original argument validation was that if you are using HSDP or HFSDP, in which case
dp_outer_dim is not None, then you must passhybrid_fsdp_group. Now I am relaxing it to only when you are using HFSDP.While we are talking about this, do you think this is a reasonable relaxation? I currently cannot imagine any situation where we need the fully-flattened DP ranks when we are using HSDP, since HSDP's collectives seem simple and orthogonal.
The 3 things I'm doing in this PR: