Skip to content

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Aug 29, 2025

Summary
pytorch/pytorch#161224 is the PyTorch PR stack by @fduwjj to implement DeviceMesh.unflatten()

This PR showcase how to use the above stack to rewrite build_mesh() in parallel_dims.py.

Open questions:

  1. How do we let FSDP2 correctly get the SPMD mesh?
  2. How do we let the trainer access each mesh dimension (e.g., dp_cp, dp_shard_cp, ...) easily and correctly?

Both questions arise because we can no longer access meshes created by flatten() and unflatten() under the new proposal.

For 1, implementing get_parent() in DeviceMesh may solve the issue.

For 2, this PR implements __get_item__() to work around the issue but this requires more discussions.

This is a demonstration of how parallel_dims will be when using pytorch/pytorch#161224 stack.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 29, 2025
config["name"].append(name)
config["degree"].append(degree)

world_mesh = init_device_mesh(device_type, [self.world_size])
Copy link
Contributor

@wconstab wconstab Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thought is that we might want to define the semantics of device-mesh creation the following way. for the purpose of this, I am pretending init_device_mesh and DeviceMesh() are equivalent (and maybe we could get rid of init_ eventually.

this case would initialize a world group and then split subgroups off of it

mesh = DeviceMesh((world_size,))
mesh.unflatten((a, b, c))

this case would not initialize the world group, it would initialize separate groups a, b, c
any further 'unflatten' would still use '.split' on the PGs
mesh = DeviceMesh((a, b, c))

I bring this up because at large scale, it is expensive to initialize the world group, so it is good to let users choose what they actually want to happen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR actually follow this proposal. It's just that we are using init_device_mesh but assuming both are equivalent in the future, then the creation semantics is identical as what you proposed.

if self.etp == self.tp:
add_dim("tp", self.tp, ep_computation_dims)

self._all_meshes = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, storing a list of all_meshes and then having to index into it feels like a UX regression. i wonder if we need to assemble the meshes into such a list at all? should we just store self.dp_mesh in parallel_dims and use that directly where we want it?

Copy link
Contributor Author

@fegin fegin Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UX regression to who? I think it is a UX regression to TorchTitan ParallelDims as TorchTitan now has to maintain all the meshes. And this is the new proposal of DeviceMesh where you could not get all the meshes through the root mesh. So the user, in this case ParallelDims, has the store the information, hence self._all_meshes.

As for people who uses ParallelDims, the UX doesn't change -- they can actually access those meshes in the same way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my comment was more focused on the choise to put all_meshes into an ordered list and access the recent one by [-1]. I would maybe either store them in a dict with descriptive names, or, just store them directly on ParallelDims as optional attributes.

pp: For PP.
dp_replicate: For DDP or HSDP replicate dimension.
dp_shard_cp: For FSDP or HSDP shard dimension. This includes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just call it fsdp?

dp: This is used by data loading to decide the global batch size and
which part of data this raunk should read. This dim includes both
``dp_replicate`` and ``dp_shard``.
The name is confusing; ``batch`` could be a better name.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree dp is not informative and batch may be better.
But I think dp_cp is pretty consistent with dp -- we should change both if we change one.
Maybe dp_cp to loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this I agree too!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to pick a JAX-y name "batch" for dp, why don't we just extend this principle all the mesh dims? The rule is you name the dimension after the most important thing that will be sharded by it.

cp: For CP.
tp: For TP.
ep: For EP.
dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to use conditions like

  1. CP must be all in EP
  2. EP spans (TP, ) CP, and part of DP shard

because otherwise it's too heavy code.
But with this unflatten we don't need to have such constraints.
We can just use too "global" mesh:

  • one is dense part pp, dp_replicate, dp_shard, cp, tp
  • the other is sparse part pp, dp_replicate, edp, ep, etp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dp_shard_in_ep is your edp. dp_shard degree in the EP region.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not exactly the same.
The dp_shard_in_ep name hints that part of dp_shard overlaps with ep, but in reality it doesn't have to be the case.

Comment on lines +115 to +119
dp_shard_in_ep = (
self.dp_shard * self.cp // self.ep
if self.etp == self.tp
else self.dp_shard * self.cp * self.tp // self.ep
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally we shouldn't do any "real math" in this file, see comment above

else self.dp_shard * self.cp * self.tp // self.ep
)

data_mesh_dims = defaultdict(list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we just need two (dense and sparse), any specific reason we have to have three?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can be unflattened from the two meshes you mentioned. But either way, we need dp and cp submeshes. The number of PGs created will be the same.

Comment on lines +122 to +123
non_ep_computation_dims = defaultdict(list)
ep_computation_dims = defaultdict(list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non_ep_ and ep_ sounds EP-centric.
How about being neutral and call them dense vs sparse. The shared expert and router in MoE actually belong to the dense part -- they are computed on every single token.

dp_replicate: For DDP or HSDP replicate dimension.
dp_shard_cp: For FSDP or HSDP shard dimension. This includes
``cp`` even if ``cp`` is 1. As a result, we always
use the name ``dp_shard_cp``, and ``dp_shard`` is not
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it harmful to just create dp_shard anyway, for symmetry? Are you trying to stop people from accidentally using the wrong mesh dim axis because they weren't thinking about CP?

dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``,
``dp_shard``, and ``cp`` as all of them are data parallelisms.
dp: This is used by data loading to decide the global batch size and
which part of data this raunk should read. This dim includes both
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
which part of data this raunk should read. This dim includes both
which part of data this rank should read. This dim includes both

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants