Skip to content
Merged
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ transforms:
############################################################################################
fuse_causal_conv_activation:
stage: compile
multi_stream_moe:
stage: compile
enabled: true
compile_model:
stage: compile
run_per_gm: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def trtllm_moe_fused_fake(
return torch.empty_like(x)


# Todo: refactor this repeating code block
# NOTE(suyogg): If compile ever fails because of this, just write a triton kernel
# for this function and use it as a custom op.
@torch.compile
def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear)."""
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
Expand Down Expand Up @@ -107,6 +109,9 @@ def trtllm_quant_fp8_moe_fused(
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
w3_weight_scale: torch.Tensor, # [E] or unused
gemm1_dequant: torch.Tensor, # [E]
gemm2_act_quant: torch.Tensor, # [E]
gemm2_dequant: torch.Tensor, # [E]
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
Expand Down Expand Up @@ -144,10 +149,10 @@ def trtllm_quant_fp8_moe_fused(
x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0])

# Scales are stored in float32
w1_weight_scale = w1_weight_scale.to(torch.float32)
w2_weight_scale = w2_weight_scale.to(torch.float32)
# w1_weight_scale = w1_weight_scale.to(torch.float32)
# w2_weight_scale = w2_weight_scale.to(torch.float32)
w1_input_scale = w1_input_scale.to(torch.float32)[0]
w2_input_scale = w2_input_scale.to(torch.float32)[0]
# w2_input_scale = w2_input_scale.to(torch.float32)[0]

# Prepare quant_scales for TensorRT-LLM FP8 format:
# [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale]
Expand All @@ -158,14 +163,14 @@ def trtllm_quant_fp8_moe_fused(
# - gemm1_input_dequant_scale: w1_input_scale

# Compute combined scales
gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze()
gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32)
gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze()
gemm1_input_dequant = w1_input_scale.contiguous()
# gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze()
# gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32)
# gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze()
# gemm1_input_dequant = w1_input_scale.contiguous()

assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D"
assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D"
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant]
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, w1_input_scale]

# Ensure contiguous tensors
selected_experts = selected_experts.int().contiguous()
Expand Down Expand Up @@ -229,6 +234,9 @@ def trtllm_quant_fp8_moe_fused_fake(
w1_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
gemm1_dequant: torch.Tensor,
gemm2_act_quant: torch.Tensor,
gemm2_dequant: torch.Tensor,
mlp_style: str,
act_fn: str,
) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,7 @@ def _cuda_cached_causal_conv1d(

if y_dec.dim() == 3:
y_dec = y_dec.squeeze(-1)
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
y_dec.to(y_flat.dtype)
)
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec)

# Custom op must not return an alias of any input; return a fresh tensor
return y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ def _triton_ssm_prepare_metadata(
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)

seq_start = torch.zeros_like(seq_len_sanitized)
if num_seq > 1:
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)

# Truncate slot indices to match active sequences
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
Expand Down Expand Up @@ -88,7 +84,6 @@ def _triton_ssm_prepare_metadata(

return (
seq_len_sanitized,
seq_start,
slot_idx_sanitized,
use_initial_states,
cu_seqlens,
Expand All @@ -109,7 +104,6 @@ def _triton_ssm_prepare_metadata_fake(
device = slot_idx.device
# Always-correct shapes
seq_len_fake = torch.empty_like(seq_len_sanitized)
seq_start_fake = torch.empty_like(seq_len_sanitized)
slot_idx_fake = torch.empty(num_seq, dtype=torch.long, device=device)
use_initial_states_fake = torch.empty(num_seq, dtype=torch.bool, device=device)
cu_seqlens_fake = torch.empty(num_seq + 1, dtype=torch.int32, device=device)
Expand Down Expand Up @@ -142,7 +136,6 @@ def _triton_ssm_prepare_metadata_fake(

return (
seq_len_fake,
seq_start_fake,
slot_idx_fake,
use_initial_states_fake,
cu_seqlens_fake,
Expand All @@ -165,7 +158,6 @@ def _triton_cached_ssm(
dt_bias: torch.Tensor, # [num_heads]
# METADATA
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
cu_seqlens: torch.Tensor, # [num_seq + 1]
Expand Down Expand Up @@ -290,7 +282,6 @@ def _triton_cached_ssm_fake(
dt_bias: torch.Tensor, # [num_heads]
# METADATA
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
cu_seqlens: torch.Tensor, # [num_seq + 1]
Expand Down Expand Up @@ -340,9 +331,9 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns: seq_len, seq_start, slot_idx, use_initial_states,
# Returns: seq_len, slot_idx, use_initial_states,
# cu_seqlens, chunk_indices, chunk_offsets, seq_idx_prefill, batch_info_tensor
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 9
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 8

@classmethod
def get_cache_initializers(
Expand Down
235 changes: 235 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""
Custom ops to enable multi-stream execution.
"""

