Skip to content

Commit a560ba5

Browse files
[#9550][feat] AutoDeploy: Add NVFP4 Cutlass MoE kernels (#9551)
Signed-off-by: Neta Zmora <[email protected]>
1 parent 227d42e commit a560ba5

File tree

2 files changed

+472
-7
lines changed

2 files changed

+472
-7
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
117
import torch
218

319
from tensorrt_llm._torch.utils import ActivationType
@@ -234,5 +250,110 @@ def trtllm_quant_fp8_moe_fused_fake(
234250
gemm2_dequant: torch.Tensor,
235251
mlp_style: str,
236252
act_fn: str,
253+
) -> torch.Tensor:
254+
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
255+
return torch.empty_like(x)
256+
257+
258+
@torch.library.custom_op("auto_deploy::trtllm_quant_nvfp4_moe_fused", mutates_args=())
259+
def trtllm_quant_nvfp4_moe_fused(
260+
x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float
261+
selected_experts: torch.Tensor,
262+
routing_weights: torch.Tensor,
263+
fc1_expert_weights_fp4: torch.Tensor, # [E, 2*I, H] or [E, I, H]; uint8
264+
fc2_expert_weights_fp4: torch.Tensor, # [E, H, I]; uint8
265+
fc1_weight_blockscale_fp8: torch.Tensor, # Global scale for fc1 (scalar)
266+
fc2_weight_blockscale_fp8: torch.Tensor, # Global scale for w2 (scalar)
267+
fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations
268+
fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations
269+
fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
270+
fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
271+
mlp_style: str = "gated_mlp",
272+
act_fn: str = "silu",
273+
) -> torch.Tensor:
274+
"""TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP.
275+
276+
Computes (per expert):
277+
For gated_mlp:
278+
y = (act(x @ w1.T) * (x @ w3.T)) @ w2.T # act := SiLU
279+
For mlp:
280+
y = act(x @ w1.T) @ w2.T # act := ReLU^2
281+
282+
283+
FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T)
284+
FC2 implements: fc2_output = fc1_output @ w2.T
285+
286+
"""
287+
NVFP4_BLOCK_SIZE = 16
288+
mlp_style = mlp_style.lower()
289+
act_fn = act_fn.lower()
290+
291+
activation_type = ActivationType.Swiglu
292+
if mlp_style == "gated_mlp":
293+
if act_fn == "silu":
294+
activation_type = ActivationType.Swiglu
295+
else:
296+
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
297+
elif mlp_style == "mlp":
298+
if act_fn == "relu2":
299+
activation_type = ActivationType.Relu2
300+
else:
301+
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
302+
else:
303+
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
304+
305+
# quant_scales is described by this code:
306+
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
307+
quant_scales = [
308+
fc1_act_global_scale, # torch.float32; [E] or scalar
309+
fc1_weight_blockscale_fp8.view(
310+
torch.int32
311+
), # 4 FP8 as packed int32; [E, I*2, H / 16 / 4] or [E, I, H / 16 / 4]
312+
fc1_alpha, # torch.float32; [E]
313+
fc2_act_global_scale, # torch.float32; [E] or scalar
314+
fc2_weight_blockscale_fp8.view(torch.int32), # 4 FP8 as packed int32; [E, H, I / 16 / 4]
315+
fc2_alpha, # torch.float32; [E]
316+
]
317+
318+
if x.dtype in (torch.float16, torch.bfloat16):
319+
x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize(
320+
x, fc1_act_global_scale, NVFP4_BLOCK_SIZE
321+
)
322+
output_dtype = x.dtype
323+
else:
324+
x_q_fp4 = x
325+
326+
trtllm_output = torch.ops.trtllm.fused_moe(
327+
x_q_fp4,
328+
selected_experts.to(torch.int),
329+
routing_weights,
330+
fc1_expert_weights=fc1_expert_weights_fp4,
331+
fc1_expert_biases=None,
332+
fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long),
333+
fc2_expert_biases=None,
334+
output_dtype=output_dtype,
335+
quant_scales=quant_scales,
336+
input_sf=input_blockscale,
337+
activation_type=activation_type,
338+
)[0].view(x.shape)
339+
340+
return trtllm_output
341+
342+
343+
@trtllm_quant_nvfp4_moe_fused.register_fake
344+
def trtllm_quant_nvfp4_moe_fused_fake(
345+
x: torch.Tensor,
346+
selected_experts: torch.Tensor,
347+
routing_weights: torch.Tensor,
348+
fc1_expert_weights_fp4: torch.Tensor,
349+
fc2_expert_weights_fp4: torch.Tensor,
350+
fc1_weight_blockscale_fp8: torch.Tensor,
351+
fc2_weight_blockscale_fp8: torch.Tensor,
352+
fc1_act_global_scale: torch.Tensor,
353+
fc2_act_global_scale: torch.Tensor,
354+
fc1_alpha: torch.Tensor,
355+
fc2_alpha: torch.Tensor,
356+
mlp_style: str = "gated_mlp",
357+
act_fn: str = "silu",
237358
) -> torch.Tensor:
238359
return torch.empty_like(x)

0 commit comments

Comments
 (0)