diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 3c2bb9e84..c387fcd61 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -336,6 +336,10 @@ def __init__(self, model_args: DeepSeekV3ModelArgs): self.model_args = model_args self.init_weights() + # Explicitly set dtype to bfloat16 if use_grouped_mm is True + if model_args.use_grouped_mm: + self.to(dtype=torch.bfloat16) + def init_weights(self, buffer_device: torch.device | None = None) -> None: buffer_device = buffer_device or self.freqs_cis.device with torch.device(buffer_device): diff --git a/torchtitan/models/deepseek_v3/model/moe.py b/torchtitan/models/deepseek_v3/model/moe.py index 1e8e99cbd..5bb408549 100644 --- a/torchtitan/models/deepseek_v3/model/moe.py +++ b/torchtitan/models/deepseek_v3/model/moe.py @@ -141,12 +141,12 @@ def _run_experts_grouped_mm( assert x.dim() == 3 assert ( - x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16 - ), "torch._grouped_mm only supports bf16 dtypes" + w1.dtype == w2.dtype == w3.dtype == torch.bfloat16 + ), "torch._grouped_mm only supports bf16 model weight dtypes" - h = F.silu(torch._grouped_mm(x, w1, offs=offsets)) - h = h * torch._grouped_mm(x, w3, offs=offsets) - out = torch._grouped_mm(h, w2, offs=offsets) + h = F.silu(torch._grouped_mm(x.to(torch.bfloat16), w1, offs=offsets)) + h = h * torch._grouped_mm(x.to(torch.bfloat16), w3, offs=offsets) + out = torch._grouped_mm(h, w2, offs=offsets).to(x.dtype) return out