Skip to content

Commit efd5037

Browse files
authored
[#9271][perf] Enable multi-stream MOE optimization in AutoDeploy (#9322)
Signed-off-by: Suyog Gupta <[email protected]>
1 parent d1c7249 commit efd5037

File tree

15 files changed

+567
-57
lines changed

15 files changed

+567
-57
lines changed

examples/auto_deploy/nano_v3.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ transforms:
1515
detect_sharding:
1616
sharding_source: ['factory', 'heuristic']
1717
sharding_dims: ['ep', 'bmm']
18+
multi_stream_moe:
19+
stage: compile
20+
enabled: true
1821
# tunable mamba cache dtype
1922
# --> use float32 for accuracy and default (null) for speed
2023
insert_cached_ssm_attention:

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def trtllm_moe_fused_fake(
6767
return torch.empty_like(x)
6868

6969

70-
# Todo: refactor this repeating code block
70+
# NOTE(suyogg): If compile ever fails because of this, just write a triton kernel
71+
# for this function and use it as a custom op.
72+
@torch.compile
7173
def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
7274
"""Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear)."""
7375
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
@@ -107,6 +109,9 @@ def trtllm_quant_fp8_moe_fused(
107109
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
108110
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
109111
w3_weight_scale: torch.Tensor, # [E] or unused
112+
gemm1_dequant: torch.Tensor, # [E]
113+
gemm2_act_quant: torch.Tensor, # [E]
114+
gemm2_dequant: torch.Tensor, # [E]
110115
mlp_style: str = "gated_mlp",
111116
act_fn: str = "silu",
112117
) -> torch.Tensor:
@@ -125,6 +130,9 @@ def trtllm_quant_fp8_moe_fused(
125130
w1_weight_scale: Weight scales for w1 [E]
126131
w2_weight_scale: Weight scales for w2 [E]
127132
w3_weight_scale: Weight scales for w3 [E]
133+
gemm1_dequant: Precomputed gemm1 dequant scale [E]
134+
gemm2_act_quant: Precomputed gemm2 act quant scale [1]
135+
gemm2_dequant: Precomputed gemm2 dequant scale [E]
128136
mlp_style: "gated_mlp" or "mlp"
129137
act_fn: "silu" for gated_mlp, "relu2" for mlp
130138
@@ -144,28 +152,20 @@ def trtllm_quant_fp8_moe_fused(
144152
x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0])
145153

146154
# Scales are stored in float32
147-
w1_weight_scale = w1_weight_scale.to(torch.float32)
148-
w2_weight_scale = w2_weight_scale.to(torch.float32)
149-
w1_input_scale = w1_input_scale.to(torch.float32)[0]
150-
w2_input_scale = w2_input_scale.to(torch.float32)[0]
155+
w1_input_scale = w1_input_scale[0]
151156

152157
# Prepare quant_scales for TensorRT-LLM FP8 format:
153158
# [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale]
154159
# For gated MLP:
160+
# These are precomputed in `fused_moe` transform
155161
# - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3)
156162
# - gemm2_act_quant_scale: 1 / w2_input_scale
157163
# - gemm2_dequant_scale: w2_weight_scale * w2_input_scale
158164
# - gemm1_input_dequant_scale: w1_input_scale
159165

160-
# Compute combined scales
161-
gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze()
162-
gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32)
163-
gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze()
164-
gemm1_input_dequant = w1_input_scale.contiguous()
165-
166166
assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D"
167167
assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D"
168-
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant]
168+
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, w1_input_scale]
169169

