-
Notifications
You must be signed in to change notification settings - Fork 492
Use new DeviceMesh unflatten to rewrite parallel_dims #1660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is a demonstration of how parallel_dims will be when using pytorch/pytorch#161224 stack.
config["name"].append(name) | ||
config["degree"].append(degree) | ||
|
||
world_mesh = init_device_mesh(device_type, [self.world_size]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thought is that we might want to define the semantics of device-mesh creation the following way. for the purpose of this, I am pretending init_device_mesh
and DeviceMesh()
are equivalent (and maybe we could get rid of init_
eventually.
this case would initialize a world group and then split subgroups off of it
mesh = DeviceMesh((world_size,))
mesh.unflatten((a, b, c))
this case would not initialize the world group, it would initialize separate groups a, b, c
any further 'unflatten' would still use '.split' on the PGs
mesh = DeviceMesh((a, b, c))
I bring this up because at large scale, it is expensive to initialize the world group, so it is good to let users choose what they actually want to happen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR actually follow this proposal. It's just that we are using init_device_mesh
but assuming both are equivalent in the future, then the creation semantics is identical as what you proposed.
if self.etp == self.tp: | ||
add_dim("tp", self.tp, ep_computation_dims) | ||
|
||
self._all_meshes = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, storing a list of all_meshes and then having to index into it feels like a UX regression. i wonder if we need to assemble the meshes into such a list at all? should we just store self.dp_mesh
in parallel_dims and use that directly where we want it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UX regression to who? I think it is a UX regression to TorchTitan ParallelDims
as TorchTitan now has to maintain all the meshes. And this is the new proposal of DeviceMesh where you could not get all the meshes through the root mesh. So the user, in this case ParallelDims
, has the store the information, hence self._all_meshes
.
As for people who uses ParallelDims
, the UX doesn't change -- they can actually access those meshes in the same way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my comment was more focused on the choise to put all_meshes into an ordered list and access the recent one by [-1]. I would maybe either store them in a dict with descriptive names, or, just store them directly on ParallelDims as optional attributes.
pp: For PP. | ||
dp_replicate: For DDP or HSDP replicate dimension. | ||
dp_shard_cp: For FSDP or HSDP shard dimension. This includes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just call it fsdp
?
dp: This is used by data loading to decide the global batch size and | ||
which part of data this raunk should read. This dim includes both | ||
``dp_replicate`` and ``dp_shard``. | ||
The name is confusing; ``batch`` could be a better name. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree dp
is not informative and batch
may be better.
But I think dp_cp
is pretty consistent with dp
-- we should change both if we change one.
Maybe dp_cp
to loss
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this I agree too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're going to pick a JAX-y name "batch" for dp, why don't we just extend this principle all the mesh dims? The rule is you name the dimension after the most important thing that will be sharded by it.
cp: For CP. | ||
tp: For TP. | ||
ep: For EP. | ||
dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to use conditions like
- CP must be all in EP
- EP spans (TP, ) CP, and part of DP shard
because otherwise it's too heavy code.
But with this unflatten
we don't need to have such constraints.
We can just use too "global" mesh:
- one is dense part
pp, dp_replicate, dp_shard, cp, tp
- the other is sparse part
pp, dp_replicate, edp, ep, etp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dp_shard_in_ep
is your edp
. dp_shard
degree in the EP region.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not exactly the same.
The dp_shard_in_ep
name hints that part of dp_shard
overlaps with ep
, but in reality it doesn't have to be the case.
dp_shard_in_ep = ( | ||
self.dp_shard * self.cp // self.ep | ||
if self.etp == self.tp | ||
else self.dp_shard * self.cp * self.tp // self.ep | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally we shouldn't do any "real math" in this file, see comment above
else self.dp_shard * self.cp * self.tp // self.ep | ||
) | ||
|
||
data_mesh_dims = defaultdict(list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking we just need two (dense and sparse), any specific reason we have to have three?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one can be unflattened from the two meshes you mentioned. But either way, we need dp
and cp
submeshes. The number of PGs created will be the same.
non_ep_computation_dims = defaultdict(list) | ||
ep_computation_dims = defaultdict(list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non_ep_
and ep_
sounds EP-centric.
How about being neutral and call them dense
vs sparse
. The shared expert and router in MoE actually belong to the dense part -- they are computed on every single token.
dp_replicate: For DDP or HSDP replicate dimension. | ||
dp_shard_cp: For FSDP or HSDP shard dimension. This includes | ||
``cp`` even if ``cp`` is 1. As a result, we always | ||
use the name ``dp_shard_cp``, and ``dp_shard`` is not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it harmful to just create dp_shard anyway, for symmetry? Are you trying to stop people from accidentally using the wrong mesh dim axis because they weren't thinking about CP?
dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, | ||
``dp_shard``, and ``cp`` as all of them are data parallelisms. | ||
dp: This is used by data loading to decide the global batch size and | ||
which part of data this raunk should read. This dim includes both |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which part of data this raunk should read. This dim includes both | |
which part of data this rank should read. This dim includes both |
Summary
pytorch/pytorch#161224 is the PyTorch PR stack by @fduwjj to implement
DeviceMesh.unflatten()
This PR showcase how to use the above stack to rewrite
build_mesh()
inparallel_dims.py
.Open questions:
Both questions arise because we can no longer access meshes created by flatten() and unflatten() under the new proposal.
For 1, implementing get_parent() in DeviceMesh may solve the issue.
For 2, this PR implements
__get_item__()
to work around the issue but this requires more discussions.