-
Notifications
You must be signed in to change notification settings - Fork 495
Use new DeviceMesh unflatten to rewrite parallel_dims #1660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
565a5f6
5fa85d9
234f80e
baaa3ea
3f4181e
70be316
3716135
a6078b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
from collections import defaultdict | ||||||
from dataclasses import dataclass | ||||||
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | ||||||
|
@@ -25,6 +26,7 @@ class ParallelDims: | |||||
ep: int | ||||||
etp: int | ||||||
world_size: int | ||||||
mesh_dim_names: tuple[str] = tuple() | ||||||
|
||||||
_world_mesh: DeviceMesh = None | ||||||
|
||||||
|
@@ -63,6 +65,139 @@ def _validate(self): | |||||
# EP would borrow all cp and tp and some dp_shard degree | ||||||
assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 | ||||||
|
||||||
def build_mesh(self) -> "ParallelDims": | ||||||
"""Build the device mesh with the required mesh dimensions. | ||||||
The following mesh dimensions may be created based on the parallel configuration: | ||||||
pp: For PP. | ||||||
dp_replicate: For DDP or HSDP replicate dimension. | ||||||
dp_shard_cp: For FSDP or HSDP shard dimension. This includes | ||||||
``cp`` even if ``cp`` is 1. As a result, we always | ||||||
use the name ``dp_shard_cp``, and ``dp_shard`` is not | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it harmful to just create dp_shard anyway, for symmetry? Are you trying to stop people from accidentally using the wrong mesh dim axis because they weren't thinking about CP? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
created as a dimension. | ||||||
dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, | ||||||
``dp_shard``, and ``cp`` as all of them are data parallelisms. | ||||||
dp: This is used by data loading to decide the global batch size and | ||||||
which part of data this raunk should read. This dim includes both | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
``dp_replicate`` and ``dp_shard``. | ||||||
The name is confusing; ``batch`` could be a better name. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this I agree too! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're going to pick a JAX-y name "batch" for dp, why don't we just extend this principle all the mesh dims? The rule is you name the dimension after the most important thing that will be sharded by it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ezyang There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, I'd prefer the atomic mesh dim names to be aligned with parallelism, and the flattened dim to the actual usage, instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea this framing makes sense to me. |
||||||
cp: For CP. | ||||||
tp: For TP. | ||||||
ep: For EP. | ||||||
dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to use conditions like
because otherwise it's too heavy code.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not exactly the same. |
||||||
Note: These dimensions won't exist at the same time. If we consider | ||||||
the unflatten() operator only, the following are all the meshes required | ||||||
assuming all degrees are > 1 except for ``pp``: | ||||||
["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader | ||||||
doesn't need it for communication. | ||||||
["dp_cp", "tp"]: Loss computation. | ||||||
["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. | ||||||
["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp. | ||||||
["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1. | ||||||
In reality, we don't actually need to create all of these meshes. | ||||||
For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. | ||||||
So we don't actually need to create ["dp_cp", "tp"]. | ||||||
But there are some meshes we MUST create if that mesh will be used for a | ||||||
parameter. So Non-EP-region-computation mesh and EP-region-computation mesh | ||||||
are required. | ||||||
""" | ||||||
|
||||||
def add_dim(name, degree, config): | ||||||
config["name"].append(name) | ||||||
config["degree"].append(degree) | ||||||
|
||||||
world_mesh = init_device_mesh(device_type, [self.world_size]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one thought is that we might want to define the semantics of device-mesh creation the following way. for the purpose of this, I am pretending this case would initialize a world group and then split subgroups off of it
this case would not initialize the world group, it would initialize separate groups a, b, c I bring this up because at large scale, it is expensive to initialize the world group, so it is good to let users choose what they actually want to happen There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR actually follow this proposal. It's just that we are using |
||||||
dp_shard_in_ep = ( | ||||||
self.dp_shard * self.cp // self.ep | ||||||
if self.etp == self.tp | ||||||
else self.dp_shard * self.cp * self.tp // self.ep | ||||||
) | ||||||
Comment on lines
+115
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ideally we shouldn't do any "real math" in this file, see comment above |
||||||
|
||||||
data_mesh_dims = defaultdict(list) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking we just need two (dense and sparse), any specific reason we have to have three? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one can be unflattened from the two meshes you mentioned. But either way, we need |
||||||
non_ep_computation_dims = defaultdict(list) | ||||||
ep_computation_dims = defaultdict(list) | ||||||
Comment on lines
+122
to
+123
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
|
||||||
if self.pp_enabled: | ||||||
add_dim("pp", self.pp, data_mesh_dims) | ||||||
add_dim("pp", self.pp, non_ep_computation_dims) | ||||||
add_dim("pp", self.pp, ep_computation_dims) | ||||||
|
||||||
if self.dp_enabled: | ||||||
add_dim("dp", self.dp_replicate * self.dp_shard, data_mesh_dims) | ||||||
if self.dp_replicate_enabled: | ||||||
add_dim("dp_replicate", self.dp_replicate, non_ep_computation_dims) | ||||||
add_dim("dp_replicate", self.dp_replicate, ep_computation_dims) | ||||||
if self.dp_shard_enabled: | ||||||
add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims) | ||||||
add_dim("dp_shard_in_ep", dp_shard_in_ep, ep_computation_dims) | ||||||
|
||||||
if self.cp_enabled: | ||||||
add_dim("cp", self.cp, data_mesh_dims) | ||||||
|
||||||
if self.tp_enabled: | ||||||
add_dim("tp", self.tp, data_mesh_dims, non_ep_computation_dims) | ||||||
if self.etp == self.tp: | ||||||
add_dim("tp", self.tp, ep_computation_dims) | ||||||
|
||||||
self._all_meshes = [] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm, storing a list of all_meshes and then having to index into it feels like a UX regression. i wonder if we need to assemble the meshes into such a list at all? should we just store There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. UX regression to who? I think it is a UX regression to TorchTitan As for people who uses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my comment was more focused on the choise to put all_meshes into an ordered list and access the recent one by [-1]. I would maybe either store them in a dict with descriptive names, or, just store them directly on ParallelDims as optional attributes. |
||||||
|
||||||
if self.dp_enabled: | ||||||
data_mesh = world_mesh._unflatten( | ||||||
0, data_mesh_dims["degree"], data_mesh_dims["name"] | ||||||
) | ||||||
self._all_meshes.append(data_mesh) | ||||||
# Note that we don't create loss_mesh as it is easier to flatten | ||||||
# from data_mesh | ||||||
if self.cp_enabled: | ||||||
self._all_meshes[-1]["dp", "cp"]._flatten(mesh_dim_name="dp_cp") | ||||||
else: | ||||||
self._all_meshes[-1]["dp"]._flatten(mesh_dim_name="dp_cp") | ||||||
|
||||||
if self.dp_cp_enabled or self.tp_enabled or self.pp_enabled: | ||||||
self._all_meshes.append( | ||||||
world_mesh._unflatten( | ||||||
0, | ||||||
non_ep_computation_dims["degree"], | ||||||
non_ep_computation_dims["name"], | ||||||
) | ||||||
) | ||||||
|
||||||
if self.ep_enabled: | ||||||
add_dim("ep", self.ep, ep_computation_dims) | ||||||
self._all_meshes.append( | ||||||
world_mesh._unflatten( | ||||||
0, ep_computation_dims["degree"], ep_computation_dims["name"] | ||||||
) | ||||||
) | ||||||
|
||||||
self._world_mesh = world_mesh | ||||||
self.mesh_dim_names = tuple( | ||||||
name for m in self._all_meshes for name in m.mesh_dim_names | ||||||
) | ||||||
return self | ||||||
|
||||||
def __getitem__(self, name): | ||||||
# This is a hack to make ParallelDims behave like a DeviceMesh. | ||||||
# We will need to change trainer if design is concluded. For now, | ||||||
# this is just a quick hack to make it work with unflatten() | ||||||
|
||||||
if "mesh_dim_names" == name: | ||||||
return [name for m in self._all_meshes for name in m.mesh_dim_names] | ||||||
|
||||||
for mesh in self._all_meshes: | ||||||
try: | ||||||
submesh = mesh[name] | ||||||
return submesh | ||||||
except KeyError: | ||||||
pass | ||||||
raise AttributeError(f"ParallelDims has no attribute {name}") | ||||||
|
||||||
""" | ||||||
def build_mesh(self) -> DeviceMesh: | ||||||
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel | ||||||
# is not very clean, due to the limited support from DeviceMesh | ||||||
|
@@ -188,14 +323,19 @@ def _build_mesh_without_ep(self) -> DeviceMesh: | |||||
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") | ||||||
return mesh | ||||||
""" | ||||||
|
||||||
@property | ||||||
def world_mesh(self) -> str: | ||||||
def world_mesh(self) -> "ParallelDims": | ||||||
# This is a hack to make ParallelDims behave like a DeviceMesh. | ||||||
# We will need to change trainer if design is concluded. For now, | ||||||
# this is just a quick hack to make it work with unflatten() | ||||||
|
||||||
# doing late init so ParallelDims can still be used as a lightweight | ||||||
# dataclass without having to initialize the world mesh | ||||||
if self._world_mesh is None: | ||||||
self._world_mesh = self.build_mesh() | ||||||
return self._world_mesh | ||||||
self.build_mesh() | ||||||
return self | ||||||
|
||||||
@property | ||||||
def dp_enabled(self): | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just call it
fsdp
?