Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,53 @@ def backward(ctx, grad_y):
return grad_x, None, None, None


# We should delete this, either by ensuring that the _A2A above actually works with dynamo,
# or by killing it in favor of AC2,
# see https://github.com/pytorch/torchtitan/issues/1467#issuecomment-3181235004
class _A2AFunCol(torch.autograd.Function):
@staticmethod
def forward(ctx, x, out_splits, in_splits, group_name):
if isinstance(out_splits, torch.Tensor):
out_splits = out_splits.tolist()
if isinstance(in_splits, torch.Tensor):
in_splits = in_splits.tolist()

y = torch.ops._c10d_functional.all_to_all_single.default(
x.contiguous(), out_splits, in_splits, group_name
)
y = torch.ops._c10d_functional.wait_tensor.default(y)

ctx.in_splits = in_splits
ctx.out_splits = out_splits
ctx.group_name = group_name
return y

@staticmethod
def backward(ctx, grad_y):
# grad wrt input has length sum(in_splits)
grad_x = torch.ops._c10d_functional.all_to_all_single.default(
grad_y.contiguous(), ctx.in_splits, ctx.out_splits, ctx.group_name
)
grad_x = torch.ops._c10d_functional.wait_tensor.default(grad_x)
return grad_x, None, None, None


# TODO:
# - we should figure out why dynamo is inlining into _A2A.apply without constructing
# an autograd.Function in the dynamo grph
# - we should also make sure that dist.all_to_all gets remapped properly by dynamo
@torch.compiler.allow_in_graph
def all_to_all_single_autograd_dynamo_friendly(x, out_splits, in_splits, group_name):
return _A2AFunCol.apply(x, out_splits, in_splits, group_name)


def all_to_all_single_autograd(x, out_splits, in_splits, group):
return _A2A.apply(x, out_splits, in_splits, group)
if torch.compiler.is_compiling():
return all_to_all_single_autograd_dynamo_friendly(
x, out_splits, in_splits, group.group_name
)
else:
return _A2A.apply(x, out_splits, in_splits, group_name)


TOKEN_GROUP_ALIGN_SIZE_M = 8
Expand Down