@@ -569,9 +569,7 @@ def break_fp4_bytes(a, dtype):
569569@pytest .mark .parametrize ("top_k" , TOP_K_VALUES )
570570@pytest .mark .parametrize ("intermediate_size" , INTERMEDIATE_SIZES )
571571@pytest .mark .parametrize ("otype, wtype" , NVFP4_TEST_DTYPES )
572- # relu2 support requires merge of https://github.com/NVIDIA/TensorRT-LLM/pull/9261
573- # @pytest.mark.parametrize("activation_func", ["silu", "relu2"])
574- @pytest .mark .parametrize ("activation_func" , ["silu" ])
572+ @pytest .mark .parametrize ("activation_func" , ["silu" , "relu2" ])
575573@pytest .mark .skipif (
576574 not fp4_compatible () or not trtllm_ops_available (),
577575 reason = "Requires fp4 and trtllm support" ,
@@ -698,8 +696,7 @@ def round_up(x, y):
698696 mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
699697 if mlp_style == "gated_mlp" :
700698 # For gated MLP, concatenate w1 and w3 as [w3, w1]
701- w3_w1_stacked = torch .cat ([w3_q_fp4 , w1_q_fp4 ], dim = 1 ).contiguous ()
702- fc1_expert_weights_fp4 = w3_w1_stacked
699+ fc1_expert_weights_fp4 = torch .cat ([w3_q_fp4 , w1_q_fp4 ], dim = 1 ).contiguous ()
703700 fc1_weight_blockscale_fp8 = torch .cat ([w3_blockscale , w1_blockscale ], dim = 1 )
704701 fc1_weight_gs = torch .max (w3_gs , w1_gs )
705702 if activation_func != "silu" :
@@ -709,7 +706,7 @@ def round_up(x, y):
709706 elif mlp_style == "mlp" :
710707 # For non-gated MLP with ReLU^2
711708 fc1_expert_weights_fp4 = w1_q_fp4
712- fc1_weight_blockscale_fp8 = w1_blockscale . view ( torch .long )
709+ fc1_weight_blockscale_fp8 = torch .cat ([ w3_blockscale , w1_blockscale ], dim = 1 )
713710 fc1_weight_gs = w1_gs
714711 if activation_func != "relu2" :
715712 raise ValueError (f"Unsupported activation '{ activation_func } ' for mlp. Use 'relu2'." )
0 commit comments