-
Notifications
You must be signed in to change notification settings - Fork 435
Open
Description
High priority
- Grouped MM @tianyu-l
- [Bug] Potential bugs in "_grouped_mm" in Llama4 MoE codes #1237
- with Activation Checkpointing
- gets stuck after a couple of iterations
- with AdamW
- gets stuck after a couple of iterations
- with torch.compile @bdhirsh
- may need to register
torch._grouped_mm
(and the triton kernel for aligning indices) - basic compile support for grouped_mm pytorch#153384
- compile: turn off fullgraph=True to support llama4 #1182
- may need to register
- auxiliary-loss-free load balancing ([llama4] add auxiliary-loss-free load balancing to MoE token routing #1114)
- remove (persistent) buffers in checkpoint
- currently it's using default stream to do blocking communication when DP degree > 1, need to assess if it's OK
- selective activation checkpointing
- currently we are checkpointing every other matmul, which is not adapted to MoE router gate /
torch._grouped_mm
ops (potential solution) - solved in Add option for selective op AC to filter mm shapes based on fqn #1380
- currently we are checkpointing every other matmul, which is not adapted to MoE router gate /
Not high priority for now
- for-loop implementation of MoE
- with DTensor TP: sharding propagation overhead due to dynamic shapes
- need to lift cache hit criteria in DTensor sharding prop
- may be needed by Loss Parallel for per-sequence loss as well
- with torch.compile: branching on “unbacked” symbolic ints
- static padding of DTensor may solve this
- with DTensor TP: sharding propagation overhead due to dynamic shapes
Not llama4 specific
eliebak and lkhphuc