Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 56 additions & 20 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from __future__ import annotations

from threading import RLock
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, List, Tuple

import torch

from ..utils.logger import ad_logger


class _Singleton(type):
_instances: Dict[type, Any] = {}
Expand All @@ -23,10 +25,12 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any:


# A singleton that holds the pointers to the cuda streams and events.
# In multi-gpu scenario, each GPU/rank has its own CudaStreamManager.
class CudaStreamManager(metaclass=_Singleton):
AUX_STREAM_NAME = "aux"
MAIN_STREAM_NAME = "main"
devices: List[torch.device] = []
events: Dict[torch.device, Dict[str, Any]] = {}
streams: Dict[torch.device, Dict[str, Any]] = {}

def __init__(self) -> None:
# In case __init__ ever gets called twice, guard against re-init
Expand All @@ -35,32 +39,50 @@ def __init__(self) -> None:

self._lock = RLock()

# Events needed for stream synchronization
self.events: Dict[str, Any] = {
self.AUX_STREAM_NAME: torch.cuda.Event(),
self.MAIN_STREAM_NAME: torch.cuda.Event(),
}

# Streams for multi-stream execution
self.aux_stream = torch.cuda.Stream()
self.streams: Dict[str, Any] = {
self.AUX_STREAM_NAME: self.aux_stream,
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
}
with torch.cuda.device(torch.cuda.current_device()):
self.events[torch.cuda.current_device()] = {
self.AUX_STREAM_NAME: torch.cuda.Event(),
self.MAIN_STREAM_NAME: torch.cuda.Event(),
}
self.streams[torch.cuda.current_device()] = {
self.AUX_STREAM_NAME: torch.cuda.Stream(),
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
}

def add_device(self, device: int) -> None:
if device not in self.devices:
self.devices.append(device)
with torch.cuda.device(device):
self.events[device] = {
self.AUX_STREAM_NAME: torch.cuda.Event(),
self.MAIN_STREAM_NAME: torch.cuda.Event(),
}
self.streams[device] = {
self.AUX_STREAM_NAME: torch.cuda.Stream(),
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
}
else:
ad_logger.warning(f"CudaStreamManager: Device {device} already added")

def get_stream(self, device: torch.device, stream_name: str) -> torch.cuda.Stream:
return self.streams[device][stream_name]

def get_event(self, device: torch.device, event_name: str) -> torch.cuda.Event:
return self.events[device][event_name]


cuda_stream_manager = CudaStreamManager()


@torch.library.custom_op("auto_deploy::record_event", mutates_args=())
def record_event(stream_name: str) -> None:
event = cuda_stream_manager.events[stream_name]
event = cuda_stream_manager.get_event(torch.cuda.current_device(), stream_name)
event.record()


@torch.library.custom_op("auto_deploy::wait_event", mutates_args=())
def wait_event(event_name: str) -> None:
event = cuda_stream_manager.events[event_name]
event = cuda_stream_manager.get_event(torch.cuda.current_device(), event_name)
event.wait()


Expand All @@ -83,7 +105,9 @@ def aux_stream_wrapper(
**kwargs: Dict[str, Any],
) -> torch.Tensor:
stream_name = cuda_stream_manager.AUX_STREAM_NAME
with torch.cuda.stream(cuda_stream_manager.streams[stream_name]):
with torch.cuda.stream(
cuda_stream_manager.get_stream(torch.cuda.current_device(), stream_name)
):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = fn(*args, **kwargs)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
Expand All @@ -102,7 +126,11 @@ def trtllm_moe_fused_aux(
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
with torch.cuda.stream(
cuda_stream_manager.get_stream(
torch.cuda.current_device(), cuda_stream_manager.AUX_STREAM_NAME
)
):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.trtllm_moe_fused(
x,
Expand Down Expand Up @@ -140,7 +168,11 @@ def triton_moe_fused_aux(
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
with torch.cuda.stream(
cuda_stream_manager.get_stream(
torch.cuda.current_device(), cuda_stream_manager.AUX_STREAM_NAME
)
):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.triton_moe_fused(
x,
Expand Down Expand Up @@ -186,7 +218,11 @@ def trtllm_quant_fp8_moe_fused_aux(
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
with torch.cuda.stream(
cuda_stream_manager.get_stream(
torch.cuda.current_device(), cuda_stream_manager.AUX_STREAM_NAME
)
):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused(
x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
from torch.fx import GraphModule

from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import record_event_wrapper
from tensorrt_llm._torch.auto_deploy.custom_ops.multi_stream import (
cuda_stream_manager,
record_event_wrapper,
)

from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
Expand Down Expand Up @@ -67,7 +70,8 @@ def _apply(
torch.ops.auto_deploy.triton_moe_fused: torch.ops.auto_deploy.triton_moe_fused_aux,
torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused: torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused_aux,
}

# Ensure that aux stream and events for the current device are added to the CudaStreamManager.
cuda_stream_manager.add_device(torch.cuda.current_device())
gm, num_matches = _execute_op_in_aux_stream(gm, op_dict)

info = TransformInfo(
Expand Down