Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/auto_deploy/nano_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ transforms:
detect_sharding:
sharding_source: ['factory', 'heuristic']
sharding_dims: ['ep', 'bmm']
multi_stream_moe:
stage: compile
enabled: true
# tunable mamba cache dtype
# --> use float32 for accuracy and default (null) for speed
insert_cached_ssm_attention:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def trtllm_moe_fused_fake(
return torch.empty_like(x)


# Todo: refactor this repeating code block
# NOTE(suyogg): If compile ever fails because of this, just write a triton kernel
# for this function and use it as a custom op.
@torch.compile
def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear)."""
FP8_MIN = torch.finfo(torch.float8_e4m3fn).min
Expand Down Expand Up @@ -107,6 +109,9 @@ def trtllm_quant_fp8_moe_fused(
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
w3_weight_scale: torch.Tensor, # [E] or unused
gemm1_dequant: torch.Tensor, # [E]
gemm2_act_quant: torch.Tensor, # [E]
gemm2_dequant: torch.Tensor, # [E]
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
Expand All @@ -125,6 +130,9 @@ def trtllm_quant_fp8_moe_fused(
w1_weight_scale: Weight scales for w1 [E]
w2_weight_scale: Weight scales for w2 [E]
w3_weight_scale: Weight scales for w3 [E]
gemm1_dequant: Precomputed gemm1 dequant scale [E]
gemm2_act_quant: Precomputed gemm2 act quant scale [1]
gemm2_dequant: Precomputed gemm2 dequant scale [E]
mlp_style: "gated_mlp" or "mlp"
act_fn: "silu" for gated_mlp, "relu2" for mlp
Expand All @@ -144,28 +152,20 @@ def trtllm_quant_fp8_moe_fused(
x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0])

# Scales are stored in float32
w1_weight_scale = w1_weight_scale.to(torch.float32)
w2_weight_scale = w2_weight_scale.to(torch.float32)
w1_input_scale = w1_input_scale.to(torch.float32)[0]
w2_input_scale = w2_input_scale.to(torch.float32)[0]
w1_input_scale = w1_input_scale[0]

# Prepare quant_scales for TensorRT-LLM FP8 format:
# [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale]
# For gated MLP:
# These are precomputed in `fused_moe` transform
# - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3)
# - gemm2_act_quant_scale: 1 / w2_input_scale
# - gemm2_dequant_scale: w2_weight_scale * w2_input_scale
# - gemm1_input_dequant_scale: w1_input_scale

# Compute combined scales
gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze()
gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32)
gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze()
gemm1_input_dequant = w1_input_scale.contiguous()

assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D"
assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D"
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant]
quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, w1_input_scale]

# Ensure contiguous tensors
selected_experts = selected_experts.int().contiguous()
Expand Down Expand Up @@ -229,6 +229,9 @@ def trtllm_quant_fp8_moe_fused_fake(
w1_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
gemm1_dequant: torch.Tensor,
gemm2_act_quant: torch.Tensor,
gemm2_dequant: torch.Tensor,
mlp_style: str,
act_fn: str,
) -> torch.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,7 @@ def _cuda_cached_causal_conv1d(

if y_dec.dim() == 3:
y_dec = y_dec.squeeze(-1)
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
y_dec.to(y_flat.dtype)
)
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec)

# Custom op must not return an alias of any input; return a fresh tensor
return y
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ def _triton_ssm_prepare_metadata(
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)

seq_start = torch.zeros_like(seq_len_sanitized)
if num_seq > 1:
seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0)

# Truncate slot indices to match active sequences
slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long)
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch
Expand Down Expand Up @@ -88,7 +84,6 @@ def _triton_ssm_prepare_metadata(

return (
seq_len_sanitized,
seq_start,
slot_idx_sanitized,
use_initial_states,
cu_seqlens,
Expand All @@ -109,7 +104,6 @@ def _triton_ssm_prepare_metadata_fake(
device = slot_idx.device
# Always-correct shapes
seq_len_fake = torch.empty_like(seq_len_sanitized)
seq_start_fake = torch.empty_like(seq_len_sanitized)
slot_idx_fake = torch.empty(num_seq, dtype=torch.long, device=device)
use_initial_states_fake = torch.empty(num_seq, dtype=torch.bool, device=device)
cu_seqlens_fake = torch.empty(num_seq + 1, dtype=torch.int32, device=device)
Expand Down Expand Up @@ -142,7 +136,6 @@ def _triton_ssm_prepare_metadata_fake(

return (
seq_len_fake,
seq_start_fake,
slot_idx_fake,
use_initial_states_fake,
cu_seqlens_fake,
Expand All @@ -165,7 +158,6 @@ def _triton_cached_ssm(
dt_bias: torch.Tensor, # [num_heads]
# METADATA
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
cu_seqlens: torch.Tensor, # [num_seq + 1]
Expand Down Expand Up @@ -290,7 +282,6 @@ def _triton_cached_ssm_fake(
dt_bias: torch.Tensor, # [num_heads]
# METADATA
seq_len: torch.Tensor, # [num_seq]
seq_start: torch.Tensor, # [num_seq]
slot_idx: torch.Tensor, # [num_seq]
use_initial_states: torch.Tensor, # [num_seq]
cu_seqlens: torch.Tensor, # [num_seq + 1]
Expand Down Expand Up @@ -340,9 +331,9 @@ def get_cached_attention_op(cls) -> MHACallable:

@classmethod
def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
# Returns: seq_len, seq_start, slot_idx, use_initial_states,
# Returns: seq_len, slot_idx, use_initial_states,
# cu_seqlens, chunk_indices, chunk_offsets, seq_idx_prefill, batch_info_tensor
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 9
return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 8

@classmethod
def get_cache_initializers(
Expand Down
235 changes: 235 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/multi_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
"""
Custom ops to enable multi-stream execution.
"""

from __future__ import annotations

from threading import RLock
from typing import Any, Callable, Dict, Tuple

import torch


class _Singleton(type):
_instances: Dict[type, Any] = {}
_lock = RLock()

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
with cls._lock:
if cls not in cls._instances: # double-checked locking
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]


# A singleton that holds the pointers to the cuda streams and events.
# In multi-gpu scenario, each GPU/rank has its own CudaStreamManager.
class CudaStreamManager(metaclass=_Singleton):
AUX_STREAM_NAME = "aux"
MAIN_STREAM_NAME = "main"

def __init__(self) -> None:
# In case __init__ ever gets called twice, guard against re-init
if hasattr(self, "streams"):
return

self._lock = RLock()

# Events needed for stream synchronization
self.events: Dict[str, Any] = {
self.AUX_STREAM_NAME: torch.cuda.Event(),
self.MAIN_STREAM_NAME: torch.cuda.Event(),
}

# Streams for multi-stream execution
self.aux_stream = torch.cuda.Stream()
self.streams: Dict[str, Any] = {
self.AUX_STREAM_NAME: self.aux_stream,
self.MAIN_STREAM_NAME: torch.cuda.default_stream(),
}


cuda_stream_manager = CudaStreamManager()


@torch.library.custom_op("auto_deploy::record_event", mutates_args=())
def record_event(stream_name: str) -> None:
event = cuda_stream_manager.events[stream_name]
event.record()


@torch.library.custom_op("auto_deploy::wait_event", mutates_args=())
def wait_event(event_name: str) -> None:
event = cuda_stream_manager.events[event_name]
event.wait()


# skip during compilation
@torch._dynamo.disable
def record_event_wrapper(
fn: Callable,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> torch.Tensor:
output = fn(*args, **kwargs)
torch.ops.auto_deploy.record_event(cuda_stream_manager.MAIN_STREAM_NAME)
return output


@torch._dynamo.disable
def aux_stream_wrapper(
fn: Callable,
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> torch.Tensor:
stream_name = cuda_stream_manager.AUX_STREAM_NAME
with torch.cuda.stream(cuda_stream_manager.streams[stream_name]):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = fn(*args, **kwargs)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
return output


# trtllm bf16
@torch.library.custom_op("auto_deploy::trtllm_moe_fused_aux", mutates_args=())
def trtllm_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.trtllm_moe_fused(
x,
selected_experts,
routing_weights,
w3_w1_stacked_weight,
w2_stacked_weight,
mlp_style,
act_fn,
)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
return output


@trtllm_moe_fused_aux.register_fake
def trtllm_moe_fused_aux_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)


# triton bf16
@torch.library.custom_op("auto_deploy::triton_moe_fused_aux", mutates_args=())
def triton_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.triton_moe_fused(
x,
selected_experts,
routing_weights,
w1_stacked_weight,
w2_stacked_weight,
)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
return output


@triton_moe_fused_aux.register_fake
def triton_moe_fused_aux_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)


# trtllm fp8
@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused_aux", mutates_args=())
def trtllm_quant_fp8_moe_fused_aux(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights
w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights
w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp
w1_input_scale: torch.Tensor, # [E] stacked input scales
w2_input_scale: torch.Tensor, # [E] stacked input scales
w3_input_scale: torch.Tensor, # [E] or unused
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
w3_weight_scale: torch.Tensor, # [E] or unused
gemm1_dequant: torch.Tensor, # [E]
gemm2_act_quant: torch.Tensor, # [E]
gemm2_dequant: torch.Tensor, # [E]
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
with torch.cuda.stream(cuda_stream_manager.streams[cuda_stream_manager.AUX_STREAM_NAME]):
torch.ops.auto_deploy.wait_event(cuda_stream_manager.MAIN_STREAM_NAME)
output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused(
x,
selected_experts,
routing_weights,
w1_weight,
w2_weight,
w3_weight,
w1_input_scale,
w2_input_scale,
w3_input_scale,
w1_weight_scale,
w2_weight_scale,
w3_weight_scale,
gemm1_dequant,
gemm2_act_quant,
gemm2_dequant,
mlp_style,
act_fn,
)
torch.ops.auto_deploy.record_event(cuda_stream_manager.AUX_STREAM_NAME)
torch.ops.auto_deploy.wait_event(cuda_stream_manager.AUX_STREAM_NAME)
return output


@trtllm_quant_fp8_moe_fused_aux.register_fake
def trtllm_quant_fp8_moe_fused_aux_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
w1_weight: torch.Tensor,
w2_weight: torch.Tensor,
w3_weight: torch.Tensor,
w1_input_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
w3_input_scale: torch.Tensor,
w1_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
gemm1_dequant: torch.Tensor,
gemm2_act_quant: torch.Tensor,
gemm2_dequant: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
) -> torch.Tensor:
return torch.empty_like(x)
Loading