Skip to content

[SimpleFSDP] Add support for hsdp+tp #1343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _distribute_dtensor(
Below are experimental enhancements to distribute a DTensor.
This helps enable Simple FSDP + TP, in which
inner spec/mesh is TP spec/mesh
outer spec/mesh is FSDP spec/mesh
outer spec/mesh is FSDP/DDP/HSDP spec/mesh
The logic follows
https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fsdp/_fsdp_param.py#L261
"""
Expand All @@ -78,38 +78,68 @@ def _distribute_dtensor(
submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
spanned_mesh = outer_global_mesh[submesh_names]

if placements[0].is_shard():
# for FSDP + TP dtensor placement
shard_dim = placements[0].dim
current_tensor = tensor._local_tensor
current_spec = DTensorSpec(
mesh=outer_mesh,
placements=(Replicate(),),
tensor_meta=inner_spec.tensor_meta,
)
target_spec = DTensorSpec(
mesh=outer_mesh,
placements=(placements[0],),
tensor_meta=inner_spec.tensor_meta,
)

if len(placements) == 1:
assert placements[0].is_replicate() or placements[0].is_shard()
if placements[0].is_shard():
shard_dim = placements[0].dim
split_factor = inner_spec.num_shards_map[shard_dim]
tensor_placement = (
(
_StridedShard(shard_dim, split_factor=split_factor)
if split_factor > 1
else placements[0]
),
inner_spec.placements[0],
)
else:
tensor_placement = (placements[0], inner_spec.placements[0])
elif len(placements) == 2:
assert placements[0].is_replicate() and placements[1].is_shard()
shard_dim = placements[1].dim
split_factor = inner_spec.num_shards_map[shard_dim]
tensor_placement = (
placements[0],
(
_StridedShard(shard_dim, split_factor=split_factor)
if split_factor > 1
else placements[0]
else placements[1]
),
inner_spec.placements[0],
)
elif placements[0].is_replicate():
# for DDP + TP dtensor placement
tensor_placement = (placements[0], inner_spec.placements[0])
current_tensor = redistribute_local_tensor(
current_tensor,
current_spec=current_spec,
target_spec=target_spec,
)
current_spec = DTensorSpec(
mesh=outer_mesh,
placements=(placements[0],),
tensor_meta=inner_spec.tensor_meta,
)
target_spec = DTensorSpec(
mesh=outer_mesh,
placements=(placements[1],),
tensor_meta=inner_spec.tensor_meta,
)
else:
raise ValueError(
f"Unsupported placement {placements[0]} for distributing DTensor {tensor}"
f"Unsupported placement {placements} for distributing DTensor {tensor}"
)

current_spec = DTensorSpec(
mesh=outer_mesh,
placements=(Replicate(),),
tensor_meta=inner_spec.tensor_meta,
)
target_spec = DTensorSpec(
mesh=outer_mesh,
placements=(placements[0],),
tensor_meta=inner_spec.tensor_meta,
)
result_tensor = redistribute_local_tensor(
tensor._local_tensor,
current_tensor,
current_spec=current_spec,
target_spec=target_spec,
)
Expand Down Expand Up @@ -188,9 +218,9 @@ def replicate_compute(self, x):
# the gradients are partial tensors that needs to perform reduction
# (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)

# support for FSDP/DDP + TP (assuming TP shards the inner-most dim)
# support for FSDP/DDP/HSDP + TP (assuming TP shards the inner-most dim)
if x._spec.mesh.mesh_dim_names[-1] == "tp":
dp_placement, tp_placement = x._spec.placements
tp_placement = x._spec.placements[-1]
# TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
# after DeviceMesh supports slicing a non-root mesh
# dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"]
Expand Down
25 changes: 12 additions & 13 deletions torchtitan/experiments/simple_fsdp/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,18 @@ def build_test_list():
"hsdp",
ngpu=4,
),
# TODO: Adds back after HSDP+TP is supported by SimpleFSDP
# OverrideDefinitions(
# [
# [
# "--parallelism.data_parallel_shard_degree=2",
# "--parallelism.data_parallel_replicate_degree=2",
# "--parallelism.tensor_parallel_degree=2",
# ]
# ],
# "HSDP+TP",
# "hsdp+tp",
# ngpu=8,
# ),
OverrideDefinitions(
[
[
"--parallelism.data_parallel_shard_degree=2",
"--parallelism.data_parallel_replicate_degree=2",
"--parallelism.tensor_parallel_degree=2",
]
],
"HSDP+TP",
"hsdp+tp",
ngpu=8,
),
OverrideDefinitions(
[
[
Expand Down