Skip to content

Commit 5deef0f

Browse files
committed
Disable low precision combine in BF16 MTP layer
Signed-off-by: Yilin Zhang <[email protected]>
1 parent 24f5cd7 commit 5deef0f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,13 @@ def is_post_quant_all2all_supported(self):
386386
else:
387387
return False
388388

389+
def is_low_precision_combine_supported(self):
390+
if not self.use_low_precision_combine:
391+
return False
392+
if self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
393+
return self.has_fp8_qdq or self.has_nvfp4 or self.has_w4afp8
394+
return False
395+
389396
def forward_chunk(
390397
self,
391398
x: Union[torch.Tensor, Fp4QuantizedTensor],
@@ -676,8 +683,7 @@ def forward_chunk(
676683
final_hidden_states = final_hidden_states.view(
677684
self.expert_size_per_partition,
678685
num_tokens_per_expert_for_fused_moe, self.hidden_size)
679-
if self.use_low_precision_combine:
680-
assert self.has_nvfp4 or self.has_w4afp8 or self.has_fp8_qdq, "Low precision combine only supports nvfp4, w4afp8 and fp8 qdq"
686+
if self.is_low_precision_combine_supported():
681687
precision = "fp8"
682688
global_scales = None
683689
if self.has_nvfp4:

0 commit comments

Comments
 (0)