-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[#9271][perf] Enable multi-stream MOE optimization in AutoDeploy #9322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
9603844
Enable multi-stream MOE optimization in AutoDeploy
suyoggupta 1a0f90b
remove duplicate test file
suyoggupta 4692586
revert tp plan changes
suyoggupta 948e5d4
remove redundant casts before rmsnorm
suyoggupta 0439b3b
fused quant scale op
suyoggupta 9c95592
more quant moe opt
suyoggupta 3f71b77
merge main
suyoggupta c5179ce
merge main
suyoggupta f78ae5a
minor refactoring
suyoggupta 73e2b1b
update some comments
suyoggupta df1daca
fix tests
suyoggupta 9a1cf7e
update ad test
suyoggupta 5b411c5
fix test failures, address review comments
suyoggupta 74335a6
materialize args before calling the op
suyoggupta 19288a2
fix how get_attr nodes are added
suyoggupta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
235 changes: 235 additions & 0 deletions
235
tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,235 @@ | ||
| """ | ||
| Custom ops to enable multi-stream execution. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from threading import RLock | ||
| from typing import Any, Callable, Dict, Tuple | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class _Singleton(type): | ||
| _instances: Dict[type, Any] = {} | ||
| _lock = RLock() | ||
|
|
||
| def __call__(cls, *args: Any, **kwargs: Any) -> Any: | ||
| if cls not in cls._instances: | ||
| with cls._lock: | ||
| if cls not in cls._instances: # double-checked locking | ||
| cls._instances[cls] = super().__call__(*args, **kwargs) | ||
| return cls._instances[cls] | ||
|
|
||
|
|
||
| # A singleton that holds the pointers to the cuda streams and events. | ||
| # In multi-gpu scenario, each GPU/rank has its own CudaStreamManager. | ||
| class CudaStreamManager(metaclass=_Singleton): | ||
| AUX_STREAM_NAME = "aux" | ||
| MAIN_STREAM_NAME = "main" | ||
|
|
||
| def __init__(self) -> None: | ||
| # In case __init__ ever gets called twice, guard against re-init | ||
| if hasattr(self, "streams"): | ||
| return | ||
|
|
||
| self._lock = RLock() | ||
|
|
||
| # Events needed for stream synchronization | ||
| self.events: Dict[str, Any] = { | ||
| self.AUX_STREAM_NAME: torch.cuda.Event(), | ||
| self.MAIN_STREAM_NAME: torch.cuda.Event(), | ||
| } | ||
|
|
||
| # Streams for multi-stream execution | ||
| self.aux_stream = torch.cuda.Stream() | ||
| self.streams: Dict[str, Any] = { | ||
| self.AUX_STREAM_NAME: self.aux_stream, | ||
| self.MAIN_STREAM_NAME: torch.cuda.default_stream(), | ||
| } | ||
|
|
||
|
|
||
| cuda_stream_manager = CudaStreamManager() | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::record_event", mutates_args=()) | ||
| def record_event(stream_name: str) -> None: | ||
| event = cuda_stream_manager.events[stream_name] | ||
| event.record() | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::wait_event", mutates_args=()) | ||
| def wait_event(event_name: str) -> None: | ||
| event = cuda_stream_manager.events[event_name] | ||
| event.wait() | ||
|
|
||
|
|
||
| # skip during compilation | ||
| @torch._dynamo.disable | ||
| def record_event_wrapper( | ||
| fn: Callable, | ||
| *args: Tuple[Any, ...], | ||
| **kwargs: Dict[str, Any], | ||
| ) -> torch.Tensor: | ||
| output = fn(*args, **kwargs) | ||
| torch.ops.auto_deploy.record_event(cuda_stream_manager.MAIN_STREAM_NAME) | ||
| return output | ||
|
|
||
|
|
||
| @torch._dynamo.disable | ||
| def aux_stream_wrapper( | ||
| fn: Callable, | ||
| *args: Tuple[Any, ...], | ||
| **kwargs: Dict[str, Any], | ||
| ) -> torch.Tensor: | ||
| stream_name = cuda_stream_manager.AUX_STREAM_NAME | ||
| with torch.cuda.stream(cuda_stream_manager.streams[stream_name]): | ||
| torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) | ||
| output = fn(*args, **kwargs) | ||
| torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) | ||
| torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) | ||
| return output | ||
|
|
||
|
|
||
| # trtllm bf16 | ||
| @torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=()) | ||
| def trtllm_moe_fused_aux( | ||
| x: torch.Tensor, | ||
| selected_experts: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| w3_w1_stacked_weight: torch.Tensor, | ||
| w2_stacked_weight: torch.Tensor, | ||
| mlp_style: str = "gated_mlp", | ||
| act_fn: str = "silu", | ||
| ) -> torch.Tensor: | ||
| with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): | ||
| torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) | ||
| output = torch.ops.auto_deploy.trtllm_moe_fused( | ||
| x, | ||
| selected_experts, | ||
| routing_weights, | ||
| w3_w1_stacked_weight, | ||
| w2_stacked_weight, | ||
| mlp_style, | ||
| act_fn, | ||
| ) | ||
| torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) | ||
| torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) | ||
| return output | ||
|
|
||
|
|
||
| @trtllm_moe_fused_aux.register_fake | ||
| def trtllm_moe_fused_aux_fake( | ||
| x: torch.Tensor, | ||
| selected_experts: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| w3_w1_stacked_weight: torch.Tensor, | ||
| w2_stacked_weight: torch.Tensor, | ||
| mlp_style: str = "gated_mlp", | ||
| act_fn: str = "silu", | ||
| ) -> torch.Tensor: | ||
| return torch.empty_like(x) | ||
|
|
||
|
|
||
| # triton bf16 | ||
| @torch.library.custom_op("auto_deploy::triton_moe_fused_aux", mutates_args=()) | ||
| def triton_moe_fused_aux( | ||
| x: torch.Tensor, | ||
| selected_experts: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| w1_stacked_weight: torch.Tensor, | ||
| w2_stacked_weight: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): | ||
| torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) | ||
| output = torch.ops.auto_deploy.triton_moe_fused( | ||
| x, | ||
| selected_experts, | ||
| routing_weights, | ||
| w1_stacked_weight, | ||
| w2_stacked_weight, | ||
| ) | ||
| torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) | ||
| torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) | ||
| return output | ||
|
|
||
|
|
||
| @triton_moe_fused_aux.register_fake | ||
| def triton_moe_fused_aux_fake( | ||
| x: torch.Tensor, | ||
| selected_experts: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| w1_stacked_weight: torch.Tensor, | ||
| w2_stacked_weight: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| return torch.empty_like(x) | ||
|
|
||
|
|
||
| # trtllm fp8 | ||
| @torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused_aux", mutates_args=()) | ||
| def trtllm_quant_fp8_moe_fused_aux( | ||
| x: torch.Tensor, | ||
| selected_experts: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights | ||
| w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights | ||
| w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp | ||
| w1_input_scale: torch.Tensor, # [E] stacked input scales | ||
| w2_input_scale: torch.Tensor, # [E] stacked input scales | ||
| w3_input_scale: torch.Tensor, # [E] or unused | ||
| w1_weight_scale: torch.Tensor, # [E] stacked weight scales | ||
| w2_weight_scale: torch.Tensor, # [E] stacked weight scales | ||
| w3_weight_scale: torch.Tensor, # [E] or unused | ||
| gemm1_dequant: torch.Tensor, # [E] | ||
| gemm2_act_quant: torch.Tensor, # [E] | ||
| gemm2_dequant: torch.Tensor, # [E] | ||
| mlp_style: str = "gated_mlp", | ||
| act_fn: str = "silu", | ||
| ) -> torch.Tensor: | ||
| with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]): | ||
| torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME) | ||
| output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused( | ||
| x, | ||
| selected_experts, | ||
| routing_weights, | ||
| w1_weight, | ||
| w2_weight, | ||
| w3_weight, | ||
| w1_input_scale, | ||
| w2_input_scale, | ||
| w3_input_scale, | ||
| w1_weight_scale, | ||
| w2_weight_scale, | ||
| w3_weight_scale, | ||
| gemm1_dequant, | ||
| gemm2_act_quant, | ||
| gemm2_dequant, | ||
| mlp_style, | ||
| act_fn, | ||
| ) | ||
| torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME) | ||
| torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME) | ||
| return output | ||
|
|
||
|
|
||
| @trtllm_quant_fp8_moe_fused_aux.register_fake | ||
| def trtllm_quant_fp8_moe_fused_aux_fake( | ||
| x: torch.Tensor, | ||
| selected_experts: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| w1_weight: torch.Tensor, | ||
| w2_weight: torch.Tensor, | ||
| w3_weight: torch.Tensor, | ||
| w1_input_scale: torch.Tensor, | ||
| w2_input_scale: torch.Tensor, | ||
| w3_input_scale: torch.Tensor, | ||
| w1_weight_scale: torch.Tensor, | ||
| w2_weight_scale: torch.Tensor, | ||
| w3_weight_scale: torch.Tensor, | ||
| gemm1_dequant: torch.Tensor, | ||
| gemm2_act_quant: torch.Tensor, | ||
| gemm2_dequant: torch.Tensor, | ||
| mlp_style: str = "gated_mlp", | ||
| act_fn: str = "silu", | ||
| ) -> torch.Tensor: | ||
| return torch.empty_like(x) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.