diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 33ff71a98..b1e60f996 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -21,11 +21,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.models.llama3.infra.parallelize import ( - apply_ac, - apply_compile, - apply_ddp, -) +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger from .expert_parallel import ( @@ -385,3 +381,19 @@ def apply_moe_ep_tp( device_mesh=experts_mesh, parallelize_plan=experts_plan, ) + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.layers.named_children(): + # TODO: remove when torch.compile supports fullgraph=True for llama4 moe + fullgraph = True + if transformer_block.moe_enabled: + fullgraph = False + transformer_block = torch.compile(transformer_block, fullgraph=fullgraph) + model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile")