diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 8c6bcd152..d90f3a67e 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -58,7 +58,7 @@ def _distribute_dtensor( Below are experimental enhancements to distribute a DTensor. This helps enable Simple FSDP + TP, in which inner spec/mesh is TP spec/mesh - outer spec/mesh is FSDP spec/mesh + outer spec/mesh is FSDP/DDP/HSDP spec/mesh The logic follows https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fsdp/_fsdp_param.py#L261 """ @@ -78,24 +78,40 @@ def _distribute_dtensor( submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names spanned_mesh = outer_global_mesh[submesh_names] - if placements[0].is_shard(): - # for FSDP + TP dtensor placement - shard_dim = placements[0].dim + if len(placements) == 1: + assert placements[0].is_replicate() or placements[0].is_shard() + if placements[0].is_shard(): + # For FSDP + TP dtensor placement + shard_dim = placements[0].dim + split_factor = inner_spec.num_shards_map[shard_dim] + tensor_placement = ( + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else placements[0] + ), + inner_spec.placements[0], + ) + else: + # For DDP + TP dtensor placement + tensor_placement = (placements[0], inner_spec.placements[0]) + elif len(placements) == 2: + assert placements[0].is_replicate() and placements[1].is_shard() + # For HSDP + TP dtensor placement + shard_dim = placements[1].dim split_factor = inner_spec.num_shards_map[shard_dim] tensor_placement = ( + placements[0], ( _StridedShard(shard_dim, split_factor=split_factor) if split_factor > 1 - else placements[0] + else placements[1] ), inner_spec.placements[0], ) - elif placements[0].is_replicate(): - # for DDP + TP dtensor placement - tensor_placement = (placements[0], inner_spec.placements[0]) else: raise ValueError( - f"Unsupported placement {placements[0]} for distributing DTensor {tensor}" + f"Unsupported placement {placements} for distributing DTensor {tensor}" ) current_spec = DTensorSpec( @@ -105,7 +121,7 @@ def _distribute_dtensor( ) target_spec = DTensorSpec( mesh=outer_mesh, - placements=(placements[0],), + placements=(placements[-1],), tensor_meta=inner_spec.tensor_meta, ) result_tensor = redistribute_local_tensor( @@ -188,9 +204,9 @@ def replicate_compute(self, x): # the gradients are partial tensors that needs to perform reduction # (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both) - # support for FSDP/DDP + TP (assuming TP shards the inner-most dim) + # support for FSDP/DDP/HSDP + TP (assuming TP shards the inner-most dim) if x._spec.mesh.mesh_dim_names[-1] == "tp": - dp_placement, tp_placement = x._spec.placements + tp_placement = x._spec.placements[-1] # TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"] # after DeviceMesh supports slicing a non-root mesh # dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"] diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index bb92b9d5e..d0579adcd 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -130,19 +130,18 @@ def build_test_list(): "hsdp", ngpu=4, ), - # TODO: Adds back after HSDP+TP is supported by SimpleFSDP - # OverrideDefinitions( - # [ - # [ - # "--parallelism.data_parallel_shard_degree=2", - # "--parallelism.data_parallel_replicate_degree=2", - # "--parallelism.tensor_parallel_degree=2", - # ] - # ], - # "HSDP+TP", - # "hsdp+tp", - # ngpu=8, - # ), + OverrideDefinitions( + [ + [ + "--parallelism.data_parallel_shard_degree=2", + "--parallelism.data_parallel_replicate_degree=2", + "--parallelism.tensor_parallel_degree=2", + ] + ], + "HSDP+TP", + "hsdp+tp", + ngpu=8, + ), OverrideDefinitions( [ [