from __future__ import annotations

from threading import RLock
from typing import Any, Callable, Dict, Tuple

import torch


class _Singleton(type):
_instances: Dict[type, Any] = {}
_lock = RLock()

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
with cls._lock:
if cls not in cls._instances: # double-checked locking
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]


# A singleton that holds the pointers to the cuda streams and events.
# In multi-gpu scenario, each GPU/rank has its own CudaStreamManager.
class CudaStreamManager(metaclass=_Singleton):
AUX_STREAM_NAME = "aux"
MAIN_STREAM_NAME = "main"

def __init__(self) -> None:
# In case __init__ ever gets called twice, guard against re-init
if hasattr(self, "streams"):
return

self._lock = RLock()

# Events needed for stream synchronization
self.events: Dict[str, Any] = {
self.AUX_STREAM_NAME: torch.cuda.Event(),
self.MAIN_STREAM_NAME: torch.cuda.Event(),
}

# Streams for multi-stream execution
self.aux_stream = torch.cuda.Stream()
self.streams: Dict[str, Any] = {
self.AUX_STREAM_NAME: self.aux_stream,
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
}


cuda_stream_manager = CudaStreamManager()


@torch.library.custom_op("auto_deploy::record_event", mutates_args=())
def record_event(stream_name: str) -> None:
event = cuda_stream_manager.events[stream_name]
event.record()


@torch.library.custom_op("auto_deploy::wait_event", mutates_args=())
def wait_event(event_name: str) -> None:
event = cuda_stream_manager.events[event_name]
event.wait()


# skip during compilation
@torch._dynamo.disable
def record_event_wrapper(
fn: Callable,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> torch.Tensor:
output = fn(*args, **kwargs)
torch.ops.auto_deploy.record_event(cuda_stream_manager.MAIN_STREAM_NAME)
return output


@torch._dynamo.disable
def aux_stream_wrapper(
fn: Callable,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> torch.Tensor:
stream_name = cuda_stream_manager.AUX_STREAM_NAME
with torch.cuda.stream(cuda_stream_manager.streams[stream_name]):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = fn(*args, **kwargs)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
return output


# bf16
@torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=())
def trtllm_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.trtllm_moe_fused(
x,
selected_experts,
routing_weights,
w3_w1_stacked_weight,
w2_stacked_weight,
mlp_style,
act_fn,
)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
return output


@trtllm_moe_fused_aux.register_fake
def trtllm_moe_fused_aux_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)


# triton bf16
@torch.library.custom_op("auto_deploy::triton_moe_fused_aux", mutates_args=())
def triton_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.triton_moe_fused(
x,
selected_experts,
routing_weights,
w1_stacked_weight,
w2_stacked_weight,
)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
return output


@triton_moe_fused_aux.register_fake
def triton_moe_fused_aux_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)


# trtllm fp8
@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused_aux", mutates_args=())
def trtllm_quant_fp8_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp
w1_input_scale: torch.Tensor, # [E] stacked input scales
w2_input_scale: torch.Tensor, # [E] stacked input scales
w3_input_scale: torch.Tensor, # [E] or unused
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
w3_weight_scale: torch.Tensor, # [E] or unused
gemm1_dequant: torch.Tensor, # [E]
gemm2_act_quant: torch.Tensor, # [E]
gemm2_dequant: torch.Tensor, # [E]
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused(
x,
selected_experts,
routing_weights,
w1_weight,
w2_weight,
w3_weight,
w1_input_scale,
w2_input_scale,
w3_input_scale,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale,
gemm1_dequant,
gemm2_act_quant,
gemm2_dequant,
mlp_style,
act_fn,
)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
return output


@trtllm_quant_fp8_moe_fused_aux.register_fake
def trtllm_quant_fp8_moe_fused_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_weight: torch.Tensor,
w2_weight: torch.Tensor,
w3_weight: torch.Tensor,
w1_input_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
w3_input_scale: torch.Tensor,
w1_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
gemm1_dequant: torch.Tensor,
gemm2_act_quant: torch.Tensor,
gemm2_dequant: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)
Loading