@@ -364,20 +364,20 @@ def trtllm_quant_nvfp4_moe_fused_fake(
364364 x : torch .Tensor ,
365365 selected_experts : torch .Tensor ,
366366 routing_weights : torch .Tensor ,
367- w1_weight_q : torch .Tensor ,
368- w2_weight_q : torch .Tensor ,
369- w3_weight_q : torch .Tensor ,
370- w1_weight_gs : torch .Tensor ,
371- w2_weight_gs : torch .Tensor ,
372- w3_weight_gs : torch .Tensor ,
373- w1_blockscale : torch .Tensor ,
374- w2_blockscale : torch .Tensor ,
375- w3_blockscale : torch .Tensor ,
367+ w1_fp4 : torch .Tensor ,
368+ w2_fp4 : torch .Tensor ,
369+ w3_fp4 : torch .Tensor ,
370+ w1_global_scale : torch .Tensor ,
371+ w2_global_scale : torch .Tensor ,
372+ w3_global_scale : torch .Tensor ,
373+ w1_blockscale_fp8 : torch .Tensor ,
374+ w2_blockscale_fp8 : torch .Tensor ,
375+ w3_blockscale_fp8 : torch .Tensor ,
376376 fc1_act_global : torch .Tensor ,
377377 fc2_act_global : torch .Tensor ,
378- fc1_global : Optional [torch .Tensor ] = None ,
379- fc2_global : Optional [torch .Tensor ] = None ,
380- input_sf : Optional [torch .Tensor ] = None ,
378+ fc1_alpha : Optional [torch .Tensor ] = None ,
379+ fc2_alpha : Optional [torch .Tensor ] = None ,
380+ input_blockscale : Optional [torch .Tensor ] = None ,
381381 output_dtype : Optional [torch .dtype ] = None ,
382382 mlp_style : str = "gated_mlp" ,
383383 act_fn : str = "silu" ,
0 commit comments