Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 143 additions & 3 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from dataclasses import dataclass

from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
Expand All @@ -25,6 +26,7 @@ class ParallelDims:
ep: int
etp: int
world_size: int
mesh_dim_names: tuple[str] = tuple()

_world_mesh: DeviceMesh = None

Expand Down Expand Up @@ -63,6 +65,139 @@ def _validate(self):
# EP would borrow all cp and tp and some dp_shard degree
assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0

def build_mesh(self) -> "ParallelDims":
"""Build the device mesh with the required mesh dimensions.
The following mesh dimensions may be created based on the parallel configuration:
pp: For PP.
dp_replicate: For DDP or HSDP replicate dimension.
dp_shard_cp: For FSDP or HSDP shard dimension. This includes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just call it fsdp?

``cp`` even if ``cp`` is 1. As a result, we always
use the name ``dp_shard_cp``, and ``dp_shard`` is not
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it harmful to just create dp_shard anyway, for symmetry? Are you trying to stop people from accidentally using the wrong mesh dim axis because they weren't thinking about CP?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what symmetry?
There are requests to always have FSDP sharding even when FSDP==1, to handle mixed precision, just like what we are doing for EP params now.
#1469
Especially considering amp is not maintained (is it?)
#1525

created as a dimension.
dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``,
``dp_shard``, and ``cp`` as all of them are data parallelisms.
dp: This is used by data loading to decide the global batch size and
which part of data this raunk should read. This dim includes both
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
which part of data this raunk should read. This dim includes both
which part of data this rank should read. This dim includes both

``dp_replicate`` and ``dp_shard``.
The name is confusing; ``batch`` could be a better name.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree dp is not informative and batch may be better.
But I think dp_cp is pretty consistent with dp -- we should change both if we change one.
Maybe dp_cp to loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this I agree too!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to pick a JAX-y name "batch" for dp, why don't we just extend this principle all the mesh dims? The rule is you name the dimension after the most important thing that will be sharded by it.

Copy link
Contributor

@tianyu-l tianyu-l Aug 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang
when you say "the most" it is introducing ambiguity!
E.g. TP/SP would shard both parameters, on some model dim, and data on seq dim. Meanwhile, CP shards data on seq dim, and can be used to FSDP fully_shard parameter together with dp_shard.
For FSDP (dp_shard * cp), which dim should we name the sharding? The answer is, it doesn't really matter. In PyTorch FSDP2, we by default shard to dim-0, but one could alternatively always shard on the dim which TP / EP don't shard on. (I do think we can just name the flattened dp_shard * cp to fsdp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I'd prefer the atomic mesh dim names to be aligned with parallelism, and the flattened dim to the actual usage, instead of _ concatenated atomic names. Concretely
dp_shard, cp -> fsdp
dp_replicate, dp_shard -> batch
dp_replicate, dp_shard, cp -> loss

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this framing makes sense to me.

cp: For CP.
tp: For TP.
ep: For EP.
dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to use conditions like

  1. CP must be all in EP
  2. EP spans (TP, ) CP, and part of DP shard

because otherwise it's too heavy code.
But with this unflatten we don't need to have such constraints.
We can just use too "global" mesh:

  • one is dense part pp, dp_replicate, dp_shard, cp, tp
  • the other is sparse part pp, dp_replicate, edp, ep, etp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dp_shard_in_ep is your edp. dp_shard degree in the EP region.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not exactly the same.
The dp_shard_in_ep name hints that part of dp_shard overlaps with ep, but in reality it doesn't have to be the case.

Note: These dimensions won't exist at the same time. If we consider
the unflatten() operator only, the following are all the meshes required
assuming all degrees are > 1 except for ``pp``:
["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader
doesn't need it for communication.
["dp_cp", "tp"]: Loss computation.
["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation.
["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp.
["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1.
In reality, we don't actually need to create all of these meshes.
For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"].
So we don't actually need to create ["dp_cp", "tp"].
But there are some meshes we MUST create if that mesh will be used for a
parameter. So Non-EP-region-computation mesh and EP-region-computation mesh
are required.
"""

def add_dim(name, degree, config):
config["name"].append(name)
config["degree"].append(degree)

world_mesh = init_device_mesh(device_type, [self.world_size])
Copy link
Contributor

