Skip to content

Commit d44a199

Browse files
committed
Address review comments
Signed-off-by: Neta Zmora <[email protected]>
1 parent 76c542a commit d44a199

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ def trtllm_quant_nvfp4_moe_fused(
271271
w1_blockscale_fp8: torch.Tensor, # Block scale for w1 (fp8 )
272272
w2_blockscale_fp8: torch.Tensor, # Block scale for w2 (fp8 )
273273
w3_blockscale_fp8: torch.Tensor, # Block scale for w3 (fp8 )
274-
fc1_act_global: torch.Tensor, # Global scale for FC1 activations
275-
fc2_act_global: torch.Tensor, # Global scale for FC2 activations
274+
fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations
275+
fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations
276276
fc1_alpha: Optional[
277277
torch.Tensor
278278
] = None, # Precomputed global scale for FC1 (1.0 / (fc1_act_global * fc1_weight_gs))
@@ -322,21 +322,21 @@ def trtllm_quant_nvfp4_moe_fused(
322322

323323
fc2_weight_block_scale = w2_blockscale_fp8
324324
fc2_weight_gs = w2_global_scale
325-
fc1_alpha = 1.0 / (fc1_act_global * fc1_weight_gs) if fc1_alpha is None else fc1_alpha
326-
fc2_alpha = 1.0 / (fc2_act_global * fc2_weight_gs) if fc2_alpha is None else fc2_alpha
325+
fc1_alpha = 1.0 / (fc1_act_global_scale * fc1_weight_gs) if fc1_alpha is None else fc1_alpha
326+
fc2_alpha = 1.0 / (fc2_act_global_scale * fc2_weight_gs) if fc2_alpha is None else fc2_alpha
327327

328328
quant_scales = [
329-
fc1_act_global,
329+
fc1_act_global_scale,
330330
fc1_weight_blockscale.view(torch.int32),
331331
fc1_alpha,
332-
fc2_act_global,
332+
fc2_act_global_scale,
333333
fc2_weight_block_scale.view(torch.int32),
334334
fc2_alpha,
335335
]
336336

337337
if x.dtype in (torch.float16, torch.bfloat16):
338338
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
339-
x, fc1_act_global, NVFP4_BLOCK_SIZE
339+
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
340340
)
341341
output_dtype = x.dtype
342342
else:

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -688,15 +688,15 @@ def round_up(x, y):
688688
w3_gs,
689689
) = _quantize_weights(w1, w2, w3)
690690

691-
fc1_act_global = torch.tensor(1.0, device="cuda", dtype=torch.float32)
692-
fc2_act_global = torch.tensor(1.0, device="cuda", dtype=torch.float32)
691+
fc1_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
692+
fc2_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
693693

694694
routing_weights, selected_experts = compute_routing(router_logits, top_k)
695695

696696
if precompute_fc_alphas:
697697
fc1_weight_gs = torch.max(w3_gs, w1_gs)
698-
fc1_alpha = 1.0 / (fc1_act_global * fc1_weight_gs)
699-
fc2_alpha = 1.0 / (fc2_act_global * w2_gs)
698+
fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs)
699+
fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs)
700700
else:
701701
fc1_alpha = None
702702
fc2_alpha = None
@@ -715,8 +715,8 @@ def round_up(x, y):
715715
w1_blockscale,
716716
w2_blockscale,
717717
w3_blockscale,
718-
fc1_act_global,
719-
fc2_act_global,
718+
fc1_activation_gs,
719+
fc2_activation_gs,
720720
fc1_alpha=fc1_alpha,
721721
fc2_alpha=fc2_alpha,
722722
input_blockscale=None,
@@ -728,12 +728,12 @@ def round_up(x, y):
728728
def compute_ref_output(w1_gs, w3_gs):
729729
# Quantize then dequantize the input to emulate the precision loss.
730730
a_fp4, a_scale_interleaved = torch.ops.trtllm.fp4_quantize(
731-
x, fc1_act_global, NVFP4_BLOCK_SIZE
731+
x, fc1_activation_gs, NVFP4_BLOCK_SIZE
732732
)
733733
x_dq = dequantize_nvfp4_to_dtype(
734734
a_fp4,
735735
a_scale_interleaved,
736-
fc1_act_global,
736+
fc1_activation_gs,
737737
dtype=otype,
738738
device=x.device,
739739
block_size=NVFP4_BLOCK_SIZE,

0 commit comments

Comments
 (0)