-
Notifications
You must be signed in to change notification settings - Fork 440
Initial compile support for llama4 #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
pls try rebase #1403 This is the original issue I mentioned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some questions.
Could also address #1365 (comment)
- rebase and see if the non-persistent buffer
tokens_per_expert
is causing trouble - manually try change
freqs_cis
to non-persistent and see if the issue is still there. https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/model/model.py#L388
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to | ||
repeated structure. Alternatively one can compile the whole model (after applying DP). | ||
""" | ||
torch._dynamo.config.fail_on_recompile_limit_hit = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for?
Other than this, it seems we can just apply the same function llama 3 uses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is to loud error if we recompile more than 8 times (default). currently, we would just silently fallback to eager if it happens.
self.w1, self.w2, self.w3, x, num_tokens_per_expert | ||
) | ||
|
||
# TODO: keeping this for-loop implementation for comparison |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
staticmethod on user-defined classes can not be generically supported, I moved those out.
Could you explain more? Does it mean if we move them out, then torch.compile can trace them in the same graph as the caller module is in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
@@ -28,89 +97,21 @@ def __init__( | |||
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) | |||
self.use_grouped_mm = use_grouped_mm | |||
|
|||
@torch._dynamo.set_fullgraph(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this annotation for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compiling the block with fullgraph=False could allow graph breaks to creep in silently with dynamo changes, and we wouldn't know about them until we manually inspect the graph or suspect QPS to have regressed.
This API to more granularly control the fullgraph argument of torch.compile, you can flip it on and off within a compiled region. In this case, we allow graph breaks between GroupedExperts.call and GroupedExperts.forward, i.e. allow graph break on the forward hooks from FSDP
@@ -297,7 +298,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
) | |||
|
|||
# shape (bs*slen*top_k, dim) | |||
routed_output = self.experts(routed_input, num_tokens_per_expert) | |||
with torch._dynamo.set_fullgraph(False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this annotation is for the FSDP caused graph break, correct?
Can we possibly incur this in the apply_compile
function. Technically this change is model-intrusively, despite being small.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API can't decorate GroupedExperts.call right now. If it's a problem, we can just compile MoE with fullgraph=False
Status
We don't have a good way in compile to specify fullgraph=True except for FSDP hooks at the moment. We can either leave it
fullgraph=False
or just wrap the experts model code inset_fullgraph(False)/set_fullgraph(True)
.Repro
tested on debug model
NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_degree=2 --training.compile
logs: https://gist.github.com/xmfan/41b822d9f09eb07fee62d684a061cec1
memory: 2.20GiB -> 1.42GiB
speedup: no big change, need to check with actual model