-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[#9271][perf] Enable multi-stream MOE optimization in AutoDeploy #9322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
📝 WalkthroughWalkthroughAdds multi-stream execution support for mixture-of-experts (MoE) layers in AutoDeploy. Introduces a CudaStreamManager singleton for managing CUDA streams and synchronization events, new Torch custom ops for stream-coordinated MoE operations, a transform that rewrites fused MoE ops to auxiliary stream variants, and optimizes Nemotron-H shared expert computation by pre-computing it once. Changes
Sequence Diagram(s)sequenceDiagram
participant Main as Main Stream
participant Aux as Aux Stream
participant Event as Event Manager
Main->>Event: record_event(MAIN_STREAM)
Note over Main: Input ready
Aux->>Event: wait_event(MAIN_STREAM)
Note over Aux: Wait for input
Aux->>Aux: Execute MoE Fused Op<br/>(trtllm_moe_fused_aux)
Aux->>Event: record_event(AUX_STREAM)
Note over Aux: Output ready
Main->>Event: wait_event(AUX_STREAM)
Note over Main: Wait for result
Main->>Main: Proceed with next ops
sequenceDiagram
participant GM as GraphModule
participant Transform as MultiStreamMOE
participant Detector as Node Detector
participant Replacer as Node Replacer
GM->>Transform: _apply(gm, cm, factory, config)
Transform->>Detector: Build op_dict mapping<br/>(fused_ops → aux_variants)
Transform->>Replacer: _execute_op_in_aux_stream(gm, op_dict)
loop For each target node
Replacer->>Replacer: Wrap input with<br/>record_event_wrapper
Replacer->>Replacer: Replace op with<br/>aux variant
end
Replacer->>Replacer: Dead code elimination<br/>& recompile
Replacer-->>Transform: Updated gm + replacement count
Transform->>Transform: Create TransformInfo<br/>(num_matches, status)
Transform-->>GM: Return updated gm & info
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Suyog Gupta <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py (1)
16-55: Consider documenting the graph structure assumptions.The transformation makes specific assumptions about the graph structure (e.g., the presence of a
view.defaultoperation before MoE ops). Consider adding documentation or assertions that clarify these assumptions to aid future maintenance and debugging.tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py (1)
14-14: Annotate mutable class attribute with ClassVar.The static analysis tool correctly identifies that
_instancesshould be annotated withtyping.ClassVarto clarify it's a class-level attribute.Apply this diff:
+from typing import Any, Callable, ClassVar, Dict, Tuple -from typing import Any, Callable, Dict, Tuple class _Singleton(type): - _instances: Dict[type, Any] = {} + _instances: ClassVar[Dict[type, Any]] = {} _lock = RLock()
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(1 hunks)tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py(1 hunks)tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py(2 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py(1 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py(1 hunks)
🧰 Additional context used
🧠 Learnings (5)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Learnt from: ChristinaZ
Repo: NVIDIA/TensorRT-LLM PR: 7068
File: cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh:169-172
Timestamp: 2025-08-20T07:43:36.447Z
Learning: In TensorRT-LLM MOE kernels, when processing up to 128 experts across 32 threads, each thread handles at most 4 experts (N < 5 constraint), where N represents candidates per thread rather than total system capacity.
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 6968
File: cpp/tensorrt_llm/thop/loraOp.cpp:133-141
Timestamp: 2025-08-17T15:07:01.420Z
Learning: In TensorRT-LLM's LoRA implementation, the LoraImpl::run() method handles setStream() internally in _runGemm() (line 51 in lora.cpp), along with setWorkspace(). The stream parameter flows from loraOp.cpp through LoraImpl::run() to _runGemm() where setStream() is called appropriately. Adding setStream() in loraOp.cpp would be redundant and goes against the intended architectural design.
📚 Learning: 2025-08-14T23:23:27.449Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py
📚 Learning: 2025-10-20T17:07:18.745Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py:98-116
Timestamp: 2025-10-20T17:07:18.745Z
Learning: In NemotronH models (tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py), the gate (self.gate) returns topk_indices and topk_weights that are already in the correct shape to be passed directly to torch_ops.auto_deploy.torch_moe without needing to reshape them when hidden_states is flattened.
Applied to files:
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Applied to files:
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.pytensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py (4)
record_event_wrapper(67-74)trtllm_moe_fused_aux(93-115)triton_moe_fused_aux(133-151)trtllm_quant_fp8_moe_fused_aux(167-203)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op(198-221)tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
BaseTransform(217-504)SharedConfig(61-66)TransformInfo(121-178)TransformRegistry(507-535)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py (2)
aux_stream_wrapper(77-88)record_event_wrapper(67-74)tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
is_op(198-221)
tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py (2)
trtllm_moe_fused(7-54)trtllm_quant_fp8_moe_fused(97-215)
🪛 Ruff (0.14.5)
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py
36-36: Avoid specifying long messages outside the exception class
(TRY003)
65-65: Unused method argument: cm
(ARG002)
66-66: Unused method argument: factory
(ARG002)
67-67: Unused method argument: shared_config
(ARG002)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py
59-59: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py
14-14: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
121-121: Unused function argument: selected_experts
(ARG001)
122-122: Unused function argument: routing_weights
(ARG001)
123-123: Unused function argument: w3_w1_stacked_weight
(ARG001)
124-124: Unused function argument: w2_stacked_weight
(ARG001)
125-125: Unused function argument: mlp_style
(ARG001)
126-126: Unused function argument: act_fn
(ARG001)
157-157: Unused function argument: selected_experts
(ARG001)
158-158: Unused function argument: routing_weights
(ARG001)
159-159: Unused function argument: w1_stacked_weight
(ARG001)
160-160: Unused function argument: w2_stacked_weight
(ARG001)
209-209: Unused function argument: selected_experts
(ARG001)
210-210: Unused function argument: routing_weights
(ARG001)
211-211: Unused function argument: w1_weight
(ARG001)
212-212: Unused function argument: w2_weight
(ARG001)
213-213: Unused function argument: w3_weight
(ARG001)
214-214: Unused function argument: w1_input_scale
(ARG001)
215-215: Unused function argument: w2_input_scale
(ARG001)
216-216: Unused function argument: w3_input_scale
(ARG001)
217-217: Unused function argument: w1_weight_scale
(ARG001)
218-218: Unused function argument: w2_weight_scale
(ARG001)
219-219: Unused function argument: w3_weight_scale
(ARG001)
220-220: Unused function argument: mlp_style
(ARG001)
221-221: Unused function argument: act_fn
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)
131-160: LGTM! Efficient optimization for shared expert computation.The pre-computation of
shared_outeliminates redundant calls toself.shared_experts(residuals), which is a good performance optimization. The comment clearly explains the dispatch order for multi-stream execution.tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
164-166: LGTM! Configuration correctly enables the multi-stream MoE transform.The new transform entry is properly structured and aligns with the transform implementation in
multi_stream_moe.py.tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py (3)
14-27: LGTM! Well-structured custom op with fake implementation.The custom op and its fake variant are properly implemented for testing purposes.
52-59: Note the different heuristic from multi_stream_moe.py.This test uses
len(input_node.users) > 1to identify the target input node, whilemulti_stream_moe.pychecks fortorch.ops.aten.view.default. Ensure both heuristics are appropriate for their respective use cases, as they may behave differently on various graph structures.
98-134: LGTM! Comprehensive test coverage.The test properly validates:
- Graph transformation and recompilation
- Numerical correctness after transformation
- CUDA graph capture and replay compatibility
tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py (1)
92-115: LGTM! Consistent multi-stream execution pattern.The aux-stream wrappers for all three MoE variants follow a consistent pattern:
- Switch to auxiliary stream
- Wait for main stream event
- Execute the fused MoE operation
- Record auxiliary stream event
- Wait for auxiliary stream event before returning
This ensures proper synchronization between streams.
Also applies to: 132-151, 166-203
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py (1)
29-36: Now I need to examine the actual file to understand the context better:Based on my verification, the review comment raises a valid and actionable concern about fragile assumptions in the code.
Verify the hardcoded view.default target assumption.
The code assumes input nodes will have
torch.ops.aten.view.defaultas their target when processing MoE graphs. This assumption is fragile because:
- Different ATen operators (over 3000 exist) may be generated depending on how graphs are exported
- FX graphs contain ATen-level operations which vary based on how the original code manipulates tensors
- Different reshaping methods (reshape vs view) can be optimized differently—reshape may return a view or a copy depending on memory layout compatibility
The code will break when MoE implementations or PyTorch versions produce
torch.ops.aten.reshapeor other ATen reshape variants instead of specificallyview.default, and raisesValueErrorsilently without context about which operation was found instead.Recommendations: Replace the hardcoded operation check with a more robust approach that:
- Checks for multiple reshaping operations (view.default, reshape, flatten, etc.)
- Provides detailed error messages showing which operations were actually present in the graph
- Documents why view.default is specifically expected or makes the logic operation-agnostic
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
|
/bot run |
|
PR_Github #25476 [ run ] triggered by Bot. Commit: |
|
PR_Github #25476 [ run ] completed with state |
Signed-off-by: Suyog Gupta <[email protected]>
Signed-off-by: Suyog Gupta <[email protected]>
|
/bot run |
|
PR_Github #25491 [ run ] triggered by Bot. Commit: |
|
PR_Github #25491 [ run ] completed with state |
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Outdated
Show resolved
Hide resolved
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_multi_stream.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Suyog Gupta <[email protected]>
|
/bot run |
|
PR_Github #25599 [ run ] triggered by Bot. Commit: |
|
PR_Github #25599 [ run ] completed with state |
Signed-off-by: Suyog Gupta <[email protected]>
|
/bot run |
|
PR_Github #25609 [ run ] triggered by Bot. Commit: |
Signed-off-by: Suyog Gupta <[email protected]>
|
/bot run |
|
PR_Github #25611 [ run ] triggered by Bot. Commit: |
|
PR_Github #25609 [ run ] completed with state |
|
PR_Github #25611 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.
fixes #9271