Skip to content

Commit c33f43e

Browse files
authored
[https://nvbugs/5518713][fix] Trtllm-gen moe backend for blockwise fp8 ckpt (Qwen3-235B-A22B-FP8) (#7856)
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent d708701 commit c33f43e

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ at::Tensor run_fp8_block_scale_moe(at::optional<at::Tensor> const& routing_logit
115115
else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Renormalize
116116
|| static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::RenormalizeNaive)
117117
{
118-
TORCH_CHECK(false, "Don't support this routing method type Renormalize(Naive).");
118+
TORCH_CHECK(top_k <= 8 && top_k > 0,
119+
"Current routing kernel (no groups, renormalize) only supports top_k<=8 && top_k>0.");
119120
}
120121
else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4)
121122
{

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,12 @@ def __init__(
395395
self,
396396
num_experts: int,
397397
top_k: int,
398-
n_group: int,
399-
topk_group: int,
398+
n_group: Optional[int],
399+
topk_group: Optional[int],
400400
intermediate_size: int,
401401
local_expert_offset: int,
402402
local_num_experts: int,
403-
routed_scaling_factor: float,
403+
routed_scaling_factor: Optional[float],
404404
routing_method_type: int,
405405
):
406406

@@ -562,12 +562,12 @@ def fp8_block_scale_moe_runner(
562562
gemm2_weights_scale: torch.Tensor,
563563
num_experts: int,
564564
top_k: int,
565-
n_group: int,
566-
topk_group: int,
565+
n_group: Optional[int],
566+
topk_group: Optional[int],
567567
intermediate_size: int,
568568
local_expert_offset: int,
569569
local_num_experts: int,
570-
routed_scaling_factor: float,
570+
routed_scaling_factor: Optional[float],
571571
routing_method_type: int,
572572
topk_weights: Optional[torch.Tensor] = None,
573573
topk_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -630,12 +630,12 @@ def _(routing_logits: torch.Tensor,
630630
gemm2_weights_scale: torch.Tensor,
631631
num_experts: int,
632632
top_k: int,
633-
n_group: int,
634-
topk_group: int,
633+
n_group: Optional[int],
634+
topk_group: Optional[int],
635635
intermediate_size: int,
636636
local_expert_offset: int,
637637
local_num_experts: int,
638-
routed_scaling_factor: float,
638+
routed_scaling_factor: Optional[float],
639639
routing_method_type: int,
640640
topk_weights: Optional[torch.Tensor] = None,
641641
topk_ids: Optional[torch.Tensor] = None) -> torch.Tensor:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2959,6 +2959,38 @@ def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
29592959
task = GSM8K(self.MODEL_NAME)
29602960
task.evaluate(llm)
29612961

2962+
@skip_pre_hopper
2963+
@pytest.mark.skip_less_device(8)
2964+
@pytest.mark.parametrize(
2965+
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend",
2966+
[(8, 1, 8, True, True, True, "DEEPGEMM"),
2967+
(8, 1, 8, False, True, True, "DEEPGEMM"),
2968+
(8, 1, 8, True, True, True, "TRTLLM"),
2969+
(8, 1, 8, False, True, True, "TRTLLM")],
2970+
ids=[
2971+
"latency_deepgemm", "throughput_latency_deepgemm", "latency_trtllm",
2972+
"throughput_latency_trtllm"
2973+
])
2974+
def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp,
2975+
cuda_graph, overlap_scheduler, moe_backend):
2976+
pytorch_config = dict(
2977+
disable_overlap_scheduler=not overlap_scheduler,
2978+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
2979+
moe_config=MoeConfig(backend=moe_backend))
2980+
2981+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
2982+
with LLM(f"{llm_models_root()}/Qwen3/Qwen3-235B-A22B-FP8",
2983+
tensor_parallel_size=tp_size,
2984+
pipeline_parallel_size=pp_size,
2985+
moe_expert_parallel_size=ep_size,
2986+
**pytorch_config,
2987+
enable_attention_dp=attention_dp,
2988+
kv_cache_config=kv_cache_config) as llm:
2989+
task = MMLU(self.MODEL_NAME)
2990+
task.evaluate(llm)
2991+
task = GSM8K(self.MODEL_NAME)
2992+
task.evaluate(llm)
2993+
29622994
@skip_pre_blackwell
29632995
@pytest.mark.skip_less_mpi_world_size(8)
29642996
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)