Skip to content

Commit e84963f

Browse files
committed
Enable relu2 tests
Signed-off-by: Neta Zmora <[email protected]>
1 parent 64dd811 commit e84963f

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,7 @@ def break_fp4_bytes(a, dtype):
569569
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
570570
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
571571
@pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES)
572-
# relu2 support requires merge of https://github.com/NVIDIA/TensorRT-LLM/pull/9261
573-
# @pytest.mark.parametrize("activation_func", ["silu", "relu2"])
574-
@pytest.mark.parametrize("activation_func", ["silu"])
572+
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
575573
@pytest.mark.skipif(
576574
not fp4_compatible() or not trtllm_ops_available(),
577575
reason="Requires fp4 and trtllm support",
@@ -698,8 +696,7 @@ def round_up(x, y):
698696
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
699697
if mlp_style == "gated_mlp":
700698
# For gated MLP, concatenate w1 and w3 as [w3, w1]
701-
w3_w1_stacked = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous()
702-
fc1_expert_weights_fp4 = w3_w1_stacked
699+
fc1_expert_weights_fp4 = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous()
703700
fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1)
704701
fc1_weight_gs = torch.max(w3_gs, w1_gs)
705702
if activation_func != "silu":
@@ -709,7 +706,7 @@ def round_up(x, y):
709706
elif mlp_style == "mlp":
710707
# For non-gated MLP with ReLU^2
711708
fc1_expert_weights_fp4 = w1_q_fp4
712-
fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long)
709+
fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1)
713710
fc1_weight_gs = w1_gs
714711
if activation_func != "relu2":
715712
raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.")

0 commit comments

Comments
 (0)