diff --git a/examples/auto_deploy/nano_v3.yaml b/examples/auto_deploy/nano_v3.yaml index 411037cc175..9d9acf6ef7f 100644 --- a/examples/auto_deploy/nano_v3.yaml +++ b/examples/auto_deploy/nano_v3.yaml @@ -15,6 +15,9 @@ transforms: detect_sharding: sharding_source: ['factory', 'heuristic'] sharding_dims: ['ep', 'bmm'] + multi_stream_moe: + stage: compile + enabled: true # tunable mamba cache dtype # --> use float32 for accuracy and default (null) for speed insert_cached_ssm_attention: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 62e7b36dd94..8b130d98744 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -67,7 +67,9 @@ def trtllm_moe_fused_fake( return torch.empty_like(x) -# Todo: refactor this repeating code block +# NOTE(suyogg): If compile ever fails because of this, just write a triton kernel +# for this function and use it as a custom op. +@torch.compile def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear).""" FP8_MIN = torch.finfo(torch.float8_e4m3fn).min @@ -107,6 +109,9 @@ def trtllm_quant_fp8_moe_fused( w1_weight_scale: torch.Tensor, # [E] stacked weight scales w2_weight_scale: torch.Tensor, # [E] stacked weight scales w3_weight_scale: torch.Tensor, # [E] or unused + gemm1_dequant: torch.Tensor, # [E] + gemm2_act_quant: torch.Tensor, # [E] + gemm2_dequant: torch.Tensor, # [E] mlp_style: str = "gated_mlp", act_fn: str = "silu", ) -> torch.Tensor: @@ -125,6 +130,9 @@ def trtllm_quant_fp8_moe_fused( w1_weight_scale: Weight scales for w1 [E] w2_weight_scale: Weight scales for w2 [E] w3_weight_scale: Weight scales for w3 [E] + gemm1_dequant: Precomputed gemm1 dequant scale [E] + gemm2_act_quant: Precomputed gemm2 act quant scale [1] + gemm2_dequant: Precomputed gemm2 dequant scale [E] mlp_style: "gated_mlp" or "mlp" act_fn: "silu" for gated_mlp, "relu2" for mlp @@ -144,28 +152,20 @@ def trtllm_quant_fp8_moe_fused( x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0]) # Scales are stored in float32 - w1_weight_scale = w1_weight_scale.to(torch.float32) - w2_weight_scale = w2_weight_scale.to(torch.float32) - w1_input_scale = w1_input_scale.to(torch.float32)[0] - w2_input_scale = w2_input_scale.to(torch.float32)[0] + w1_input_scale = w1_input_scale[0] # Prepare quant_scales for TensorRT-LLM FP8 format: # [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale] # For gated MLP: + # These are precomputed in `fused_moe` transform # - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3) # - gemm2_act_quant_scale: 1 / w2_input_scale # - gemm2_dequant_scale: w2_weight_scale * w2_input_scale # - gemm1_input_dequant_scale: w1_input_scale - # Compute combined scales - gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze() - gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) - gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze() - gemm1_input_dequant = w1_input_scale.contiguous() - assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D" assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D" - quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant] + quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, w1_input_scale] # Ensure contiguous tensors selected_experts = selected_experts.int().contiguous() @@ -229,6 +229,9 @@ def trtllm_quant_fp8_moe_fused_fake( w1_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, + gemm1_dequant: torch.Tensor, + gemm2_act_quant: torch.Tensor, + gemm2_dequant: torch.Tensor, mlp_style: str, act_fn: str, ) -> torch.Tensor: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 33a7eb2a284..5f0a6d429ea 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -204,9 +204,7 @@ def _cuda_cached_causal_conv1d( if y_dec.dim() == 3: y_dec = y_dec.squeeze(-1) - y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_( - y_dec.to(y_flat.dtype) - ) + y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec) # Custom op must not return an alias of any input; return a fresh tensor return y diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 3ab13309009..f8339865abc 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -43,10 +43,6 @@ def _triton_ssm_prepare_metadata( seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) - seq_start = torch.zeros_like(seq_len_sanitized) - if num_seq > 1: - seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0) - # Truncate slot indices to match active sequences slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long) # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch @@ -88,7 +84,6 @@ def _triton_ssm_prepare_metadata( return ( seq_len_sanitized, - seq_start, slot_idx_sanitized, use_initial_states, cu_seqlens, @@ -109,7 +104,6 @@ def _triton_ssm_prepare_metadata_fake( device = slot_idx.device # Always-correct shapes seq_len_fake = torch.empty_like(seq_len_sanitized) - seq_start_fake = torch.empty_like(seq_len_sanitized) slot_idx_fake = torch.empty(num_seq, dtype=torch.long, device=device) use_initial_states_fake = torch.empty(num_seq, dtype=torch.bool, device=device) cu_seqlens_fake = torch.empty(num_seq + 1, dtype=torch.int32, device=device) @@ -142,7 +136,6 @@ def _triton_ssm_prepare_metadata_fake( return ( seq_len_fake, - seq_start_fake, slot_idx_fake, use_initial_states_fake, cu_seqlens_fake, @@ -165,7 +158,6 @@ def _triton_cached_ssm( dt_bias: torch.Tensor, # [num_heads] # METADATA seq_len: torch.Tensor, # [num_seq] - seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] use_initial_states: torch.Tensor, # [num_seq] cu_seqlens: torch.Tensor, # [num_seq + 1] @@ -290,7 +282,6 @@ def _triton_cached_ssm_fake( dt_bias: torch.Tensor, # [num_heads] # METADATA seq_len: torch.Tensor, # [num_seq] - seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] use_initial_states: torch.Tensor, # [num_seq] cu_seqlens: torch.Tensor, # [num_seq + 1] @@ -340,9 +331,9 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns: seq_len, seq_start, slot_idx, use_initial_states, + # Returns: seq_len, slot_idx, use_initial_states, # cu_seqlens, chunk_indices, chunk_offsets, seq_idx_prefill, batch_info_tensor - return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 9 + return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 8 @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py new file mode 100644 index 00000000000..871374155e8 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py @@ -0,0 +1,235 @@ +""" +Custom ops to enable multi-stream execution. +""" + +from __future__ import annotations + +from threading import RLock +from typing import Any, Callable, Dict, Tuple + +import torch + + +class _Singleton(type): + _instances: Dict[type, Any] = {} + _lock = RLock() + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + if cls not in cls._instances: + with cls._lock: + if cls not in cls._instances: # double-checked locking + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +# A singleton that holds the pointers to the cuda streams and events. +# In multi-gpu scenario, each GPU/rank has its own CudaStreamManager. +class CudaStreamManager(metaclass=_Singleton): + AUX_STREAM_NAME = "aux" + MAIN_STREAM_NAME = "main" + + def __init__(self) -> None: + # In case __init__ ever gets called twice, guard against re-init + if hasattr(self, "streams"): + return + + self._lock = RLock() + + # Events needed for stream synchronization + self.events: Dict[str, Any] = { + self.AUX_STREAM_NAME: torch.cuda.Event(), + self.MAIN_STREAM_NAME: torch.cuda.Event(), + } + + # Streams for multi-stream execution + self.aux_stream = torch.cuda.Stream() + self.streams: Dict[str, Any] = { + self.AUX_STREAM_NAME: self.aux_stream, + self.MAIN_STREAM_NAME: torch.cuda.default_stream(), + } + + +cuda_stream_manager = CudaStreamManager() + + +@torch.library.custom_op("auto_deploy::record_event", mutates_args=()) +def record_event(stream_name: str) -> None: + event = cuda_stream_manager.events[stream_name] + event.record() + + +@torch.library.custom_op("auto_deploy::wait_event", mutates_args=()) +def wait_event(event_name: str) -> None: + event = cuda_stream_manager.events[event_name] + event.wait() + + +# skip during compilation +@torch._dynamo.disable +def record_event_wrapper( + fn: Callable, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> torch.Tensor: + output = fn(*args, **kwargs) + torch.ops.auto_deploy.record_event(cuda_stream_manager.MAIN_STREAM_NAME) + return output + + +@torch._dynamo.disable +def aux_stream_wrapper( + fn: Callable, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> torch.Tensor: + stream_name = cuda_stream_manager.AUX_STREAM_NAME + with torch.cuda.stream(cuda_stream_manager.streams[stream_name]): + torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) + output = fn(*args, **kwargs) + torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) + torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) + return output + + +# trtllm bf16 +@torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=()) +def trtllm_moe_fused_aux( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w3_w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): + torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) + output = torch.ops.auto_deploy.trtllm_moe_fused( + x, + selected_experts, + routing_weights, + w3_w1_stacked_weight, + w2_stacked_weight, + mlp_style, + act_fn, + ) + torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) + torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) + return output + + +@trtllm_moe_fused_aux.register_fake +def trtllm_moe_fused_aux_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w3_w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + return torch.empty_like(x) + + +# triton bf16 +@torch.library.custom_op("auto_deploy::triton_moe_fused_aux", mutates_args=()) +def triton_moe_fused_aux( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, +) -> torch.Tensor: + with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): + torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) + output = torch.ops.auto_deploy.triton_moe_fused( + x, + selected_experts, + routing_weights, + w1_stacked_weight, + w2_stacked_weight, + ) + torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) + torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) + return output + + +@triton_moe_fused_aux.register_fake +def triton_moe_fused_aux_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_stacked_weight: torch.Tensor, + w2_stacked_weight: torch.Tensor, +) -> torch.Tensor: + return torch.empty_like(x) + + +# trtllm fp8 +@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused_aux", mutates_args=()) +def trtllm_quant_fp8_moe_fused_aux( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights + w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights + w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp + w1_input_scale: torch.Tensor, # [E] stacked input scales + w2_input_scale: torch.Tensor, # [E] stacked input scales + w3_input_scale: torch.Tensor, # [E] or unused + w1_weight_scale: torch.Tensor, # [E] stacked weight scales + w2_weight_scale: torch.Tensor, # [E] stacked weight scales + w3_weight_scale: torch.Tensor, # [E] or unused + gemm1_dequant: torch.Tensor, # [E] + gemm2_act_quant: torch.Tensor, # [E] + gemm2_dequant: torch.Tensor, # [E] + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): + torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) + output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused( + x, + selected_experts, + routing_weights, + w1_weight, + w2_weight, + w3_weight, + w1_input_scale, + w2_input_scale, + w3_input_scale, + w1_weight_scale, + w2_weight_scale, + w3_weight_scale, + gemm1_dequant, + gemm2_act_quant, + gemm2_dequant, + mlp_style, + act_fn, + ) + torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) + torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) + return output + + +@trtllm_quant_fp8_moe_fused_aux.register_fake +def trtllm_quant_fp8_moe_fused_aux_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, + w2_weight: torch.Tensor, + w3_weight: torch.Tensor, + w1_input_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + w3_input_scale: torch.Tensor, + w1_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + gemm1_dequant: torch.Tensor, + gemm2_act_quant: torch.Tensor, + gemm2_dequant: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py index f4b98d49df0..708449ea732 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py @@ -113,21 +113,20 @@ def triton_rmsnorm_gated( # Flatten to (M, H), ensure last-dim contiguous, and run in fp32 x_shape = x.shape - x2 = x.to(torch.float32).reshape(-1, H) + x2 = x.reshape(-1, H) if x2.stride(-1) != 1: x2 = x2.contiguous() z2 = None if gate is not None: - z2 = gate.to(torch.float32).reshape(-1, H) + z2 = gate.reshape(-1, H) if z2.stride(-1) != 1: z2 = z2.contiguous() - - w = weight.to(torch.float32).contiguous() + assert weight.is_contiguous(), "weight must be contiguous" out2, _, _ = _layer_norm_fwd( x2, - w, + weight, None, # bias eps, z=z2, diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py index fb6a6dafe7d..8248ab209f2 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py @@ -128,6 +128,10 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): topk_indices, topk_weights = self.gate(hidden_states) x_flat = hidden_states.view(-1, hidden_states.shape[-1]) + # NOTE: So far we've seen that the dispatch order in eager code is the same as the node order in the exported graph. + # We dispatch shared expert first so that we can easily fork the execution of the routed experts + # (using the custom op below) to an auxiliary stream. + shared_out = self.shared_experts(residuals) # Check if this is a latent MOE (has fc1_latent_proj and fc2_latent_proj) has_latent_proj = hasattr(self, "fc1_latent_proj") and hasattr(self, "fc2_latent_proj") @@ -151,8 +155,8 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor): # Latent MOE: project back from latent space out_flat = self.fc2_latent_proj(out_flat) - out = out_flat.view(*orig_shape) - out = out + self.shared_experts(residuals) + routed_out = out_flat.view(*orig_shape) + out = shared_out + routed_out return out diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 62328ccf309..c7ea32fb0e2 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -676,20 +676,33 @@ def get_param_or_buffer(target): ) ) - w1_weight_scale_stacked = torch.stack( - [get_param_or_buffer(n.target) for n in w1_weight_scale], dim=0 + w1_weight_scale_stacked = ( + torch.stack([get_param_or_buffer(n.target) for n in w1_weight_scale], dim=0) + .to(torch.float32) + .contiguous() ) - w2_weight_scale_stacked = torch.stack( - [get_param_or_buffer(n.target) for n in w2_weight_scale], dim=0 + w2_weight_scale_stacked = ( + torch.stack([get_param_or_buffer(n.target) for n in w2_weight_scale], dim=0) + .to(torch.float32) + .contiguous() ) w3_weight_scale_stacked = ( - torch.stack([get_param_or_buffer(n.target) for n in w3_weight_scale], dim=0) - if w3_weight_scale - else torch.empty( - 0, device=w1_weight_scale_stacked.device, dtype=w1_weight_scale_stacked.dtype + ( + torch.stack([get_param_or_buffer(n.target) for n in w3_weight_scale], dim=0) + if w3_weight_scale + else torch.empty( + 0, device=w1_weight_scale_stacked.device, dtype=w1_weight_scale_stacked.dtype + ) ) + .to(torch.float32) + .contiguous() + ) + assert torch.all(w1_input_scale_stacked[0] == w1_input_scale_stacked), ( + "All w1 scales should have the same value." + ) + assert torch.all(w2_input_scale_stacked[0] == w2_input_scale_stacked), ( + "All w2 scales should have the same value." ) - # Register stacked tensors as new parameters new_key_w1 = f"quant_moe_w1_stacked_{fused_key_counter}" new_key_w2 = f"quant_moe_w2_stacked_{fused_key_counter}" @@ -729,24 +742,55 @@ def get_param_or_buffer(target): torch.nn.Parameter(w3_weight_scale_stacked, requires_grad=False), ) + additional_nodes = [] + if backend == "trtllm": + # For optimization reasons, we precompute a few additional arguments to the trtllm_quant_fp8_moe_fused op + # to avoid computing them at runtime. + gemm1_dequant = (w1_weight_scale_stacked * w1_input_scale_stacked[0]).squeeze() + gemm2_act_quant = (1.0 / w2_input_scale_stacked[0]).to(torch.float32) + gemm2_dequant = (w2_weight_scale_stacked * w2_input_scale_stacked[0]).squeeze() + + new_key_gemm1_dequant = f"quant_moe_gemm1_dequant_stacked_{fused_key_counter}" + new_key_gemm2_act_quant = f"quant_moe_gemm2_act_quant_stacked_{fused_key_counter}" + new_key_gemm2_dequant = f"quant_moe_gemm2_dequant_stacked_{fused_key_counter}" + gm.register_parameter( + new_key_gemm1_dequant, + torch.nn.Parameter(gemm1_dequant, requires_grad=False), + ) + gm.register_parameter( + new_key_gemm2_act_quant, + torch.nn.Parameter(gemm2_act_quant, requires_grad=False), + ) + gm.register_parameter( + new_key_gemm2_dequant, + torch.nn.Parameter(gemm2_dequant, requires_grad=False), + ) + additional_nodes = [ + new_key_gemm1_dequant, + new_key_gemm2_act_quant, + new_key_gemm2_dequant, + ] + # Create new node with get_attr for stacked parameters with graph.inserting_before(node): + args = ( + hidden_states, + selected_experts, + routing_weights, + graph.get_attr(new_key_w1), + graph.get_attr(new_key_w2), + graph.get_attr(new_key_w3), + graph.get_attr(new_key_w1_input_scale), + graph.get_attr(new_key_w2_input_scale), + graph.get_attr(new_key_w3_input_scale), + graph.get_attr(new_key_w1_weight_scale), + graph.get_attr(new_key_w2_weight_scale), + graph.get_attr(new_key_w3_weight_scale), + ) + additional_args = (graph.get_attr(node) for node in additional_nodes) new_node = graph.call_function( replacement_op, - args=( - hidden_states, - selected_experts, - routing_weights, - graph.get_attr(new_key_w1), - graph.get_attr(new_key_w2), - graph.get_attr(new_key_w3), - graph.get_attr(new_key_w1_input_scale), - graph.get_attr(new_key_w2_input_scale), - graph.get_attr(new_key_w3_input_scale), - graph.get_attr(new_key_w1_weight_scale), - graph.get_attr(new_key_w2_weight_scale), - graph.get_attr(new_key_w3_weight_scale), - ), + args=(*args, *additional_args), kwargs=node.kwargs, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py new file mode 100644 index 00000000000..a0ec07777b3 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py @@ -0,0 +1,80 @@ +"""Transform for multi-stream execution of MoE layers that have shared experts and routed experts.""" + +from typing import Callable, Dict, Tuple + +import torch +from torch.fx import GraphModule + +from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import record_event_wrapper + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _execute_op_in_aux_stream( + gm: GraphModule, op_dict: Dict[Callable, Callable] +) -> Tuple[GraphModule, int]: + graph = gm.graph + num_replaced = 0 + + # Collect targets first to avoid mutating while iterating + target_nodes = [n for n in graph.nodes if is_op(n, op_dict.keys())] + + for n in target_nodes: + target_input_node = None + for input_node in n.all_input_nodes: + if input_node.target == torch.ops.aten.view.default: + target_input_node = input_node + break + + assert target_input_node is not None, f"Target input node not found for node {n}" + with graph.inserting_before(target_input_node): + new_node = graph.call_function( + record_event_wrapper, + args=(target_input_node.target, *target_input_node.args), + kwargs=target_input_node.kwargs, + ) + target_input_node.replace_all_uses_with(new_node) + graph.erase_node(target_input_node) + with graph.inserting_after(n): + new_node = graph.call_function(op_dict[n.target], args=n.args, kwargs=n.kwargs) + n.replace_all_uses_with(new_node) + graph.erase_node(n) + num_replaced += 1 + if num_replaced: + graph.eliminate_dead_code() + graph.lint() + gm.recompile() + + return gm, num_replaced + + +@TransformRegistry.register("multi_stream_moe") +class MultiStreamMOE(BaseTransform): + """Multi-stream execution of MoE layers that have shared experts and routed experts.""" + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + op_dict = { + torch.ops.auto_deploy.trtllm_moe_fused: torch.ops.auto_deploy.trtllm_moe_fused_aux, + torch.ops.auto_deploy.triton_moe_fused: torch.ops.auto_deploy.triton_moe_fused_aux, + torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused: torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused_aux, + } + + gm, num_matches = _execute_op_in_aux_stream(gm, op_dict) + + info = TransformInfo( + skipped=False, + num_matches=num_matches, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 7e8d2bf1bc4..c2959999599 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -168,6 +168,7 @@ def get_default_kwargs(self): # Keep max_batch_size as in the PyTorch test to avoid OOM "max_batch_size": 128, # Model context length is 8K + "enable_chunked_prefill": True, "max_seq_len": 8192, # Set explicitly to match default build_config behavior "max_num_tokens": 8192, diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index 6b9bf92a9f7..af821955d49 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -502,6 +502,12 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): "num_hidden_layers": 2, }, }, + "nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024": { + "llm_models_subdir": "Nemotron-Nano-3-30B-A3.5B-dev-1024", + "model_kwargs": { + "num_hidden_layers": 8, + }, + }, } diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py new file mode 100644 index 00000000000..972cf013a3b --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py @@ -0,0 +1,132 @@ +from typing import Tuple + +import torch +import torch.nn as nn +from torch.fx import GraphModule, Node + +from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import ( + aux_stream_wrapper, + record_event_wrapper, +) +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + + +@torch.library.custom_op("auto_deploy::multi_stream_linear", mutates_args=()) +def multi_stream_linear( + input: torch.Tensor, weight0: torch.Tensor, weight1: torch.Tensor +) -> torch.Tensor: + output = torch.ops.aten.linear(input, weight0) + output = torch.ops.aten.linear(output, weight1) + return output + + +@multi_stream_linear.register_fake +def multi_stream_linear_fake(input, weight0, weight1): + """Fake implementation of multi_stream_linear.""" + output = torch.ops.aten.linear(input, weight0) + return torch.ops.aten.linear(output, weight1) + + +def replace_multi_stream_linear_with_aux_stream_wrapper(gm: GraphModule) -> Tuple[GraphModule, int]: + """Traverse ``gm`` and replace all ``auto_deploy::multi_stream_linear`` ops with ``aux_stream_wrapper``. + + The replacement preserves the original args/kwargs of the node. + After rewriting, the graph is cleaned and recompiled. + + Args: + gm: The FX graph module to transform. + aux_stream_wrapper: A callable to replace the custom op with. + + Returns: + A tuple of (gm, num_replaced) + """ + graph = gm.graph + num_replaced = 0 + + # Collect targets first to avoid mutating while iterating + target_nodes: list[Node] = [] + target_nodes = [n for n in graph.nodes if is_op(n, torch.ops.auto_deploy.multi_stream_linear)] + + for n in target_nodes: + target_input_node = None + for input_node in n.all_input_nodes: + if len(input_node.users) > 1: + target_input_node = input_node + break + if target_input_node is None: + raise ValueError(f"Target input node not found for node {n}") + with graph.inserting_before(target_input_node): + new_node = graph.call_function( + record_event_wrapper, + args=(target_input_node.target, *target_input_node.args), + kwargs=target_input_node.kwargs, + ) + target_input_node.replace_all_uses_with(new_node) + graph.erase_node(target_input_node) + with graph.inserting_after(n): + new_node = graph.call_function( + aux_stream_wrapper, args=(n.target, *n.args), kwargs=n.kwargs + ) + n.replace_all_uses_with(new_node) + graph.erase_node(n) + num_replaced += 1 + + if num_replaced: + graph.eliminate_dead_code() + graph.lint() + gm.recompile() + + return gm, num_replaced + + +class ParallelTwoLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.fc10 = nn.Linear(in_dim, in_dim) + self.fc11 = nn.Linear(in_dim, out_dim) + self.fc2 = nn.Linear(in_dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.relu(x) + y0 = self.fc2(x) + y1 = torch.ops.auto_deploy.multi_stream_linear(x, self.fc10.weight, self.fc11.weight) + return y0 + y1 + + +def test_multi_stream_linear(): + in_dim, out_dim = 128, 256 + + model = ( + nn.Sequential(ParallelTwoLinear(in_dim, out_dim), ParallelTwoLinear(out_dim, out_dim)) + .eval() + .to("cuda") + ) + + # Example input used for export + example_input = torch.randn(4, in_dim).to("cuda") + + # Export the graph + egm = torch.export.export(model, (example_input,)) + gm = egm.module() + + test_x = torch.randn(4, in_dim).to("cuda") + ref_output = model(test_x) + + # pattern matching and replace + gm, num_replaced = replace_multi_stream_linear_with_aux_stream_wrapper(gm) + + assert num_replaced == 2 + y = gm(test_x) + assert torch.allclose(y, ref_output) + + static_x = torch.randn(4, in_dim).to("cuda") + static_output = torch.randn(4, out_dim).to("cuda") + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + static_output.copy_(gm(static_x)) + + static_x.copy_(test_x) + graph.replay() + + assert torch.allclose(static_output, ref_output) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py index 4b1c373b0fc..917cdbaca2e 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py @@ -163,7 +163,6 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): dt, dt_bias, seq_len, - seq_start, slot_idx, use_initial_states, cu_seqlens, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 5eb6bcfaa38..3e13e28a0c5 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -404,6 +404,10 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1) mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + # compute quant_scales + gemm1_dequant = (w1_scales * hidden_states_scale).contiguous().squeeze().to(torch.float32) + gemm2_act_quant = (1.0 / w2_input_scale[0]).contiguous().to(torch.float32) + gemm2_dequant = (w2_scales * w2_input_scale[0]).contiguous().squeeze().to(torch.float32) ad_test_output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused( x, # Note! unquantized input is expected selected_experts.to(torch.int), @@ -417,6 +421,9 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE w1_weight_scale=w1_scales, w2_weight_scale=w2_scales, w3_weight_scale=w3_scales, + gemm1_dequant=gemm1_dequant, + gemm2_act_quant=gemm2_act_quant, + gemm2_dequant=gemm2_dequant, mlp_style=mlp_style, act_fn=activation_func, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 4e1e78bd97d..320dbdcfa62 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -186,6 +186,14 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): }, }, ), + ( + "nvidia/Nemotron-Nano-3-30B-A3.5B-dev-1024", + { + "transforms": { + "multi_stream_moe": {"stage": "compile", "enabled": True}, + }, + }, + ), ], ) def test_build_ad(model_hub_id: str, llm_extra_args: dict):