From 85cd81f75bc6e94fb9e16c807b58ebf86f43d521 Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Sat, 29 Nov 2025 15:30:29 -0800 Subject: [PATCH 01/11] [#9550][feature] AutoDeploy: Add NVFP4 Cutlass MoE kernels Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 107 ++++++ .../singlegpu/custom_ops/test_trtllm_moe.py | 333 +++++++++++++++++- 2 files changed, 436 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 8b130d98744..c56ebba0b2d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from tensorrt_llm._torch.utils import ActivationType @@ -235,4 +237,109 @@ def trtllm_quant_fp8_moe_fused_fake( mlp_style: str, act_fn: str, ) -> torch.Tensor: + _validate_mlp_style_and_act_fn(mlp_style, act_fn) return torch.empty_like(x) + + +@torch.library.custom_op("auto_deploy::trtllm_quant_nvfp4_moe_fused", mutates_args=()) +def trtllm_quant_nvfp4_moe_fused( + x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight_q: torch.Tensor, # [E, I, H] stacked FP4 weights + w2_weight_q: torch.Tensor, # [E, H, I] stacked FP4 weights + w3_weight_q: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp + w1_weight_gs: torch.Tensor, + w2_weight_gs: torch.Tensor, + w3_weight_gs: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + w3_blockscale: torch.Tensor, + fc1_act_global: torch.Tensor, + fc2_act_global: torch.Tensor, + fc1_global: Optional[torch.Tensor] = None, # Precomputed global scale for FC1 + fc2_global: Optional[torch.Tensor] = None, # Precomputed global scale for FC2 + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + NVFP4_BLOCK_SIZE = 16 + mlp_style = mlp_style.lower() + act_fn = act_fn.lower() + + activation_type = ActivationType.Swiglu + if mlp_style == "gated_mlp": + # For gated MLP, concatenate w1 and w3 as [w3, w1] + w3_w1_stacked = torch.cat([w3_weight_q, w1_weight_q], dim=1).contiguous().view(torch.long) + fc1_expert_weights = w3_w1_stacked + fc1_weight_blockscale = torch.cat([w3_blockscale, w1_blockscale], dim=1) + fc1_weight_gs = torch.max(w3_weight_gs, w1_weight_gs) + if act_fn == "silu": + activation_type = ActivationType.Swiglu + else: + raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") + elif mlp_style == "mlp": + # For non-gated MLP with ReLU^2 + fc1_expert_weights = w1_weight_q.contiguous().view(torch.long) + fc1_weight_blockscale = w1_blockscale.contiguous().view(torch.long) + fc1_weight_gs = w1_weight_gs + if act_fn == "relu2": + activation_type = ActivationType.Relu2 + else: + raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + + fc2_weight_block_scale = w2_blockscale + fc2_weight_gs = w2_weight_gs + fc1_global = fc1_global or 1.0 / (fc1_act_global * fc1_weight_gs) + fc2_global = fc2_global or 1.0 / (fc2_act_global * fc2_weight_gs) + + quant_scales = [ + fc1_act_global, + fc1_weight_blockscale.view(torch.int32), + fc1_global, + fc2_act_global, + fc2_weight_block_scale.view(torch.int32), + fc2_global, + ] + + if x.dtype in (torch.float16, torch.bfloat16): + x_q_fp4, input_sf = torch.ops.trtllm.fp4_quantize(x, fc1_act_global, NVFP4_BLOCK_SIZE) + output_dtype = x.dtype + else: + x_q_fp4 = x + output_dtype = None + + trtllm_output = torch.ops.trtllm.fused_moe( + x_q_fp4, + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights=fc1_expert_weights, + fc1_expert_biases=None, + fc2_expert_weights=w2_weight_q.contiguous().view(torch.long), + fc2_expert_biases=None, + output_dtype=output_dtype, + quant_scales=quant_scales, + input_sf=input_sf, + activation_type=activation_type, + )[0].view(x.shape) + + return trtllm_output[0].view(x.shape) + + +@trtllm_quant_nvfp4_moe_fused.register_fake +def trtllm_quant_nvfp4_moe_fused_fake( + hidden_states: torch.Tensor, + router_weight: torch.Tensor, + router_bias: torch.Tensor, + top_k: int, + gate_up_blocks: torch.Tensor, + gate_up_bias: torch.Tensor, + gate_up_scales: torch.Tensor, + alpha: float, + limit: float, + down_blocks: torch.Tensor, + down_bias: torch.Tensor, + down_scales: torch.Tensor, +) -> torch.Tensor: + return torch.empty_like(hidden_states) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 3e13e28a0c5..0ebc64a2735 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -7,7 +7,7 @@ import pytest import torch -from _torch_test_utils import fp8_compatible, trtllm_ops_available # noqa: F401 +from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available # noqa: F401 from torch.nn import functional as F from utils.util import skip_pre_hopper @@ -15,7 +15,9 @@ from tensorrt_llm._torch.utils import ActivationType FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +FLOAT4_E2M1_MAX = 6.0 FP8_DTYPE = torch.float8_e4m3fn +NVFP4_BLOCK_SIZE = 16 def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]: @@ -344,7 +346,6 @@ def test_trtllm_fused_moe_fp8( W_GEN_SCALE = 0.1 def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE): - # input_shape = (batch_size, hidden_size) w31_shape = (num_experts, 2 * intermediate_size, hidden_size) w2_shape = (num_experts, hidden_size, intermediate_size) @@ -397,8 +398,6 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE w3_input_scale = torch.tensor([1.0]).cuda() w2_input_scale = torch.tensor([1.0]).cuda() - torch.cuda.synchronize() - print("before fused_moe.cutlass_fused_moe") # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1) @@ -408,6 +407,9 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE gemm1_dequant = (w1_scales * hidden_states_scale).contiguous().squeeze().to(torch.float32) gemm2_act_quant = (1.0 / w2_input_scale[0]).contiguous().to(torch.float32) gemm2_dequant = (w2_scales * w2_input_scale[0]).contiguous().squeeze().to(torch.float32) + + print("before fused_moe.cutlass_fused_moe") + torch.cuda.synchronize() ad_test_output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused( x, # Note! unquantized input is expected selected_experts.to(torch.int), @@ -455,3 +457,326 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE _print_diff_if(lambda diff: diff.max() > 1e-1, diff, ad_test_output, ref_output) torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-1, atol=1e-1) + + +NVFP4_TEST_DTYPES = [ + (torch.float16, torch.float8_e4m3fn), + (torch.bfloat16, torch.float8_e4m3fn), +] + + +# Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py +def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids, activation_type): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + if activation_type == ActivationType.Swiglu: + + def act(weight, mask): + m = weight.shape[0] + assert m % 2 == 0 + w1_expert, w3_expert = weight[m // 2 :, :], weight[: m // 2, :] + return F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) + elif activation_type == ActivationType.Relu2: + + def act(weight, mask): + return F.relu(a[mask] @ weight.t()) ** 2 + else: + raise ValueError(f"Unsupported activation type {activation_type}") + + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter = act(w1[i], mask) + inter_gs = torch.tensor(1.0).cuda() + inter_q, inter_blockscale = torch.ops.trtllm.fp4_quantize( + inter, inter_gs, NVFP4_BLOCK_SIZE + ) + inter = dequantize_nvfp4_to_dtype( + inter_q, + inter_blockscale, + inter_gs, + dtype=inter.dtype, + device=inter.device, + block_size=16, + ).cuda() + out[mask] = inter @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +# Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py +def dequantize_nvfp4_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + + def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + # Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py + def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1ToFloat = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32) + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +@pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES) +@pytest.mark.parametrize("quantized_input", [False, True]) +# relu2 support requires merge of https://github.com/NVIDIA/TensorRT-LLM/pull/9261 +# @pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.parametrize("activation_func", ["silu"]) +@pytest.mark.skipif( + not fp4_compatible() or not trtllm_ops_available(), + reason="Requires fp4 and trtllm support", +) +def test_trtllm_fused_moe_nvfp4( + batch_size, + hidden_size, + num_experts, + top_k, + intermediate_size, + otype, + wtype, + quantized_input, + activation_func, +): + # Skip invalid configurations + if top_k > num_experts: + pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") + torch.manual_seed(42) + + def _get_test_data( + otype, + wtype, + batch_size, + hidden_size, + num_experts, + intermediate_size, + ): + x = gen_tensor((batch_size, hidden_size), otype) + w31_shape = (num_experts, 2 * intermediate_size, hidden_size) + w31 = gen_tensor(w31_shape, otype, scale=0.1) + w1_n = w31_shape[1] + w2 = gen_tensor((num_experts, hidden_size, intermediate_size), otype, scale=0.1) + w31_d = torch.empty((num_experts, w1_n, hidden_size), device="cuda", dtype=otype) + w2_d = torch.empty( + (num_experts, hidden_size, intermediate_size), device="cuda", dtype=otype + ) + router_logits = torch.randn(batch_size, num_experts, dtype=otype).cuda() + return x, w31, w2, w31_d, w2_d, router_logits + + def _quantize_weights(w31, w2): + def round_up(x, y): + return (x + y - 1) // y * y + + w1_n = w31.shape[1] + sf_w1_2n = round_up(w1_n, 128) + sf_w1_k = round_up(hidden_size // NVFP4_BLOCK_SIZE, 4) + w31_blockscale = torch.empty( + (num_experts, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + ) + sf_w2_k = round_up(hidden_size, 128) + sf_w2_n = round_up(intermediate_size // NVFP4_BLOCK_SIZE, 4) + w2_blockscale = torch.empty( + (num_experts, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn + ) + w31_q = torch.empty((num_experts, w1_n, hidden_size // 2), device="cuda", dtype=torch.uint8) + w2_q = torch.empty( + (num_experts, hidden_size, intermediate_size // 2), device="cuda", dtype=torch.uint8 + ) + + w1_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) + + for expert in range(num_experts): + w31_amax = torch.abs(w31[expert]).max().to(torch.float32) + w2_amax = torch.abs(w2).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w31_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( + w31[expert], w1_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + ) + w31_q[expert] = nvfp4_vals + w31_blockscale[expert] = fp8_block_scales.reshape(w31_blockscale[expert].shape) + + nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( + w2[expert], w2_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + ) + w2_q[expert] = nvfp4_vals + w2_blockscale[expert] = fp8_block_scales.reshape(w2_blockscale[expert].shape) + + return w31_q, w31_blockscale, w2_q, w2_blockscale, w1_gs, w2_gs + + x, w31, w2, w31_d, w2_d, router_logits = _get_test_data( + otype, wtype, batch_size, hidden_size, num_experts, intermediate_size + ) + w31_q, w31_blockscale, w2_q, w2_blockscale, w31_gs, w2_gs = _quantize_weights(w31, w2) + a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + routing_weights, selected_experts = compute_routing(router_logits, top_k) + + input_sf = None + if quantized_input: + x_q_fp4, input_sf = torch.ops.trtllm.fp4_quantize(x, a1_gs, NVFP4_BLOCK_SIZE) + else: + x_q_fp4 = x + + if False: + w3_weight_q, w1_weight_q = torch.chunk(w31_q, 2, dim=1) + w3_blockscale, w1_blockscale = torch.chunk(w31_blockscale, 2, dim=1) + + fc1_expert_weights = ( + torch.cat([w3_weight_q, w1_weight_q], dim=1).contiguous().view(torch.long) + ) + fc1_weight_blockscale = torch.cat([w3_blockscale, w1_blockscale], dim=1) + fc2_weight_block_scale = w2_blockscale + fc1_act_global = a1_gs + fc2_act_global = a2_gs + fc1_global = 1.0 / (fc1_act_global * w31_gs) + fc2_global = 1.0 / (fc2_act_global * w2_gs) + + quant_scales = [ + fc1_act_global, + fc1_weight_blockscale.view(torch.int32), # fc1_weight_block + fc1_global, + fc2_act_global, + fc2_weight_block_scale.view(torch.int32), # fc2_weight_block + fc2_global, + ] + + trtllm_output = torch.ops.trtllm.fused_moe( + x_q_fp4, + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights=fc1_expert_weights, + fc1_expert_biases=None, + fc2_expert_weights=w2_q.contiguous().view(torch.long), + fc2_expert_biases=None, + output_dtype=otype, + quant_scales=quant_scales, + input_sf=input_sf, + activation_type=_activation_type_from_str(activation_func), + )[0].view(x.shape) + + else: + w3_weight_q, w1_weight_q = torch.chunk(w31_q, 2, dim=1) + w2_weight_q = w2_q + w3_blockscale, w1_blockscale = torch.chunk(w31_blockscale, 2, dim=1) + + fc1_act_global = a1_gs + fc2_act_global = a2_gs + fc1_global = 1.0 / (fc1_act_global * w31_gs) + fc2_global = 1.0 / (fc2_act_global * w2_gs) + w1_weight_gs = w31_gs + w2_weight_gs = w2_gs + w3_weight_gs = w31_gs + mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( + x, + selected_experts.to(torch.int), + routing_weights, + w1_weight_q, + w2_weight_q, + w3_weight_q, + w1_weight_gs, + w2_weight_gs, + w3_weight_gs, + w1_blockscale, + w2_blockscale, + w3_blockscale, + fc1_act_global, + fc2_act_global, + mlp_style=mlp_style, + act_fn=activation_func, + ) + # Ref check + a_fp4, a_scale_interleaved = torch.ops.trtllm.fp4_quantize(x, a1_gs, NVFP4_BLOCK_SIZE) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a1_gs, + dtype=otype, + device=x.device, + block_size=NVFP4_BLOCK_SIZE, + ) + + for idx in range(0, num_experts): + w31_d[idx] = dequantize_nvfp4_to_dtype( + w31_q[idx], + w31_blockscale[idx], + w31_gs[idx], + dtype=w31.dtype, + device=w31.device, + block_size=NVFP4_BLOCK_SIZE, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=NVFP4_BLOCK_SIZE, + ) + + ref_output = torch_moe_nvfp4( + a_in_dtype, + w31_d, + w2_d, + top_k, + routing_weights, + selected_experts, + _activation_type_from_str(activation_func), + ) + print(f"max diff: {(ref_output - trtllm_output).abs().max()}") + print(f"diff = {ref_output - trtllm_output}") + print(f"ref_output = {ref_output}") + print(f"flash_output = {trtllm_output}") + torch.testing.assert_close(ref_output, trtllm_output, rtol=2e-1, atol=2e-1) From b8fd7b4ab133d4c60a7e407813c2e5321b97fc21 Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Sat, 29 Nov 2025 16:26:12 -0800 Subject: [PATCH 02/11] Refactoring Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 19 ++-- .../singlegpu/custom_ops/test_trtllm_moe.py | 106 +++++------------- 2 files changed, 40 insertions(+), 85 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index c56ebba0b2d..d98d5438b7f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -249,16 +249,18 @@ def trtllm_quant_nvfp4_moe_fused( w1_weight_q: torch.Tensor, # [E, I, H] stacked FP4 weights w2_weight_q: torch.Tensor, # [E, H, I] stacked FP4 weights w3_weight_q: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp - w1_weight_gs: torch.Tensor, - w2_weight_gs: torch.Tensor, - w3_weight_gs: torch.Tensor, - w1_blockscale: torch.Tensor, - w2_blockscale: torch.Tensor, - w3_blockscale: torch.Tensor, - fc1_act_global: torch.Tensor, - fc2_act_global: torch.Tensor, + w1_weight_gs: torch.Tensor, # Global scale for w1 + w2_weight_gs: torch.Tensor, # Global scale for w2 + w3_weight_gs: torch.Tensor, # Global scale for w3 + w1_blockscale: torch.Tensor, # Block scale for w1 + w2_blockscale: torch.Tensor, # Block scale for w2 + w3_blockscale: torch.Tensor, # Block scale for w3 + fc1_act_global: torch.Tensor, # Global scale for FC1 activations + fc2_act_global: torch.Tensor, # Global scale for FC2 activations fc1_global: Optional[torch.Tensor] = None, # Precomputed global scale for FC1 fc2_global: Optional[torch.Tensor] = None, # Precomputed global scale for FC2 + input_sf: Optional[torch.Tensor] = None, # Input scale factors for NVFP4 input + output_dtype: Optional[torch.dtype] = None, # Output dtype for NVFP4 input mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: @@ -308,7 +310,6 @@ def trtllm_quant_nvfp4_moe_fused( output_dtype = x.dtype else: x_q_fp4 = x - output_dtype = None trtllm_output = torch.ops.trtllm.fused_moe( x_q_fp4, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 0ebc64a2735..67c73eb2921 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -566,7 +566,6 @@ def break_fp4_bytes(a, dtype): @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES) -@pytest.mark.parametrize("quantized_input", [False, True]) # relu2 support requires merge of https://github.com/NVIDIA/TensorRT-LLM/pull/9261 # @pytest.mark.parametrize("activation_func", ["silu", "relu2"]) @pytest.mark.parametrize("activation_func", ["silu"]) @@ -582,7 +581,6 @@ def test_trtllm_fused_moe_nvfp4( intermediate_size, otype, wtype, - quantized_input, activation_func, ): # Skip invalid configurations @@ -662,81 +660,37 @@ def round_up(x, y): routing_weights, selected_experts = compute_routing(router_logits, top_k) - input_sf = None - if quantized_input: - x_q_fp4, input_sf = torch.ops.trtllm.fp4_quantize(x, a1_gs, NVFP4_BLOCK_SIZE) - else: - x_q_fp4 = x - - if False: - w3_weight_q, w1_weight_q = torch.chunk(w31_q, 2, dim=1) - w3_blockscale, w1_blockscale = torch.chunk(w31_blockscale, 2, dim=1) - - fc1_expert_weights = ( - torch.cat([w3_weight_q, w1_weight_q], dim=1).contiguous().view(torch.long) - ) - fc1_weight_blockscale = torch.cat([w3_blockscale, w1_blockscale], dim=1) - fc2_weight_block_scale = w2_blockscale - fc1_act_global = a1_gs - fc2_act_global = a2_gs - fc1_global = 1.0 / (fc1_act_global * w31_gs) - fc2_global = 1.0 / (fc2_act_global * w2_gs) - - quant_scales = [ - fc1_act_global, - fc1_weight_blockscale.view(torch.int32), # fc1_weight_block - fc1_global, - fc2_act_global, - fc2_weight_block_scale.view(torch.int32), # fc2_weight_block - fc2_global, - ] - - trtllm_output = torch.ops.trtllm.fused_moe( - x_q_fp4, - selected_experts.to(torch.int), - routing_weights, - fc1_expert_weights=fc1_expert_weights, - fc1_expert_biases=None, - fc2_expert_weights=w2_q.contiguous().view(torch.long), - fc2_expert_biases=None, - output_dtype=otype, - quant_scales=quant_scales, - input_sf=input_sf, - activation_type=_activation_type_from_str(activation_func), - )[0].view(x.shape) + w3_weight_q, w1_weight_q = torch.chunk(w31_q, 2, dim=1) + w2_weight_q = w2_q + w3_blockscale, w1_blockscale = torch.chunk(w31_blockscale, 2, dim=1) - else: - w3_weight_q, w1_weight_q = torch.chunk(w31_q, 2, dim=1) - w2_weight_q = w2_q - w3_blockscale, w1_blockscale = torch.chunk(w31_blockscale, 2, dim=1) - - fc1_act_global = a1_gs - fc2_act_global = a2_gs - fc1_global = 1.0 / (fc1_act_global * w31_gs) - fc2_global = 1.0 / (fc2_act_global * w2_gs) - w1_weight_gs = w31_gs - w2_weight_gs = w2_gs - w3_weight_gs = w31_gs - mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" - trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( - x, - selected_experts.to(torch.int), - routing_weights, - w1_weight_q, - w2_weight_q, - w3_weight_q, - w1_weight_gs, - w2_weight_gs, - w3_weight_gs, - w1_blockscale, - w2_blockscale, - w3_blockscale, - fc1_act_global, - fc2_act_global, - mlp_style=mlp_style, - act_fn=activation_func, - ) - # Ref check + fc1_act_global = a1_gs + fc2_act_global = a2_gs + w1_weight_gs = w31_gs + w2_weight_gs = w2_gs + w3_weight_gs = w31_gs + mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( + x, + selected_experts.to(torch.int), + routing_weights, + w1_weight_q, + w2_weight_q, + w3_weight_q, + w1_weight_gs, + w2_weight_gs, + w3_weight_gs, + w1_blockscale, + w2_blockscale, + w3_blockscale, + fc1_act_global, + fc2_act_global, + input_sf=None, + output_dtype=otype, + mlp_style=mlp_style, + act_fn=activation_func, + ) + # Ref check a_fp4, a_scale_interleaved = torch.ops.trtllm.fp4_quantize(x, a1_gs, NVFP4_BLOCK_SIZE) _, m_k = a_fp4.shape a_in_dtype = dequantize_nvfp4_to_dtype( From bf96ca1d770a9c82bda2c6d5ed8b3b645ada2ff6 Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Sun, 30 Nov 2025 00:22:40 -0800 Subject: [PATCH 03/11] Fixes per coderabbit review Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 52 ++++++++++++++----- .../singlegpu/custom_ops/test_trtllm_moe.py | 14 ++--- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index d98d5438b7f..0fe69fce244 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -1,3 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from typing import Optional import torch @@ -325,22 +341,30 @@ def trtllm_quant_nvfp4_moe_fused( activation_type=activation_type, )[0].view(x.shape) - return trtllm_output[0].view(x.shape) + return trtllm_output @trtllm_quant_nvfp4_moe_fused.register_fake def trtllm_quant_nvfp4_moe_fused_fake( - hidden_states: torch.Tensor, - router_weight: torch.Tensor, - router_bias: torch.Tensor, - top_k: int, - gate_up_blocks: torch.Tensor, - gate_up_bias: torch.Tensor, - gate_up_scales: torch.Tensor, - alpha: float, - limit: float, - down_blocks: torch.Tensor, - down_bias: torch.Tensor, - down_scales: torch.Tensor, + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight_q: torch.Tensor, + w2_weight_q: torch.Tensor, + w3_weight_q: torch.Tensor, + w1_weight_gs: torch.Tensor, + w2_weight_gs: torch.Tensor, + w3_weight_gs: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + w3_blockscale: torch.Tensor, + fc1_act_global: torch.Tensor, + fc2_act_global: torch.Tensor, + fc1_global: Optional[torch.Tensor] = None, + fc2_global: Optional[torch.Tensor] = None, + input_sf: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", ) -> torch.Tensor: - return torch.empty_like(hidden_states) + return torch.empty_like(x) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 67c73eb2921..0272dc22b3b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -459,12 +459,6 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-1, atol=1e-1) -NVFP4_TEST_DTYPES = [ - (torch.float16, torch.float8_e4m3fn), - (torch.bfloat16, torch.float8_e4m3fn), -] - - # Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids, activation_type): B, D = a.shape @@ -560,6 +554,12 @@ def break_fp4_bytes(a, dtype): return out.to(dtype=dtype) +NVFP4_TEST_DTYPES = [ + (torch.float16, torch.float8_e4m3fn), + (torch.bfloat16, torch.float8_e4m3fn), +] + + @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @@ -633,7 +633,7 @@ def round_up(x, y): for expert in range(num_experts): w31_amax = torch.abs(w31[expert]).max().to(torch.float32) - w2_amax = torch.abs(w2).max().to(torch.float32) + w2_amax = torch.abs(w2[expert]).max().to(torch.float32) w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w31_amax w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax From 0b4ec736f5d26b14bc4c66233a6d3f98f3ca71dc Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Sun, 30 Nov 2025 05:07:23 -0800 Subject: [PATCH 04/11] Code refactoring Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 65 +++--- .../singlegpu/custom_ops/test_trtllm_moe.py | 207 +++++++++++------- 2 files changed, 172 insertions(+), 100 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 0fe69fce244..96591429f1b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -262,24 +262,37 @@ def trtllm_quant_nvfp4_moe_fused( x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float selected_experts: torch.Tensor, routing_weights: torch.Tensor, - w1_weight_q: torch.Tensor, # [E, I, H] stacked FP4 weights - w2_weight_q: torch.Tensor, # [E, H, I] stacked FP4 weights - w3_weight_q: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp - w1_weight_gs: torch.Tensor, # Global scale for w1 - w2_weight_gs: torch.Tensor, # Global scale for w2 - w3_weight_gs: torch.Tensor, # Global scale for w3 - w1_blockscale: torch.Tensor, # Block scale for w1 - w2_blockscale: torch.Tensor, # Block scale for w2 - w3_blockscale: torch.Tensor, # Block scale for w3 + w1_fp4: torch.Tensor, # [E, I, H] stacked FP4 weights (uint8) + w2_fp4: torch.Tensor, # [E, H, I] stacked FP4 weights (uint8) + w3_fp4: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp (uint8) + w1_global_scale: torch.Tensor, # Global scale for w1 (scalar) + w2_global_scale: torch.Tensor, # Global scale for w2 (scalar) + w3_global_scale: torch.Tensor, # Global scale for w3 (scalar) + w1_blockscale_fp8: torch.Tensor, # Block scale for w1 (fp8 ) + w2_blockscale_fp8: torch.Tensor, # Block scale for w2 (fp8 ) + w3_blockscale_fp8: torch.Tensor, # Block scale for w3 (fp8 ) fc1_act_global: torch.Tensor, # Global scale for FC1 activations fc2_act_global: torch.Tensor, # Global scale for FC2 activations - fc1_global: Optional[torch.Tensor] = None, # Precomputed global scale for FC1 - fc2_global: Optional[torch.Tensor] = None, # Precomputed global scale for FC2 - input_sf: Optional[torch.Tensor] = None, # Input scale factors for NVFP4 input + fc1_global: Optional[ + torch.Tensor + ] = None, # Precomputed global scale for FC1 (1.0 / (fc1_act_global * fc1_weight_gs)) + fc2_global: Optional[ + torch.Tensor + ] = None, # Precomputed global scale for FC2 (1.0 / (fc2_act_global * fc2_weight_gs)) + input_blockscale: Optional[torch.Tensor] = None, # Input scale factors for NVFP4 input output_dtype: Optional[torch.dtype] = None, # Output dtype for NVFP4 input mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: + """TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP. + + Computes (per expert): + For gated_mlp: + y = act(x @ w1.T) @ (x @ w3.T) @ w2.T # act := SiLU + For mlp: + y = act(x @ w1.T) @ w2.T # act := ReLU^2 + + """ NVFP4_BLOCK_SIZE = 16 mlp_style = mlp_style.lower() act_fn = act_fn.lower() @@ -287,19 +300,19 @@ def trtllm_quant_nvfp4_moe_fused( activation_type = ActivationType.Swiglu if mlp_style == "gated_mlp": # For gated MLP, concatenate w1 and w3 as [w3, w1] - w3_w1_stacked = torch.cat([w3_weight_q, w1_weight_q], dim=1).contiguous().view(torch.long) + w3_w1_stacked = torch.cat([w3_fp4, w1_fp4], dim=1).contiguous().view(torch.long) fc1_expert_weights = w3_w1_stacked - fc1_weight_blockscale = torch.cat([w3_blockscale, w1_blockscale], dim=1) - fc1_weight_gs = torch.max(w3_weight_gs, w1_weight_gs) + fc1_weight_blockscale = torch.cat([w3_blockscale_fp8, w1_blockscale_fp8], dim=1) + fc1_weight_gs = torch.max(w3_global_scale, w1_global_scale) if act_fn == "silu": activation_type = ActivationType.Swiglu else: raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") elif mlp_style == "mlp": # For non-gated MLP with ReLU^2 - fc1_expert_weights = w1_weight_q.contiguous().view(torch.long) - fc1_weight_blockscale = w1_blockscale.contiguous().view(torch.long) - fc1_weight_gs = w1_weight_gs + fc1_expert_weights = w1_fp4.view(torch.long) + fc1_weight_blockscale = w1_blockscale_fp8.view(torch.long) + fc1_weight_gs = w1_global_scale if act_fn == "relu2": activation_type = ActivationType.Relu2 else: @@ -307,10 +320,10 @@ def trtllm_quant_nvfp4_moe_fused( else: raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - fc2_weight_block_scale = w2_blockscale - fc2_weight_gs = w2_weight_gs - fc1_global = fc1_global or 1.0 / (fc1_act_global * fc1_weight_gs) - fc2_global = fc2_global or 1.0 / (fc2_act_global * fc2_weight_gs) + fc2_weight_block_scale = w2_blockscale_fp8 + fc2_weight_gs = w2_global_scale + fc1_global = 1.0 / (fc1_act_global * fc1_weight_gs) if fc1_global is None else fc1_global + fc2_global = 1.0 / (fc2_act_global * fc2_weight_gs) if fc2_global is None else fc2_global quant_scales = [ fc1_act_global, @@ -322,7 +335,9 @@ def trtllm_quant_nvfp4_moe_fused( ] if x.dtype in (torch.float16, torch.bfloat16): - x_q_fp4, input_sf = torch.ops.trtllm.fp4_quantize(x, fc1_act_global, NVFP4_BLOCK_SIZE) + x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize( + x, fc1_act_global, NVFP4_BLOCK_SIZE + ) output_dtype = x.dtype else: x_q_fp4 = x @@ -333,11 +348,11 @@ def trtllm_quant_nvfp4_moe_fused( routing_weights, fc1_expert_weights=fc1_expert_weights, fc1_expert_biases=None, - fc2_expert_weights=w2_weight_q.contiguous().view(torch.long), + fc2_expert_weights=w2_fp4.view(torch.long), fc2_expert_biases=None, output_dtype=output_dtype, quant_scales=quant_scales, - input_sf=input_sf, + input_sf=input_blockscale, activation_type=activation_type, )[0].view(x.shape) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 0272dc22b3b..ad3ba40c7d8 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -461,6 +461,11 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE # Originally from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids, activation_type): + """Reference implementation of NVFP4 MoE. + + The intermediate activations are quantized and dequantized to emulate the precision loss of a real + quantized operation. + """ B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) @@ -583,6 +588,11 @@ def test_trtllm_fused_moe_nvfp4( wtype, activation_func, ): + # In the code below: + # sf := block scale factors for NVFP4 + # blockscale := block scale factor for NVFP4 + # gs := global scale for NVFP4 + # Skip invalid configurations if top_k > num_experts: pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") @@ -590,58 +600,63 @@ def test_trtllm_fused_moe_nvfp4( def _get_test_data( otype, - wtype, batch_size, hidden_size, num_experts, intermediate_size, ): x = gen_tensor((batch_size, hidden_size), otype) - w31_shape = (num_experts, 2 * intermediate_size, hidden_size) - w31 = gen_tensor(w31_shape, otype, scale=0.1) - w1_n = w31_shape[1] + w1_shape = (num_experts, intermediate_size, hidden_size) + w3_shape = w1_shape + w1 = gen_tensor(w1_shape, otype, scale=0.1) w2 = gen_tensor((num_experts, hidden_size, intermediate_size), otype, scale=0.1) - w31_d = torch.empty((num_experts, w1_n, hidden_size), device="cuda", dtype=otype) - w2_d = torch.empty( - (num_experts, hidden_size, intermediate_size), device="cuda", dtype=otype - ) + w3 = gen_tensor(w3_shape, otype, scale=0.1) router_logits = torch.randn(batch_size, num_experts, dtype=otype).cuda() - return x, w31, w2, w31_d, w2_d, router_logits + return x, w1, w2, w3, router_logits - def _quantize_weights(w31, w2): + def _quantize_weights(w1, w2, w3): def round_up(x, y): return (x + y - 1) // y * y - w1_n = w31.shape[1] - sf_w1_2n = round_up(w1_n, 128) + w1_n = w1.shape[1] + w3_n = w3.shape[1] + sf_w1_n = round_up(w1_n, 128) + sf_w3_n = round_up(w3_n, 128) sf_w1_k = round_up(hidden_size // NVFP4_BLOCK_SIZE, 4) - w31_blockscale = torch.empty( - (num_experts, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + w1_blockscale = torch.empty( + (num_experts, sf_w1_n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn ) sf_w2_k = round_up(hidden_size, 128) sf_w2_n = round_up(intermediate_size // NVFP4_BLOCK_SIZE, 4) w2_blockscale = torch.empty( (num_experts, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn ) - w31_q = torch.empty((num_experts, w1_n, hidden_size // 2), device="cuda", dtype=torch.uint8) + w3_blockscale = torch.empty( + (num_experts, sf_w3_n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn + ) + w1_q = torch.empty((num_experts, w1_n, hidden_size // 2), device="cuda", dtype=torch.uint8) w2_q = torch.empty( (num_experts, hidden_size, intermediate_size // 2), device="cuda", dtype=torch.uint8 ) + w3_q = torch.empty((num_experts, w3_n, hidden_size // 2), device="cuda", dtype=torch.uint8) w1_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) w2_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) + w3_gs = torch.empty((num_experts,), device="cuda", dtype=torch.float32) for expert in range(num_experts): - w31_amax = torch.abs(w31[expert]).max().to(torch.float32) + w1_amax = torch.abs(w1[expert]).max().to(torch.float32) w2_amax = torch.abs(w2[expert]).max().to(torch.float32) - w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w31_amax + w3_amax = torch.abs(w3[expert]).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + w3_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w3_amax nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( - w31[expert], w1_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + w1[expert], w1_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True ) - w31_q[expert] = nvfp4_vals - w31_blockscale[expert] = fp8_block_scales.reshape(w31_blockscale[expert].shape) + w1_q[expert] = nvfp4_vals + w1_blockscale[expert] = fp8_block_scales.reshape(w1_blockscale[expert].shape) nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( w2[expert], w2_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True @@ -649,86 +664,128 @@ def round_up(x, y): w2_q[expert] = nvfp4_vals w2_blockscale[expert] = fp8_block_scales.reshape(w2_blockscale[expert].shape) - return w31_q, w31_blockscale, w2_q, w2_blockscale, w1_gs, w2_gs + nvfp4_vals, fp8_block_scales = torch.ops.trtllm.fp4_quantize( + w3[expert], w3_gs[expert], NVFP4_BLOCK_SIZE, isSfSwizzledLayout=True + ) + w3_q[expert] = nvfp4_vals + w3_blockscale[expert] = fp8_block_scales.reshape(w3_blockscale[expert].shape) - x, w31, w2, w31_d, w2_d, router_logits = _get_test_data( - otype, wtype, batch_size, hidden_size, num_experts, intermediate_size + return w1_q, w2_q, w3_q, w1_blockscale, w2_blockscale, w3_blockscale, w1_gs, w2_gs, w3_gs + + x, w1, w2, w3, router_logits = _get_test_data( + otype, batch_size, hidden_size, num_experts, intermediate_size ) - w31_q, w31_blockscale, w2_q, w2_blockscale, w31_gs, w2_gs = _quantize_weights(w31, w2) - a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) - a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + ( + w1_q_fp4, + w2_q_fp4, + w3_q_fp4, + w1_blockscale, + w2_blockscale, + w3_blockscale, + w1_gs, + w2_gs, + w3_gs, + ) = _quantize_weights(w1, w2, w3) + + fc1_act_global = torch.tensor(1.0, device="cuda", dtype=torch.float32) + fc2_act_global = torch.tensor(1.0, device="cuda", dtype=torch.float32) routing_weights, selected_experts = compute_routing(router_logits, top_k) - w3_weight_q, w1_weight_q = torch.chunk(w31_q, 2, dim=1) - w2_weight_q = w2_q - w3_blockscale, w1_blockscale = torch.chunk(w31_blockscale, 2, dim=1) + if True: + fc1_weight_gs = torch.max(w3_gs, w1_gs) + fc1_global = 1.0 / (fc1_act_global * fc1_weight_gs) + fc2_global = 1.0 / (fc2_act_global * w2_gs) + else: + fc1_global = None + fc2_global = None - fc1_act_global = a1_gs - fc2_act_global = a2_gs - w1_weight_gs = w31_gs - w2_weight_gs = w2_gs - w3_weight_gs = w31_gs mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( x, selected_experts.to(torch.int), routing_weights, - w1_weight_q, - w2_weight_q, - w3_weight_q, - w1_weight_gs, - w2_weight_gs, - w3_weight_gs, + w1_q_fp4, + w2_q_fp4, + w3_q_fp4, + w1_gs, + w2_gs, + w3_gs, w1_blockscale, w2_blockscale, w3_blockscale, fc1_act_global, fc2_act_global, - input_sf=None, + fc1_global=fc1_global, + fc2_global=fc2_global, + input_blockscale=None, output_dtype=otype, mlp_style=mlp_style, act_fn=activation_func, ) - # Ref check - a_fp4, a_scale_interleaved = torch.ops.trtllm.fp4_quantize(x, a1_gs, NVFP4_BLOCK_SIZE) - _, m_k = a_fp4.shape - a_in_dtype = dequantize_nvfp4_to_dtype( - a_fp4, - a_scale_interleaved, - a1_gs, - dtype=otype, - device=x.device, - block_size=NVFP4_BLOCK_SIZE, - ) - for idx in range(0, num_experts): - w31_d[idx] = dequantize_nvfp4_to_dtype( - w31_q[idx], - w31_blockscale[idx], - w31_gs[idx], - dtype=w31.dtype, - device=w31.device, - block_size=NVFP4_BLOCK_SIZE, + def compute_ref_output(w1_gs, w3_gs): + # Quantize then dequantize the input to emulate the precision loss. + a_fp4, a_scale_interleaved = torch.ops.trtllm.fp4_quantize( + x, fc1_act_global, NVFP4_BLOCK_SIZE ) - w2_d[idx] = dequantize_nvfp4_to_dtype( - w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=w2.dtype, - device=w2.device, + x_dq = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + fc1_act_global, + dtype=otype, + device=x.device, block_size=NVFP4_BLOCK_SIZE, ) - ref_output = torch_moe_nvfp4( - a_in_dtype, - w31_d, - w2_d, - top_k, - routing_weights, - selected_experts, - _activation_type_from_str(activation_func), - ) + concat_w3_w1 = mlp_style == "gated_mlp" + if concat_w3_w1: + w1_gs = w3_gs = torch.max(w1_gs, w3_gs) + + w1_dq = torch.empty(w1.shape, device="cuda", dtype=otype) + w3_dq = torch.empty(w3.shape, device="cuda", dtype=otype) + w2_dq = torch.empty(w2.shape, device="cuda", dtype=otype) + + # Dequantize the weights to emulate the precision loss. + for idx in range(0, num_experts): + w1_dq[idx] = dequantize_nvfp4_to_dtype( + w1_q_fp4[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=w1.dtype, + device=w1.device, + block_size=NVFP4_BLOCK_SIZE, + ) + w2_dq[idx] = dequantize_nvfp4_to_dtype( + w2_q_fp4[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=NVFP4_BLOCK_SIZE, + ) + w3_dq[idx] = dequantize_nvfp4_to_dtype( + w3_q_fp4[idx], + w3_blockscale[idx], + w3_gs[idx], + dtype=w3.dtype, + device=w3.device, + block_size=NVFP4_BLOCK_SIZE, + ) + + ref_output = torch_moe_nvfp4( + x_dq, + torch.cat([w3_dq, w1_dq], dim=1) if concat_w3_w1 else w1_dq, + w2_dq, + top_k, + routing_weights, + selected_experts, + _activation_type_from_str(activation_func), + ) + return ref_output + + ref_output = compute_ref_output(w1_gs, w3_gs) print(f"max diff: {(ref_output - trtllm_output).abs().max()}") print(f"diff = {ref_output - trtllm_output}") print(f"ref_output = {ref_output}") From 7fa76f2ba56f09857d1b3427ca8c18d04afc6445 Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Sun, 30 Nov 2025 05:30:43 -0800 Subject: [PATCH 05/11] Address review comments from tcherckez Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 14 ++++++------ .../singlegpu/custom_ops/test_trtllm_moe.py | 22 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 96591429f1b..6cdbb36753c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -273,14 +273,14 @@ def trtllm_quant_nvfp4_moe_fused( w3_blockscale_fp8: torch.Tensor, # Block scale for w3 (fp8 ) fc1_act_global: torch.Tensor, # Global scale for FC1 activations fc2_act_global: torch.Tensor, # Global scale for FC2 activations - fc1_global: Optional[ + fc1_alpha: Optional[ torch.Tensor ] = None, # Precomputed global scale for FC1 (1.0 / (fc1_act_global * fc1_weight_gs)) - fc2_global: Optional[ + fc2_alpha: Optional[ torch.Tensor ] = None, # Precomputed global scale for FC2 (1.0 / (fc2_act_global * fc2_weight_gs)) input_blockscale: Optional[torch.Tensor] = None, # Input scale factors for NVFP4 input - output_dtype: Optional[torch.dtype] = None, # Output dtype for NVFP4 input + output_dtype: Optional[torch.dtype] = None, # determines output dtype when input is NVFP4 mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: @@ -322,16 +322,16 @@ def trtllm_quant_nvfp4_moe_fused( fc2_weight_block_scale = w2_blockscale_fp8 fc2_weight_gs = w2_global_scale - fc1_global = 1.0 / (fc1_act_global * fc1_weight_gs) if fc1_global is None else fc1_global - fc2_global = 1.0 / (fc2_act_global * fc2_weight_gs) if fc2_global is None else fc2_global + fc1_alpha = 1.0 / (fc1_act_global * fc1_weight_gs) if fc1_alpha is None else fc1_alpha + fc2_alpha = 1.0 / (fc2_act_global * fc2_weight_gs) if fc2_alpha is None else fc2_alpha quant_scales = [ fc1_act_global, fc1_weight_blockscale.view(torch.int32), - fc1_global, + fc1_alpha, fc2_act_global, fc2_weight_block_scale.view(torch.int32), - fc2_global, + fc2_alpha, ] if x.dtype in (torch.float16, torch.bfloat16): diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index ad3ba40c7d8..7bf09cec6d4 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -229,9 +229,6 @@ def test_trtllm_fused_moe( activation_func=activation_func, ) - torch.cuda.synchronize() - print("before fused_moe.cutlass_fused_moe") - assert itype == torch.bfloat16 or itype == torch.float16, ( "F16 test only supports bfloat16 or float16" ) @@ -256,6 +253,7 @@ def get_fc1_expert_weights( _, w1_weight = torch.chunk(w31_weight, 2, dim=1) mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + torch.cuda.synchronize() ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused( x, selected_experts.to(torch.int), @@ -500,7 +498,7 @@ def act(weight, mask): inter_gs, dtype=inter.dtype, device=inter.device, - block_size=16, + block_size=NVFP4_BLOCK_SIZE, ).cuda() out[mask] = inter @ w2[i].transpose(0, 1) return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) @@ -565,6 +563,7 @@ def break_fp4_bytes(a, dtype): ] +@pytest.mark.parametrize("precompute_fc_alphas", [True, False]) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @@ -579,6 +578,7 @@ def break_fp4_bytes(a, dtype): reason="Requires fp4 and trtllm support", ) def test_trtllm_fused_moe_nvfp4( + precompute_fc_alphas, batch_size, hidden_size, num_experts, @@ -693,13 +693,13 @@ def round_up(x, y): routing_weights, selected_experts = compute_routing(router_logits, top_k) - if True: + if precompute_fc_alphas: fc1_weight_gs = torch.max(w3_gs, w1_gs) - fc1_global = 1.0 / (fc1_act_global * fc1_weight_gs) - fc2_global = 1.0 / (fc2_act_global * w2_gs) + fc1_alpha = 1.0 / (fc1_act_global * fc1_weight_gs) + fc2_alpha = 1.0 / (fc2_act_global * w2_gs) else: - fc1_global = None - fc2_global = None + fc1_alpha = None + fc2_alpha = None mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( @@ -717,8 +717,8 @@ def round_up(x, y): w3_blockscale, fc1_act_global, fc2_act_global, - fc1_global=fc1_global, - fc2_global=fc2_global, + fc1_alpha=fc1_alpha, + fc2_alpha=fc2_alpha, input_blockscale=None, output_dtype=otype, mlp_style=mlp_style, From f27fca4a0194e501398568dc26119c30a863e33c Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Sun, 30 Nov 2025 05:34:32 -0800 Subject: [PATCH 06/11] Address review comments Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 6cdbb36753c..c293f0a246b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -364,20 +364,20 @@ def trtllm_quant_nvfp4_moe_fused_fake( x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, - w1_weight_q: torch.Tensor, - w2_weight_q: torch.Tensor, - w3_weight_q: torch.Tensor, - w1_weight_gs: torch.Tensor, - w2_weight_gs: torch.Tensor, - w3_weight_gs: torch.Tensor, - w1_blockscale: torch.Tensor, - w2_blockscale: torch.Tensor, - w3_blockscale: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w3_fp4: torch.Tensor, + w1_global_scale: torch.Tensor, + w2_global_scale: torch.Tensor, + w3_global_scale: torch.Tensor, + w1_blockscale_fp8: torch.Tensor, + w2_blockscale_fp8: torch.Tensor, + w3_blockscale_fp8: torch.Tensor, fc1_act_global: torch.Tensor, fc2_act_global: torch.Tensor, - fc1_global: Optional[torch.Tensor] = None, - fc2_global: Optional[torch.Tensor] = None, - input_sf: Optional[torch.Tensor] = None, + fc1_alpha: Optional[torch.Tensor] = None, + fc2_alpha: Optional[torch.Tensor] = None, + input_blockscale: Optional[torch.Tensor] = None, output_dtype: Optional[torch.dtype] = None, mlp_style: str = "gated_mlp", act_fn: str = "silu", From ed0ea8a93bf8ba357ca507899591b5ac98612f83 Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Sun, 30 Nov 2025 05:38:40 -0800 Subject: [PATCH 07/11] Address review comments Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 14 +++++++------- .../unit/singlegpu/custom_ops/test_trtllm_moe.py | 16 ++++++++-------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index c293f0a246b..b3cb4fac52b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -271,8 +271,8 @@ def trtllm_quant_nvfp4_moe_fused( w1_blockscale_fp8: torch.Tensor, # Block scale for w1 (fp8 ) w2_blockscale_fp8: torch.Tensor, # Block scale for w2 (fp8 ) w3_blockscale_fp8: torch.Tensor, # Block scale for w3 (fp8 ) - fc1_act_global: torch.Tensor, # Global scale for FC1 activations - fc2_act_global: torch.Tensor, # Global scale for FC2 activations + fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations + fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations fc1_alpha: Optional[ torch.Tensor ] = None, # Precomputed global scale for FC1 (1.0 / (fc1_act_global * fc1_weight_gs)) @@ -322,21 +322,21 @@ def trtllm_quant_nvfp4_moe_fused( fc2_weight_block_scale = w2_blockscale_fp8 fc2_weight_gs = w2_global_scale - fc1_alpha = 1.0 / (fc1_act_global * fc1_weight_gs) if fc1_alpha is None else fc1_alpha - fc2_alpha = 1.0 / (fc2_act_global * fc2_weight_gs) if fc2_alpha is None else fc2_alpha + fc1_alpha = 1.0 / (fc1_act_global_scale * fc1_weight_gs) if fc1_alpha is None else fc1_alpha + fc2_alpha = 1.0 / (fc2_act_global_scale * fc2_weight_gs) if fc2_alpha is None else fc2_alpha quant_scales = [ - fc1_act_global, + fc1_act_global_scale, fc1_weight_blockscale.view(torch.int32), fc1_alpha, - fc2_act_global, + fc2_act_global_scale, fc2_weight_block_scale.view(torch.int32), fc2_alpha, ] if x.dtype in (torch.float16, torch.bfloat16): x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize( - x, fc1_act_global, NVFP4_BLOCK_SIZE + x, fc1_act_global_scale, NVFP4_BLOCK_SIZE ) output_dtype = x.dtype else: diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 7bf09cec6d4..b6b8de07575 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -688,15 +688,15 @@ def round_up(x, y): w3_gs, ) = _quantize_weights(w1, w2, w3) - fc1_act_global = torch.tensor(1.0, device="cuda", dtype=torch.float32) - fc2_act_global = torch.tensor(1.0, device="cuda", dtype=torch.float32) + fc1_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) + fc2_activation_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32) routing_weights, selected_experts = compute_routing(router_logits, top_k) if precompute_fc_alphas: fc1_weight_gs = torch.max(w3_gs, w1_gs) - fc1_alpha = 1.0 / (fc1_act_global * fc1_weight_gs) - fc2_alpha = 1.0 / (fc2_act_global * w2_gs) + fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs) + fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs) else: fc1_alpha = None fc2_alpha = None @@ -715,8 +715,8 @@ def round_up(x, y): w1_blockscale, w2_blockscale, w3_blockscale, - fc1_act_global, - fc2_act_global, + fc1_activation_gs, + fc2_activation_gs, fc1_alpha=fc1_alpha, fc2_alpha=fc2_alpha, input_blockscale=None, @@ -728,12 +728,12 @@ def round_up(x, y): def compute_ref_output(w1_gs, w3_gs): # Quantize then dequantize the input to emulate the precision loss. a_fp4, a_scale_interleaved = torch.ops.trtllm.fp4_quantize( - x, fc1_act_global, NVFP4_BLOCK_SIZE + x, fc1_activation_gs, NVFP4_BLOCK_SIZE ) x_dq = dequantize_nvfp4_to_dtype( a_fp4, a_scale_interleaved, - fc1_act_global, + fc1_activation_gs, dtype=otype, device=x.device, block_size=NVFP4_BLOCK_SIZE, From d1563170f7ac35c538b0ee93d3e04c69860e752c Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Mon, 1 Dec 2025 05:53:13 -0800 Subject: [PATCH 08/11] fix typo Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index b3cb4fac52b..9629d8df012 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -288,9 +288,9 @@ def trtllm_quant_nvfp4_moe_fused( Computes (per expert): For gated_mlp: - y = act(x @ w1.T) @ (x @ w3.T) @ w2.T # act := SiLU + y = (act(x @ w1.T) * (x @ w3.T)) @ w2.T # act := SiLU For mlp: - y = act(x @ w1.T) @ w2.T # act := ReLU^2 + y = act(x @ w1.T) @ w2.T # act := ReLU^2 """ NVFP4_BLOCK_SIZE = 16 From 49078599cb50ea3fc3f9d2984203da9fee448126 Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Tue, 2 Dec 2025 01:46:48 -0800 Subject: [PATCH 09/11] Change nvfp4 moe kernel interface Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/trtllm_moe.py | 86 +++++++------------ .../singlegpu/custom_ops/test_trtllm_moe.py | 53 +++++++----- 2 files changed, 62 insertions(+), 77 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 9629d8df012..827d47c44ae 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -14,8 +14,6 @@ # limitations under the License. -from typing import Optional - import torch from tensorrt_llm._torch.utils import ActivationType @@ -262,25 +260,14 @@ def trtllm_quant_nvfp4_moe_fused( x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float selected_experts: torch.Tensor, routing_weights: torch.Tensor, - w1_fp4: torch.Tensor, # [E, I, H] stacked FP4 weights (uint8) - w2_fp4: torch.Tensor, # [E, H, I] stacked FP4 weights (uint8) - w3_fp4: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp (uint8) - w1_global_scale: torch.Tensor, # Global scale for w1 (scalar) - w2_global_scale: torch.Tensor, # Global scale for w2 (scalar) - w3_global_scale: torch.Tensor, # Global scale for w3 (scalar) - w1_blockscale_fp8: torch.Tensor, # Block scale for w1 (fp8 ) - w2_blockscale_fp8: torch.Tensor, # Block scale for w2 (fp8 ) - w3_blockscale_fp8: torch.Tensor, # Block scale for w3 (fp8 ) + fc1_expert_weights_fp4: torch.Tensor, # [E, 2*I, H] or [E, I, H]; uint8 + fc2_expert_weights_fp4: torch.Tensor, # [E, H, I]; uint8 + fc1_weight_blockscale_fp8: torch.Tensor, # Global scale for fc1 (scalar) + fc2_weight_blockscale_fp8: torch.Tensor, # Global scale for w2 (scalar) fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations - fc1_alpha: Optional[ - torch.Tensor - ] = None, # Precomputed global scale for FC1 (1.0 / (fc1_act_global * fc1_weight_gs)) - fc2_alpha: Optional[ - torch.Tensor - ] = None, # Precomputed global scale for FC2 (1.0 / (fc2_act_global * fc2_weight_gs)) - input_blockscale: Optional[torch.Tensor] = None, # Input scale factors for NVFP4 input - output_dtype: Optional[torch.dtype] = None, # determines output dtype when input is NVFP4 + fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8)) + fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8)) mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: @@ -292,6 +279,10 @@ def trtllm_quant_nvfp4_moe_fused( For mlp: y = act(x @ w1.T) @ w2.T # act := ReLU^2 + + FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T) + FC2 implements: fc2_output = fc1_output @ w2.T + """ NVFP4_BLOCK_SIZE = 16 mlp_style = mlp_style.lower() @@ -299,20 +290,11 @@ def trtllm_quant_nvfp4_moe_fused( activation_type = ActivationType.Swiglu if mlp_style == "gated_mlp": - # For gated MLP, concatenate w1 and w3 as [w3, w1] - w3_w1_stacked = torch.cat([w3_fp4, w1_fp4], dim=1).contiguous().view(torch.long) - fc1_expert_weights = w3_w1_stacked - fc1_weight_blockscale = torch.cat([w3_blockscale_fp8, w1_blockscale_fp8], dim=1) - fc1_weight_gs = torch.max(w3_global_scale, w1_global_scale) if act_fn == "silu": activation_type = ActivationType.Swiglu else: raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") elif mlp_style == "mlp": - # For non-gated MLP with ReLU^2 - fc1_expert_weights = w1_fp4.view(torch.long) - fc1_weight_blockscale = w1_blockscale_fp8.view(torch.long) - fc1_weight_gs = w1_global_scale if act_fn == "relu2": activation_type = ActivationType.Relu2 else: @@ -320,18 +302,17 @@ def trtllm_quant_nvfp4_moe_fused( else: raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - fc2_weight_block_scale = w2_blockscale_fp8 - fc2_weight_gs = w2_global_scale - fc1_alpha = 1.0 / (fc1_act_global_scale * fc1_weight_gs) if fc1_alpha is None else fc1_alpha - fc2_alpha = 1.0 / (fc2_act_global_scale * fc2_weight_gs) if fc2_alpha is None else fc2_alpha - + # quant_scales is described by this code: + # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 quant_scales = [ - fc1_act_global_scale, - fc1_weight_blockscale.view(torch.int32), - fc1_alpha, - fc2_act_global_scale, - fc2_weight_block_scale.view(torch.int32), - fc2_alpha, + fc1_act_global_scale, # torch.float32; [E] or scalar + fc1_weight_blockscale_fp8.view( + torch.int32 + ), # 4 FP8 as packed int32; [E, I*2, H / 16 / 4] or [E, I, H / 16 / 4] + fc1_alpha, # torch.float32; [E] + fc2_act_global_scale, # torch.float32; [E] or scalar + fc2_weight_blockscale_fp8.view(torch.int32), # 4 FP8 as packed int32; [E, H, I / 16 / 4] + fc2_alpha, # torch.float32; [E] ] if x.dtype in (torch.float16, torch.bfloat16): @@ -346,9 +327,9 @@ def trtllm_quant_nvfp4_moe_fused( x_q_fp4, selected_experts.to(torch.int), routing_weights, - fc1_expert_weights=fc1_expert_weights, + fc1_expert_weights=fc1_expert_weights_fp4, fc1_expert_biases=None, - fc2_expert_weights=w2_fp4.view(torch.long), + fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long), fc2_expert_biases=None, output_dtype=output_dtype, quant_scales=quant_scales, @@ -364,21 +345,14 @@ def trtllm_quant_nvfp4_moe_fused_fake( x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, - w1_fp4: torch.Tensor, - w2_fp4: torch.Tensor, - w3_fp4: torch.Tensor, - w1_global_scale: torch.Tensor, - w2_global_scale: torch.Tensor, - w3_global_scale: torch.Tensor, - w1_blockscale_fp8: torch.Tensor, - w2_blockscale_fp8: torch.Tensor, - w3_blockscale_fp8: torch.Tensor, - fc1_act_global: torch.Tensor, - fc2_act_global: torch.Tensor, - fc1_alpha: Optional[torch.Tensor] = None, - fc2_alpha: Optional[torch.Tensor] = None, - input_blockscale: Optional[torch.Tensor] = None, - output_dtype: Optional[torch.dtype] = None, + fc1_expert_weights_fp4: torch.Tensor, + fc2_expert_weights_fp4: torch.Tensor, + fc1_weight_blockscale_fp8: torch.Tensor, + fc2_weight_blockscale_fp8: torch.Tensor, + fc1_act_global_scale: torch.Tensor, + fc2_act_global_scale: torch.Tensor, + fc1_alpha: torch.Tensor, + fc2_alpha: torch.Tensor, mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index b6b8de07575..2eeb66dd5b7 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -563,7 +563,6 @@ def break_fp4_bytes(a, dtype): ] -@pytest.mark.parametrize("precompute_fc_alphas", [True, False]) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @@ -578,7 +577,6 @@ def break_fp4_bytes(a, dtype): reason="Requires fp4 and trtllm support", ) def test_trtllm_fused_moe_nvfp4( - precompute_fc_alphas, batch_size, hidden_size, num_experts, @@ -693,34 +691,47 @@ def round_up(x, y): routing_weights, selected_experts = compute_routing(router_logits, top_k) - if precompute_fc_alphas: + fc1_weight_gs = torch.max(w3_gs, w1_gs) + fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs) + fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs) + + mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + if mlp_style == "gated_mlp": + # For gated MLP, concatenate w1 and w3 as [w3, w1] + w3_w1_stacked = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous() + fc1_expert_weights_fp4 = w3_w1_stacked + fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1) fc1_weight_gs = torch.max(w3_gs, w1_gs) - fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs) - fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs) + if activation_func != "silu": + raise ValueError( + f"Unsupported activation '{activation_func}' for gated_mlp. Use 'silu'." + ) + elif mlp_style == "mlp": + # For non-gated MLP with ReLU^2 + fc1_expert_weights_fp4 = w1_q_fp4 + fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long) + fc1_weight_gs = w1_gs + if activation_func != "relu2": + raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.") else: - fc1_alpha = None - fc2_alpha = None + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + + fc2_expert_weights_fp4 = w2_q_fp4.view(torch.long) + fc2_weight_blockscale_fp8 = w2_blockscale.view(torch.long) + fc1_expert_weights_fp4 = fc1_expert_weights_fp4.view(torch.long) - mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" trtllm_output = torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused( x, selected_experts.to(torch.int), routing_weights, - w1_q_fp4, - w2_q_fp4, - w3_q_fp4, - w1_gs, - w2_gs, - w3_gs, - w1_blockscale, - w2_blockscale, - w3_blockscale, + fc1_expert_weights_fp4, + fc2_expert_weights_fp4, + fc1_weight_blockscale_fp8, + fc2_weight_blockscale_fp8, fc1_activation_gs, fc2_activation_gs, - fc1_alpha=fc1_alpha, - fc2_alpha=fc2_alpha, - input_blockscale=None, - output_dtype=otype, + fc1_alpha, + fc2_alpha, mlp_style=mlp_style, act_fn=activation_func, ) From f19f4cb29a5f24d79f1e0b243da4f4617246780f Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Tue, 2 Dec 2025 02:09:02 -0800 Subject: [PATCH 10/11] Enable relu2 tests Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../unit/singlegpu/custom_ops/test_trtllm_moe.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 2eeb66dd5b7..f931613d580 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -569,9 +569,7 @@ def break_fp4_bytes(a, dtype): @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES) -# relu2 support requires merge of https://github.com/NVIDIA/TensorRT-LLM/pull/9261 -# @pytest.mark.parametrize("activation_func", ["silu", "relu2"]) -@pytest.mark.parametrize("activation_func", ["silu"]) +@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) @pytest.mark.skipif( not fp4_compatible() or not trtllm_ops_available(), reason="Requires fp4 and trtllm support", @@ -698,8 +696,7 @@ def round_up(x, y): mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" if mlp_style == "gated_mlp": # For gated MLP, concatenate w1 and w3 as [w3, w1] - w3_w1_stacked = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous() - fc1_expert_weights_fp4 = w3_w1_stacked + fc1_expert_weights_fp4 = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous() fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1) fc1_weight_gs = torch.max(w3_gs, w1_gs) if activation_func != "silu": @@ -709,7 +706,7 @@ def round_up(x, y): elif mlp_style == "mlp": # For non-gated MLP with ReLU^2 fc1_expert_weights_fp4 = w1_q_fp4 - fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long) + fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1) fc1_weight_gs = w1_gs if activation_func != "relu2": raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.") From 7d663c01d8d1385b2a6190a8ce08c0cd5ae72ac0 Mon Sep 17 00:00:00 2001 From: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Date: Tue, 2 Dec 2025 03:27:33 -0800 Subject: [PATCH 11/11] Fix relu2 blockscale after rebase Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> --- .../auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index f931613d580..c9aea8bc607 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -706,7 +706,7 @@ def round_up(x, y): elif mlp_style == "mlp": # For non-gated MLP with ReLU^2 fc1_expert_weights_fp4 = w1_q_fp4 - fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1) + fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long) fc1_weight_gs = w1_gs if activation_func != "relu2": raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.")