|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | + |
1 | 17 | import torch |
2 | 18 |
|
3 | 19 | from tensorrt_llm._torch.utils import ActivationType |
@@ -234,5 +250,110 @@ def trtllm_quant_fp8_moe_fused_fake( |
234 | 250 | gemm2_dequant: torch.Tensor, |
235 | 251 | mlp_style: str, |
236 | 252 | act_fn: str, |
| 253 | +) -> torch.Tensor: |
| 254 | + _validate_mlp_style_and_act_fn(mlp_style, act_fn) |
| 255 | + return torch.empty_like(x) |
| 256 | + |
| 257 | + |
| 258 | +@torch.library.custom_op("auto_deploy::trtllm_quant_nvfp4_moe_fused", mutates_args=()) |
| 259 | +def trtllm_quant_nvfp4_moe_fused( |
| 260 | + x: torch.Tensor, # [B, S, H] or [B*S, H], 16-bit float |
| 261 | + selected_experts: torch.Tensor, |
| 262 | + routing_weights: torch.Tensor, |
| 263 | + fc1_expert_weights_fp4: torch.Tensor, # [E, 2*I, H] or [E, I, H]; uint8 |
| 264 | + fc2_expert_weights_fp4: torch.Tensor, # [E, H, I]; uint8 |
| 265 | + fc1_weight_blockscale_fp8: torch.Tensor, # Global scale for fc1 (scalar) |
| 266 | + fc2_weight_blockscale_fp8: torch.Tensor, # Global scale for w2 (scalar) |
| 267 | + fc1_act_global_scale: torch.Tensor, # Global scale for FC1 activations |
| 268 | + fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations |
| 269 | + fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8)) |
| 270 | + fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8)) |
| 271 | + mlp_style: str = "gated_mlp", |
| 272 | + act_fn: str = "silu", |
| 273 | +) -> torch.Tensor: |
| 274 | + """TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP. |
| 275 | +
|
| 276 | + Computes (per expert): |
| 277 | + For gated_mlp: |
| 278 | + y = (act(x @ w1.T) * (x @ w3.T)) @ w2.T # act := SiLU |
| 279 | + For mlp: |
| 280 | + y = act(x @ w1.T) @ w2.T # act := ReLU^2 |
| 281 | +
|
| 282 | +
|
| 283 | + FC1 implements: fc1_output = (act(x @ w1.T) * (x @ w3.T)) or fc1_output = act(x @ w1.T) |
| 284 | + FC2 implements: fc2_output = fc1_output @ w2.T |
| 285 | +
|
| 286 | + """ |
| 287 | + NVFP4_BLOCK_SIZE = 16 |
| 288 | + mlp_style = mlp_style.lower() |
| 289 | + act_fn = act_fn.lower() |
| 290 | + |
| 291 | + activation_type = ActivationType.Swiglu |
| 292 | + if mlp_style == "gated_mlp": |
| 293 | + if act_fn == "silu": |
| 294 | + activation_type = ActivationType.Swiglu |
| 295 | + else: |
| 296 | + raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") |
| 297 | + elif mlp_style == "mlp": |
| 298 | + if act_fn == "relu2": |
| 299 | + activation_type = ActivationType.Relu2 |
| 300 | + else: |
| 301 | + raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") |
| 302 | + else: |
| 303 | + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") |
| 304 | + |
| 305 | + # quant_scales is described by this code: |
| 306 | + # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 |
| 307 | + quant_scales = [ |
| 308 | + fc1_act_global_scale, # torch.float32; [E] or scalar |
| 309 | + fc1_weight_blockscale_fp8.view( |
| 310 | + torch.int32 |
| 311 | + ), # 4 FP8 as packed int32; [E, I*2, H / 16 / 4] or [E, I, H / 16 / 4] |
| 312 | + fc1_alpha, # torch.float32; [E] |
| 313 | + fc2_act_global_scale, # torch.float32; [E] or scalar |
| 314 | + fc2_weight_blockscale_fp8.view(torch.int32), # 4 FP8 as packed int32; [E, H, I / 16 / 4] |
| 315 | + fc2_alpha, # torch.float32; [E] |
| 316 | + ] |
| 317 | + |
| 318 | + if x.dtype in (torch.float16, torch.bfloat16): |
| 319 | + x_q_fp4, input_blockscale = torch.ops.trtllm.fp4_quantize( |
| 320 | + x, fc1_act_global_scale, NVFP4_BLOCK_SIZE |
| 321 | + ) |
| 322 | + output_dtype = x.dtype |
| 323 | + else: |
| 324 | + x_q_fp4 = x |
| 325 | + |
| 326 | + trtllm_output = torch.ops.trtllm.fused_moe( |
| 327 | + x_q_fp4, |
| 328 | + selected_experts.to(torch.int), |
| 329 | + routing_weights, |
| 330 | + fc1_expert_weights=fc1_expert_weights_fp4, |
| 331 | + fc1_expert_biases=None, |
| 332 | + fc2_expert_weights=fc2_expert_weights_fp4.view(torch.long), |
| 333 | + fc2_expert_biases=None, |
| 334 | + output_dtype=output_dtype, |
| 335 | + quant_scales=quant_scales, |
| 336 | + input_sf=input_blockscale, |
| 337 | + activation_type=activation_type, |
| 338 | + )[0].view(x.shape) |
| 339 | + |
| 340 | + return trtllm_output |
| 341 | + |
| 342 | + |
| 343 | +@trtllm_quant_nvfp4_moe_fused.register_fake |
| 344 | +def trtllm_quant_nvfp4_moe_fused_fake( |
| 345 | + x: torch.Tensor, |
| 346 | + selected_experts: torch.Tensor, |
| 347 | + routing_weights: torch.Tensor, |
| 348 | + fc1_expert_weights_fp4: torch.Tensor, |
| 349 | + fc2_expert_weights_fp4: torch.Tensor, |
| 350 | + fc1_weight_blockscale_fp8: torch.Tensor, |
| 351 | + fc2_weight_blockscale_fp8: torch.Tensor, |
| 352 | + fc1_act_global_scale: torch.Tensor, |
| 353 | + fc2_act_global_scale: torch.Tensor, |
| 354 | + fc1_alpha: torch.Tensor, |
| 355 | + fc2_alpha: torch.Tensor, |
| 356 | + mlp_style: str = "gated_mlp", |
| 357 | + act_fn: str = "silu", |
237 | 358 | ) -> torch.Tensor: |
238 | 359 | return torch.empty_like(x) |
0 commit comments