Skip to content
121 changes: 121 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch

from tensorrt_llm._torch.utils import ActivationType
Expand Down Expand Up @@ -234,5 +250,110 @@ def trtllm_quant_fp8_moe_fused_fake(
gemm2_dequant: torch.Tensor,
mlp_style: str,
act_fn: str,
) -> torch.Tensor:
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
return torch.empty_like(x)


@torch.library.custom_op("auto_deploy::trtllm_quant_nvfp4_moe_fused", mutates_args=())
def trtllm_quant_nvfp4_moe_fused(
x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
fc1_expert_weights_fp4: torch.Tensor, # [E, 2*I, H] or [E, I, H]; uint8
fc2_expert_weights_fp4: torch.Tensor, # [E, H, I]; uint8
fc1_weight_blockscale_fp8: torch.Tensor, # Global scale for fc1 (scalar)
fc2_weight_blockscale_fp8: torch.Tensor, # Global scale for w2 (scalar)
fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations
fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations
fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
"""TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP.

Computes (per expert):
For gated_mlp:
y = (act(x @ w1.T) * (x @ w3.T)) @ w2.T # act := SiLU
For mlp:
y = act(x @ w1.T) @ w2.T # act := ReLU^2


FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
FC2 implements: fc2_output = fc1_output @ w2.T

"""
NVFP4_BLOCK_SIZE = 16
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()

activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
if act_fn == "silu":
activation_type = ActivationType.Swiglu
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
if act_fn == "relu2":
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")

# quant_scales is described by this code:
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
quant_scales = [
fc1_act_global_scale, # torch.float32; [E] or scalar
fc1_weight_blockscale_fp8.view(
torch.int32
), # 4 FP8 as packed int32; [E, I*2, H / 16 / 4] or [E, I, H / 16 / 4]
fc1_alpha, # torch.float32; [E]
fc2_act_global_scale, # torch.float32; [E] or scalar
fc2_weight_blockscale_fp8.view(torch.int32), # 4 FP8 as packed int32; [E, H, I / 16 / 4]
fc2_alpha, # torch.float32; [E]
]

if x.dtype in (torch.float16, torch.bfloat16):
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
)
output_dtype = x.dtype
else:
x_q_fp4 = x

trtllm_output = torch.ops.trtllm.fused_moe(
x_q_fp4,
selected_experts.to(torch.int),
routing_weights,
fc1_expert_weights=fc1_expert_weights_fp4,
fc1_expert_biases=None,
fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long),
fc2_expert_biases=None,
output_dtype=output_dtype,
quant_scales=quant_scales,
input_sf=input_blockscale,
activation_type=activation_type,
)[0].view(x.shape)

return trtllm_output


@trtllm_quant_nvfp4_moe_fused.register_fake
def trtllm_quant_nvfp4_moe_fused_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
fc1_expert_weights_fp4: torch.Tensor,
fc2_expert_weights_fp4: torch.Tensor,
fc1_weight_blockscale_fp8: torch.Tensor,
fc2_weight_blockscale_fp8: torch.Tensor,
fc1_act_global_scale: torch.Tensor,
fc2_act_global_scale: torch.Tensor,
fc1_alpha: torch.Tensor,
fc2_alpha: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)
Loading