Skip to content

Commit 37e2b64

Browse files
committed
Address review comments from tcherckez
Signed-off-by: Neta Zmora <[email protected]>
1 parent 2af6319 commit 37e2b64

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,14 +273,14 @@ def trtllm_quant_nvfp4_moe_fused(
273273
w3_blockscale_fp8: torch.Tensor, # Block scale for w3 (fp8 )
274274
fc1_act_global: torch.Tensor, # Global scale for FC1 activations
275275
fc2_act_global: torch.Tensor, # Global scale for FC2 activations
276-
fc1_global: Optional[
276+
fc1_alpha: Optional[
277277
torch.Tensor
278278
] = None, # Precomputed global scale for FC1 (1.0 / (fc1_act_global * fc1_weight_gs))
279-
fc2_global: Optional[
279+
fc2_alpha: Optional[
280280
torch.Tensor
281281
] = None, # Precomputed global scale for FC2 (1.0 / (fc2_act_global * fc2_weight_gs))
282282
input_blockscale: Optional[torch.Tensor] = None, # Input scale factors for NVFP4 input
283-
output_dtype: Optional[torch.dtype] = None, # Output dtype for NVFP4 input
283+
output_dtype: Optional[torch.dtype] = None, # determines output dtype when input is NVFP4
284284
mlp_style: str = "gated_mlp",
285285
act_fn: str = "silu",
286286
) -> torch.Tensor:
@@ -322,16 +322,16 @@ def trtllm_quant_nvfp4_moe_fused(
322322

323323
fc2_weight_block_scale = w2_blockscale_fp8
324324
fc2_weight_gs = w2_global_scale
325-
fc1_global = 1.0 / (fc1_act_global * fc1_weight_gs) if fc1_global is None else fc1_global
326-
fc2_global = 1.0 / (fc2_act_global * fc2_weight_gs) if fc2_global is None else fc2_global
325+
fc1_alpha = 1.0 / (fc1_act_global * fc1_weight_gs) if fc1_alpha is None else fc1_alpha
326+
fc2_alpha = 1.0 / (fc2_act_global * fc2_weight_gs) if fc2_alpha is None else fc2_alpha
327327

328328
quant_scales = [
329329
fc1_act_global,
330330
fc1_weight_blockscale.view(torch.int32),
331-
fc1_global,
331+
fc1_alpha,
332332
fc2_act_global,
333333
fc2_weight_block_scale.view(torch.int32),
334-
fc2_global,
334+
fc2_alpha,
335335
]
336336

337337
if x.dtype in (torch.float16, torch.bfloat16):

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,6 @@ def test_trtllm_fused_moe(
229229
activation_func=activation_func,
230230
)
231231

232-
torch.cuda.synchronize()
233-
print("before fused_moe.cutlass_fused_moe")
234-
235232
assert itype == torch.bfloat16 or itype == torch.float16, (
236233
"F16 test only supports bfloat16 or float16"
237234
)
@@ -256,6 +253,7 @@ def get_fc1_expert_weights(
256253
_, w1_weight = torch.chunk(w31_weight, 2, dim=1)
257254
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
258255

256+
torch.cuda.synchronize()
259257
ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused(
260258
x,
261259
selected_experts.to(torch.int),
@@ -500,7 +498,7 @@ def act(weight, mask):
500498
inter_gs,
501499
dtype=inter.dtype,
502500
device=inter.device,
503-
block_size=16,
501+
block_size=NVFP4_BLOCK_SIZE,
504502
).cuda()
505503
out[mask] = inter @ w2[i].transpose(0, 1)
506504
return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@@ -565,6 +563,7 @@ def break_fp4_bytes(a, dtype):
565563
]
566564

567565

566+
@pytest.mark.parametrize("precompute_fc_alphas", [True, False])
568567
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
569568
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
570569
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@@ -579,6 +578,7 @@ def break_fp4_bytes(a, dtype):
579578
reason="Requires fp4 and trtllm support",
580579
)
581580
def test_trtllm_fused_moe_nvfp4(
581+
precompute_fc_alphas,
582582
batch_size,
583583
hidden_size,
584584
num_experts,
@@ -693,13 +693,13 @@ def round_up(x, y):
693693

694694
routing_weights, selected_experts = compute_routing(router_logits, top_k)
695695

696-
if True:
696+
if precompute_fc_alphas:
697697
fc1_weight_gs = torch.max(w3_gs, w1_gs)
698-
fc1_global = 1.0 / (fc1_act_global * fc1_weight_gs)
699-
fc2_global = 1.0 / (fc2_act_global * w2_gs)
698+
fc1_alpha = 1.0 / (fc1_act_global * fc1_weight_gs)
699+
fc2_alpha = 1.0 / (fc2_act_global * w2_gs)
700700
else:
701-
fc1_global = None
702-
fc2_global = None
701+
fc1_alpha = None
702+
fc2_alpha = None
703703

704704
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
705705
trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused(
@@ -717,8 +717,8 @@ def round_up(x, y):
717717
w3_blockscale,
718718
fc1_act_global,
719719
fc2_act_global,
720-
fc1_global=fc1_global,
721-
fc2_global=fc2_global,
720+
fc1_alpha=fc1_alpha,
721+
fc2_alpha=fc2_alpha,
722722
input_blockscale=None,
723723
output_dtype=otype,
724724
mlp_style=mlp_style,

0 commit comments

Comments
 (0)