@wconstab wconstab Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thought is that we might want to define the semantics of device-mesh creation the following way. for the purpose of this, I am pretending init_device_mesh and DeviceMesh() are equivalent (and maybe we could get rid of init_ eventually.

this case would initialize a world group and then split subgroups off of it

mesh = DeviceMesh((world_size,))
mesh.unflatten((a, b, c))

this case would not initialize the world group, it would initialize separate groups a, b, c
any further 'unflatten' would still use '.split' on the PGs
mesh = DeviceMesh((a, b, c))

I bring this up because at large scale, it is expensive to initialize the world group, so it is good to let users choose what they actually want to happen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR actually follow this proposal. It's just that we are using init_device_mesh but assuming both are equivalent in the future, then the creation semantics is identical as what you proposed.

dp_shard_in_ep = (
self.dp_shard * self.cp // self.ep
if self.etp == self.tp
else self.dp_shard * self.cp * self.tp // self.ep
)
Comment on lines +115 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally we shouldn't do any "real math" in this file, see comment above


data_mesh_dims = defaultdict(list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we just need two (dense and sparse), any specific reason we have to have three?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can be unflattened from the two meshes you mentioned. But either way, we need dp and cp submeshes. The number of PGs created will be the same.

non_ep_computation_dims = defaultdict(list)
ep_computation_dims = defaultdict(list)
Comment on lines +122 to +123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non_ep_ and ep_ sounds EP-centric.
How about being neutral and call them dense vs sparse. The shared expert and router in MoE actually belong to the dense part -- they are computed on every single token.


if self.pp_enabled:
add_dim("pp", self.pp, data_mesh_dims)
add_dim("pp", self.pp, non_ep_computation_dims)
add_dim("pp", self.pp, ep_computation_dims)

if self.dp_enabled:
add_dim("dp", self.dp_replicate * self.dp_shard, data_mesh_dims)
if self.dp_replicate_enabled:
add_dim("dp_replicate", self.dp_replicate, non_ep_computation_dims)
add_dim("dp_replicate", self.dp_replicate, ep_computation_dims)
if self.dp_shard_enabled:
add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims)
add_dim("dp_shard_in_ep", dp_shard_in_ep, ep_computation_dims)

if self.cp_enabled:
add_dim("cp", self.cp, data_mesh_dims)

if self.tp_enabled:
add_dim("tp", self.tp, data_mesh_dims, non_ep_computation_dims)
if self.etp == self.tp:
add_dim("tp", self.tp, ep_computation_dims)

self._all_meshes = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, storing a list of all_meshes and then having to index into it feels like a UX regression. i wonder if we need to assemble the meshes into such a list at all? should we just store self.dp_mesh in parallel_dims and use that directly where we want it?

Copy link
Contributor Author

@fegin fegin Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UX regression to who? I think it is a UX regression to TorchTitan ParallelDims as TorchTitan now has to maintain all the meshes. And this is the new proposal of DeviceMesh where you could not get all the meshes through the root mesh. So the user, in this case ParallelDims, has the store the information, hence self._all_meshes.

As for people who uses ParallelDims, the UX doesn't change -- they can actually access those meshes in the same way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my comment was more focused on the choise to put all_meshes into an ordered list and access the recent one by [-1]. I would maybe either store them in a dict with descriptive names, or, just store them directly on ParallelDims as optional attributes.


if self.dp_enabled:
data_mesh = world_mesh._unflatten(
0, data_mesh_dims["degree"], data_mesh_dims["name"]
)
self._all_meshes.append(data_mesh)
# Note that we don't create loss_mesh as it is easier to flatten
# from data_mesh
if self.cp_enabled:
self._all_meshes[-1]["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
else:
self._all_meshes[-1]["dp"]._flatten(mesh_dim_name="dp_cp")

if self.dp_cp_enabled or self.tp_enabled or self.pp_enabled:
self._all_meshes.append(
world_mesh._unflatten(
0,
non_ep_computation_dims["degree"],
non_ep_computation_dims["name"],
)
)

if self.ep_enabled:
add_dim("ep", self.ep, ep_computation_dims)
self._all_meshes.append(
world_mesh._unflatten(
0, ep_computation_dims["degree"], ep_computation_dims["name"]
)
)

self._world_mesh = world_mesh
self.mesh_dim_names = tuple(
name for m in self._all_meshes for name in m.mesh_dim_names
)
return self

def __getitem__(self, name):
# This is a hack to make ParallelDims behave like a DeviceMesh.
# We will need to change trainer if design is concluded. For now,
# this is just a quick hack to make it work with unflatten()

if "mesh_dim_names" == name:
return [name for m in self._all_meshes for name in m.mesh_dim_names]

for mesh in self._all_meshes:
try:
submesh = mesh[name]
return submesh
except KeyError:
pass
raise AttributeError(f"ParallelDims has no attribute {name}")

"""
def build_mesh(self) -> DeviceMesh:
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
# is not very clean, due to the limited support from DeviceMesh
Expand Down Expand Up @@ -188,14 +323,19 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
return mesh
"""

@property
def world_mesh(self) -> str:
def world_mesh(self) -> "ParallelDims":
# This is a hack to make ParallelDims behave like a DeviceMesh.
# We will need to change trainer if design is concluded. For now,
# this is just a quick hack to make it work with unflatten()

# doing late init so ParallelDims can still be used as a lightweight
# dataclass without having to initialize the world mesh
if self._world_mesh is None:
self._world_mesh = self.build_mesh()
return self._world_mesh
self.build_mesh()
return self

@property
def dp_enabled(self):
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def __init__(self, job_config: JobConfig):

# Set random seed, and maybe enable deterministic mode
# (mainly for debugging, expect perf loss).
"""
dist_utils.set_determinism(
world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
)
"""
self.train_spec = train_spec_module.get_train_spec(job_config.model.name)

# build tokenizer and dataloader
Expand Down Expand Up @@ -611,7 +613,7 @@ def train(self):
timeout=timedelta(
seconds=job_config.comm.train_timeout_seconds
),
world_mesh=self.parallel_dims.world_mesh,
world_mesh=self.parallel_dims._world_mesh,
)

if torch.distributed.get_rank() == 0:
Expand Down
Loading