diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 8b130d98744..827d47c44ae 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import torch from tensorrt_llm._torch.utils import ActivationType @@ -234,5 +250,110 @@ def trtllm_quant_fp8_moe_fused_fake( gemm2_dequant: torch.Tensor, mlp_style: str, act_fn: str, +) -> torch.Tensor: + _validate_mlp_style_and_act_fn(mlp_style, act_fn) + return torch.empty_like(x) + + +@torch.library.custom_op("auto_deploy::trtllm_quant_nvfp4_moe_fused", mutates_args=()) +def trtllm_quant_nvfp4_moe_fused( + x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + fc1_expert_weights_fp4: torch.Tensor, # [E, 2*I, H] or [E, I, H]; uint8 + fc2_expert_weights_fp4: torch.Tensor, # [E, H, I]; uint8 + fc1_weight_blockscale_fp8: torch.Tensor, # Global scale for fc1 (scalar) + fc2_weight_blockscale_fp8: torch.Tensor, # Global scale for w2 (scalar) + fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations + fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations + fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8)) + fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8)) + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + """TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP. + + Computes (per expert): + For gated_mlp: + y = (act(x @ w1.T) * (x @ w3.T)) @ w2.T # act := SiLU + For mlp: + y = act(x @ w1.T) @ w2.T # act := ReLU^2 + + + FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T) + FC2 implements: fc2_output = fc1_output @ w2.T + + """ + NVFP4_BLOCK_SIZE = 16 + mlp_style = mlp_style.lower() + act_fn = act_fn.lower() + + activation_type = ActivationType.Swiglu + if mlp_style == "gated_mlp": + if act_fn == "silu": + activation_type = ActivationType.Swiglu + else: + raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") + elif mlp_style == "mlp": + if act_fn == "relu2": + activation_type = ActivationType.Relu2 + else: + raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + + # quant_scales is described by this code: + # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 + quant_scales = [ + fc1_act_global_scale, # torch.float32; [E] or scalar + fc1_weight_blockscale_fp8.view( + torch.int32 + ), # 4 FP8 as packed int32; [E, I*2, H / 16 / 4] or [E, I, H / 16 / 4] + fc1_alpha, # torch.float32; [E] + fc2_act_global_scale, # torch.float32; [E] or scalar + fc2_weight_blockscale_fp8.view(torch.int32), # 4 FP8 as packed int32; [E, H, I / 16 / 4] + fc2_alpha, # torch.float32; [E] + ] + + if x.dtype in (torch.float16, torch.bfloat16): + x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize( + x, fc1_act_global_scale, NVFP4_BLOCK_SIZE + ) + output_dtype = x.dtype + else: + x_q_fp4 = x + + trtllm_output = torch.ops.trtllm.fused_moe( + x_q_fp4, + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights=fc1_expert_weights_fp4, + fc1_expert_biases=None, + fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long), + fc2_expert_biases=None, + output_dtype=output_dtype, + quant_scales=quant_scales, + input_sf=input_blockscale, + activation_type=activation_type, + )[0].view(x.shape) + + return trtllm_output + + +@trtllm_quant_nvfp4_moe_fused.register_fake +def trtllm_quant_nvfp4_moe_fused_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + fc1_expert_weights_fp4: torch.Tensor, + fc2_expert_weights_fp4: torch.Tensor, + fc1_weight_blockscale_fp8: torch.Tensor, + fc2_weight_blockscale_fp8: torch.Tensor, + fc1_act_global_scale: torch.Tensor, + fc2_act_global_scale: torch.Tensor, + fc1_alpha: torch.Tensor, + fc2_alpha: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 3e13e28a0c5..c9aea8bc607 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -7,7 +7,7 @@ import pytest import torch -from _torch_test_utils import fp8_compatible, trtllm_ops_available # noqa: F401 +from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available # noqa: F401 from torch.nn import functional as F from utils.util import skip_pre_hopper @@ -15,7 +15,9 @@ from tensorrt_llm._torch.utils import ActivationType FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +FLOAT4_E2M1_MAX = 6.0 FP8_DTYPE = torch.float8_e4m3fn +NVFP4_BLOCK_SIZE = 16 def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]: @@ -227,9 +229,6 @@ def test_trtllm_fused_moe( activation_func=activation_func, ) - torch.cuda.synchronize() - print("before fused_moe.cutlass_fused_moe") - assert itype == torch.bfloat16 or itype == torch.float16, ( "F16 test only supports bfloat16 or float16" ) @@ -254,6 +253,7 @@ def get_fc1_expert_weights( _, w1_weight = torch.chunk(w31_weight, 2, dim=1) mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + torch.cuda.synchronize() ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused( x, selected_experts.to(torch.int), @@ -344,7 +344,6 @@ def test_trtllm_fused_moe_fp8( W_GEN_SCALE = 0.1 def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE): - # input_shape = (batch_size, hidden_size) w31_shape = (num_experts, 2 * intermediate_size, hidden_size) w2_shape = (num_experts, hidden_size, intermediate_size) @@ -397,8 +396,6 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE w3_input_scale = torch.tensor([1.0]).cuda() w2_input_scale = torch.tensor([1.0]).cuda() - torch.cuda.synchronize() - print("before fused_moe.cutlass_fused_moe") # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1) @@ -408,6 +405,9 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE gemm1_dequant = (w1_scales * hidden_states_scale).contiguous().squeeze().to(torch.float32) gemm2_act_quant = (1.0 / w2_input_scale[0]).contiguous().to(torch.float32) gemm2_dequant = (w2_scales * w2_input_scale[0]).contiguous().squeeze().to(torch.float32) + + print("before fused_moe.cutlass_fused_moe") + torch.cuda.synchronize() ad_test_output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused( x, # Note! unquantized input is expected selected_experts.to(torch.int), @@ -455,3 +455,347 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE _print_diff_if(lambda diff: diff.max() > 1e-1, diff, ad_test_output, ref_output) torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-1, atol=1e-1) + + +# Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py +def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids, activation_type): + """Reference implementation of NVFP4 MoE. + + The intermediate activations are quantized and dequantized to emulate the precision loss of a real + quantized operation. + """ + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + if activation_type == ActivationType.Swiglu: + + def act(weight, mask): + m = weight.shape[0] + assert m % 2 == 0 + w1_expert, w3_expert = weight[m // 2 :, :], weight[: m // 2, :] + return F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) + elif activation_type == ActivationType.Relu2: + + def act(weight, mask): + return F.relu(a[mask] @ weight.t()) ** 2 + else: + raise ValueError(f"Unsupported activation type {activation_type}") + + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter = act(w1[i], mask) + inter_gs = torch.tensor(1.0).cuda() + inter_q, inter_blockscale = torch.ops.trtllm.fp4_quantize( + inter, inter_gs, NVFP4_BLOCK_SIZE + ) + inter = dequantize_nvfp4_to_dtype( + inter_q, + inter_blockscale, + inter_gs, + dtype=inter.dtype, + device=inter.device, + block_size=NVFP4_BLOCK_SIZE, + ).cuda() + out[mask] = inter @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +# Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py +def dequantize_nvfp4_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + + def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + # Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py + def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1ToFloat = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32) + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +NVFP4_TEST_DTYPES = [ + (torch.float16, torch.float8_e4m3fn), + (torch.bfloat16, torch.float8_e4m3fn), +] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +@pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES) +@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.skipif( + not fp4_compatible() or not trtllm_ops_available(), + reason="Requires fp4 and trtllm support", +) +def test_trtllm_fused_moe_nvfp4( + batch_size, + hidden_size, + num_experts, + top_k, + intermediate_size, + otype, + wtype, + activation_func, +): + # In the code below: + # sf := block scale factors for NVFP4 + # blockscale := block scale factor for NVFP4 + # gs := global scale for NVFP4 + + # Skip invalid configurations + if top_k > num_experts: + pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") + torch.manual_seed(42) + + def _get_test_data( + otype, + batch_size, + hidden_size, + num_experts, + intermediate_size, + ): + x = gen_tensor((batch_size, hidden_size), otype) + w1_shape = (num_experts, intermediate_size, hidden_size) + w3_shape = w1_shape + w1 = gen_tensor(w1_shape, otype, scale=0.1) + w2 = gen_tensor((num_experts, hidden_size, intermediate_size), otype, scale=0.1) + w3 = gen_tensor(w3_shape, otype, scale=0.1) + router_logits = torch.randn(batch_size, num_experts, dtype=otype).cuda() + return x, w1, w2, w3, router_logits + + def _quantize_weights(w1, w2, w3): + def round_up(x, y): + return (x + y - 1) // y * y + + w1_n = w1.shape[1] + w3_n = w3.shape[1] + sf_w1_n = round_up(w1_n, 128) + sf_w3_n = round_up(w3_n, 128) + sf_w1_k = round_up(hidden_size // NVFP4_BLOCK_SIZE, 4) + w1_blockscale = torch.empty( + (num_experts, sf_w1_n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + ) + sf_w2_k = round_up(hidden_size, 128) + sf_w2_n = round_up(intermediate_size // NVFP4_BLOCK_SIZE, 4) + w2_blockscale = torch.empty( + (num_experts, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn + ) + w3_blockscale = torch.empty( + (num_experts, sf_w3_n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + ) + w1_q = torch.empty((num_experts, w1_n, hidden_size // 2), device="cuda", dtype=torch.uint8) + w2_q = torch.empty( + (num_experts, hidden_size, intermediate_size // 2), device="cuda", dtype=torch.uint8 + ) + w3_q = torch.empty((num_experts, w3_n, hidden_size // 2), device="cuda", dtype=torch.uint8) + + w1_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) + w3_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) + + for expert in range(num_experts): + w1_amax = torch.abs(w1[expert]).max().to(torch.float32) + w2_amax = torch.abs(w2[expert]).max().to(torch.float32) + w3_amax = torch.abs(w3[expert]).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + w3_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w3_amax + + nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( + w1[expert], w1_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + ) + w1_q[expert] = nvfp4_vals + w1_blockscale[expert] = fp8_block_scales.reshape(w1_blockscale[expert].shape) + + nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( + w2[expert], w2_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + ) + w2_q[expert] = nvfp4_vals + w2_blockscale[expert] = fp8_block_scales.reshape(w2_blockscale[expert].shape) + + nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( + w3[expert], w3_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + ) + w3_q[expert] = nvfp4_vals + w3_blockscale[expert] = fp8_block_scales.reshape(w3_blockscale[expert].shape) + + return w1_q, w2_q, w3_q, w1_blockscale, w2_blockscale, w3_blockscale, w1_gs, w2_gs, w3_gs + + x, w1, w2, w3, router_logits = _get_test_data( + otype, batch_size, hidden_size, num_experts, intermediate_size + ) + + ( + w1_q_fp4, + w2_q_fp4, + w3_q_fp4, + w1_blockscale, + w2_blockscale, + w3_blockscale, + w1_gs, + w2_gs, + w3_gs, + ) = _quantize_weights(w1, w2, w3) + + fc1_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + fc2_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + routing_weights, selected_experts = compute_routing(router_logits, top_k) + + fc1_weight_gs = torch.max(w3_gs, w1_gs) + fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs) + fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs) + + mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + if mlp_style == "gated_mlp": + # For gated MLP, concatenate w1 and w3 as [w3, w1] + fc1_expert_weights_fp4 = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous() + fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1) + fc1_weight_gs = torch.max(w3_gs, w1_gs) + if activation_func != "silu": + raise ValueError( + f"Unsupported activation '{activation_func}' for gated_mlp. Use 'silu'." + ) + elif mlp_style == "mlp": + # For non-gated MLP with ReLU^2 + fc1_expert_weights_fp4 = w1_q_fp4 + fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long) + fc1_weight_gs = w1_gs + if activation_func != "relu2": + raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.") + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + + fc2_expert_weights_fp4 = w2_q_fp4.view(torch.long) + fc2_weight_blockscale_fp8 = w2_blockscale.view(torch.long) + fc1_expert_weights_fp4 = fc1_expert_weights_fp4.view(torch.long) + + trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( + x, + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights_fp4, + fc2_expert_weights_fp4, + fc1_weight_blockscale_fp8, + fc2_weight_blockscale_fp8, + fc1_activation_gs, + fc2_activation_gs, + fc1_alpha, + fc2_alpha, + mlp_style=mlp_style, + act_fn=activation_func, + ) + + def compute_ref_output(w1_gs, w3_gs): + # Quantize then dequantize the input to emulate the precision loss. + a_fp4, a_scale_interleaved = torch.ops.trtllm.fp4_quantize( + x, fc1_activation_gs, NVFP4_BLOCK_SIZE + ) + x_dq = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + fc1_activation_gs, + dtype=otype, + device=x.device, + block_size=NVFP4_BLOCK_SIZE, + ) + + concat_w3_w1 = mlp_style == "gated_mlp" + if concat_w3_w1: + w1_gs = w3_gs = torch.max(w1_gs, w3_gs) + + w1_dq = torch.empty(w1.shape, device="cuda", dtype=otype) + w3_dq = torch.empty(w3.shape, device="cuda", dtype=otype) + w2_dq = torch.empty(w2.shape, device="cuda", dtype=otype) + + # Dequantize the weights to emulate the precision loss. + for idx in range(0, num_experts): + w1_dq[idx] = dequantize_nvfp4_to_dtype( + w1_q_fp4[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=w1.dtype, + device=w1.device, + block_size=NVFP4_BLOCK_SIZE, + ) + w2_dq[idx] = dequantize_nvfp4_to_dtype( + w2_q_fp4[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=NVFP4_BLOCK_SIZE, + ) + w3_dq[idx] = dequantize_nvfp4_to_dtype( + w3_q_fp4[idx], + w3_blockscale[idx], + w3_gs[idx], + dtype=w3.dtype, + device=w3.device, + block_size=NVFP4_BLOCK_SIZE, + ) + + ref_output = torch_moe_nvfp4( + x_dq, + torch.cat([w3_dq, w1_dq], dim=1) if concat_w3_w1 else w1_dq, + w2_dq, + top_k, + routing_weights, + selected_experts, + _activation_type_from_str(activation_func), + ) + return ref_output + + ref_output = compute_ref_output(w1_gs, w3_gs) + print(f"max diff: {(ref_output - trtllm_output).abs().max()}") + print(f"diff = {ref_output - trtllm_output}") + print(f"ref_output = {ref_output}") + print(f"flash_output = {trtllm_output}") + torch.testing.assert_close(ref_output, trtllm_output, rtol=2e-1, atol=2e-1)