170170
# Ensure contiguous tensors
171171
selected_experts = selected_experts.int().contiguous()
@@ -229,6 +229,9 @@ def trtllm_quant_fp8_moe_fused_fake(
229229
w1_weight_scale: torch.Tensor,
230230
w2_weight_scale: torch.Tensor,
231231
w3_weight_scale: torch.Tensor,
232+
gemm1_dequant: torch.Tensor,
233+
gemm2_act_quant: torch.Tensor,
234+
gemm2_dequant: torch.Tensor,
232235
mlp_style: str,
233236
act_fn: str,
234237
) -> torch.Tensor:

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,7 @@ def _cuda_cached_causal_conv1d(
204204

205205
if y_dec.dim() == 3:
206206
y_dec = y_dec.squeeze(-1)
207-
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
208-
y_dec.to(y_flat.dtype)
209-
)
207+
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec)
210208

211209
# Custom op must not return an alias of any input; return a fresh tensor
212210
return y

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ def _triton_ssm_prepare_metadata(
4343
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
4444
num_seq = len(seq_len_sanitized)
4545

46-
seq_start = torch.zeros_like(seq_len_sanitized)
47-
if num_seq > 1:
48-
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)
49-
5046
# Truncate slot indices to match active sequences
5147
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
5248
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
@@ -88,7 +84,6 @@ def _triton_ssm_prepare_metadata(
8884

8985
return (
9086
seq_len_sanitized,
91-
seq_start,
9287
slot_idx_sanitized,
9388
use_initial_states,
9489
cu_seqlens,
@@ -109,7 +104,6 @@ def _triton_ssm_prepare_metadata_fake(
109104
device = slot_idx.device
110105
# Always-correct shapes
111106
seq_len_fake = torch.empty_like(seq_len_sanitized)
112-
seq_start_fake = torch.empty_like(seq_len_sanitized)
113107
slot_idx_fake = torch.empty(num_seq, dtype=torch.long, device=device)
114108
use_initial_states_fake = torch.empty(num_seq, dtype=torch.bool, device=device)
115109
cu_seqlens_fake = torch.empty(num_seq + 1, dtype=torch.int32, device=device)
@@ -142,7 +136,6 @@ def _triton_ssm_prepare_metadata_fake(
142136

143137
return (
144138
seq_len_fake,
145-
seq_start_fake,
146139
slot_idx_fake,
147140
use_initial_states_fake,
148141
cu_seqlens_fake,
@@ -165,7 +158,6 @@ def _triton_cached_ssm(
165158
dt_bias: torch.Tensor, # [num_heads]
166159
# METADATA
167160
seq_len: torch.Tensor, # [num_seq]
168-
seq_start: torch.Tensor, # [num_seq]
169161
slot_idx: torch.Tensor, # [num_seq]
170162
use_initial_states: torch.Tensor, # [num_seq]
171163
cu_seqlens: torch.Tensor, # [num_seq + 1]
@@ -290,7 +282,6 @@ def _triton_cached_ssm_fake(
290282
dt_bias: torch.Tensor, # [num_heads]
291283
# METADATA
292284
seq_len: torch.Tensor, # [num_seq]
293-
seq_start: torch.Tensor, # [num_seq]
294285
slot_idx: torch.Tensor, # [num_seq]
295286
use_initial_states: torch.Tensor, # [num_seq]
296287
cu_seqlens: torch.Tensor, # [num_seq + 1]
@@ -340,9 +331,9 @@ def get_cached_attention_op(cls) -> MHACallable:
340331

341332
@classmethod
342333
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
343-
# Returns: seq_len, seq_start, slot_idx, use_initial_states,
334+
# Returns: seq_len, slot_idx, use_initial_states,
344335
# cu_seqlens, chunk_indices, chunk_offsets, seq_idx_prefill, batch_info_tensor
345-
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 9
336+
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 8
346337

347338
@classmethod
348339
def get_cache_initializers(
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
"""
2+
Custom ops to enable multi-stream execution.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from threading import RLock
8+
from typing import Any, Callable, Dict, Tuple
9+
10+
import torch
11+
12+
13+
class _Singleton(type):
14+
_instances: Dict[type, Any] = {}
15+
_lock = RLock()
16+
17+
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
18+
if cls not in cls._instances:
19+
with cls._lock:
20+
if cls not in cls._instances: # double-checked locking
21+
cls._instances[cls] = super().__call__(*args, **kwargs)
22+
return cls._instances[cls]
23+
24+
25+
# A singleton that holds the pointers to the cuda streams and events.
26+
# In multi-gpu scenario, each GPU/rank has its own CudaStreamManager.
27+
class CudaStreamManager(metaclass=_Singleton):
28+
AUX_STREAM_NAME = "aux"
29+
MAIN_STREAM_NAME = "main"
30+
31+
def __init__(self) -> None:
32+
# In case __init__ ever gets called twice, guard against re-init
33+
if hasattr(self, "streams"):
34+
return
35+
36+
self._lock = RLock()
37+
38+
# Events needed for stream synchronization
39+
self.events: Dict[str, Any] = {
40+
self.AUX_STREAM_NAME: torch.cuda.Event(),
41+
self.MAIN_STREAM_NAME: torch.cuda.Event(),
42+
}
43+
44+
# Streams for multi-stream execution
45+
self.aux_stream = torch.cuda.Stream()
46+
self.streams: Dict[str, Any] = {
47+
self.AUX_STREAM_NAME: self.aux_stream,
48+
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
49+
}
50+
51+
52+
cuda_stream_manager = CudaStreamManager()
53+
54+
55+
@torch.library.custom_op("auto_deploy::record_event", mutates_args=())
56+
def record_event(stream_name: str) -> None:
57+
event = cuda_stream_manager.events[stream_name]
58+
event.record()
59+
60+
61+
@torch.library.custom_op("auto_deploy::wait_event", mutates_args=())
62+
def wait_event(event_name: str) -> None:
63+
event = cuda_stream_manager.events[event_name]
64+
event.wait()
65+
66+
67+
# skip during compilation
68+
@torch._dynamo.disable
69+
def record_event_wrapper(
70+
fn: Callable,
71+
*args: Tuple[Any, ...],
72+
**kwargs: Dict[str, Any],
73+
) -> torch.Tensor:
74+
output = fn(*args, **kwargs)
75+
torch.ops.auto_deploy.record_event(cuda_stream_manager.MAIN_STREAM_NAME)
76+
return output
77+
78+
79+
@torch._dynamo.disable
80+
def aux_stream_wrapper(
81+
fn: Callable,
82+
*args: Tuple[Any, ...],
83+
**kwargs: Dict[str, Any],
84+
) -> torch.Tensor:
85+
stream_name = cuda_stream_manager.AUX_STREAM_NAME
86+
with torch.cuda.stream(cuda_stream_manager.streams[stream_name]):
87+
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
88+
output = fn(*args, **kwargs)
89+
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
90+
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
91+
return output
92+
93+
94+
# trtllm bf16
95+
@torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=())
96+
def trtllm_moe_fused_aux(
97+
x: torch.Tensor,
98+
selected_experts: torch.Tensor,
99+
routing_weights: torch.Tensor,
100+
w3_w1_stacked_weight: torch.Tensor,
101+
w2_stacked_weight: torch.Tensor,
102+
mlp_style: str = "gated_mlp",
103+
act_fn: str = "silu",
104+
) -> torch.Tensor:
105+
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
106+
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
107+
output = torch.ops.auto_deploy.trtllm_moe_fused(
108+
x,
109+
selected_experts,
110+
routing_weights,
111+
w3_w1_stacked_weight,
112+
w2_stacked_weight,
113+
mlp_style,
114+
act_fn,
115+
)
116+
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
117+
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
118+
return output
119+
120+
121+
@trtllm_moe_fused_aux.register_fake
122+
def trtllm_moe_fused_aux_fake(
123+
x: torch.Tensor,
124+
selected_experts: torch.Tensor,
125+
routing_weights: torch.Tensor,
126+
w3_w1_stacked_weight: torch.Tensor,
127+
w2_stacked_weight: torch.Tensor,
128+
mlp_style: str = "gated_mlp",
129+
act_fn: str = "silu",
130+
) -> torch.Tensor:
131+
return torch.empty_like(x)
132+
133+
134+
# triton bf16
135+
@torch.library.custom_op("auto_deploy::triton_moe_fused_aux", mutates_args=())
136+
def triton_moe_fused_aux(
137+
x: torch.Tensor,
138+
selected_experts: torch.Tensor,
139+
routing_weights: torch.Tensor,
140+
w1_stacked_weight: torch.Tensor,
141+
w2_stacked_weight: torch.Tensor,
142+
) -> torch.Tensor:
143+
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
144+
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
145+
output = torch.ops.auto_deploy.triton_moe_fused(
146+
x,
147+
selected_experts,
148+
routing_weights,
149+
w1_stacked_weight,
150+
w2_stacked_weight,
151+
)
152+
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
153+
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
154+
return output
155+
156+
157+
@triton_moe_fused_aux.register_fake
158+
def triton_moe_fused_aux_fake(
159+
x: torch.Tensor,
160+
selected_experts: torch.Tensor,
161+
routing_weights: torch.Tensor,
162+
w1_stacked_weight: torch.Tensor,
163+
w2_stacked_weight: torch.Tensor,
164+
) -> torch.Tensor:
165+
return torch.empty_like(x)
166+
167+
168+
# trtllm fp8
169+
@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused_aux", mutates_args=())
170+
def trtllm_quant_fp8_moe_fused_aux(
171+
x: torch.Tensor,
172+
selected_experts: torch.Tensor,
173+
routing_weights: torch.Tensor,
174+
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
175+
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
176+
w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp
177+
w1_input_scale: torch.Tensor, # [E] stacked input scales
178+
w2_input_scale: torch.Tensor, # [E] stacked input scales
179+
w3_input_scale: torch.Tensor, # [E] or unused
180+
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
181+
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
182+
w3_weight_scale: torch.Tensor, # [E] or unused
183+
gemm1_dequant: torch.Tensor, # [E]
184+
gemm2_act_quant: torch.Tensor, # [E]
185+
gemm2_dequant: torch.Tensor, # [E]
186+
mlp_style: str = "gated_mlp",
187+
act_fn: str = "silu",
188+
) -> torch.Tensor:
189+
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
190+
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
191+
output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused(
192+
x,
193+
selected_experts,
194+
routing_weights,
195+
w1_weight,
196+
w2_weight,
197+
w3_weight,
198+
w1_input_scale,
199+
w2_input_scale,
200+
w3_input_scale,
201+
w1_weight_scale,
202+
w2_weight_scale,
203+
w3_weight_scale,
204+
gemm1_dequant,
205+
gemm2_act_quant,
206+
gemm2_dequant,
207+
mlp_style,
208+
act_fn,
209+
)
210+
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
211+
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
212+
return output
213+
214+
215+
@trtllm_quant_fp8_moe_fused_aux.register_fake
216+
def trtllm_quant_fp8_moe_fused_aux_fake(
217+
x: torch.Tensor,
218+
selected_experts: torch.Tensor,
219+
routing_weights: torch.Tensor,
220+
w1_weight: torch.Tensor,
221+
w2_weight: torch.Tensor,
222+
w3_weight: torch.Tensor,
223+
w1_input_scale: torch.Tensor,
224+
w2_input_scale: torch.Tensor,
225+
w3_input_scale: torch.Tensor,
226+
w1_weight_scale: torch.Tensor,
227+
w2_weight_scale: torch.Tensor,
228+
w3_weight_scale: torch.Tensor,
229+
gemm1_dequant: torch.Tensor,
230+
gemm2_act_quant: torch.Tensor,
231+
gemm2_dequant: torch.Tensor,
232+
mlp_style: str = "gated_mlp",
233+
act_fn: str = "silu",
234+
) -> torch.Tensor:
235+
return torch.empty_like(x)

0 commit comments

Comments
 (0)