Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

from torchtitan.models.llama3.infra.parallelize import (
apply_ac,
apply_compile,
apply_ddp,
)
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.tools.logging import logger

from .expert_parallel import (
Expand Down Expand Up @@ -385,3 +381,19 @@ def apply_moe_ep_tp(
device_mesh=experts_mesh,
parallelize_plan=experts_plan,
)


def apply_compile(model: nn.Module):
"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
# TODO: remove when torch.compile supports fullgraph=True for llama4 moe
fullgraph = True
if transformer_block.moe_enabled:
fullgraph = False
transformer_block = torch.compile(transformer_block, fullgraph=fullgraph)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiling each TransformerBlock with torch.compile")