Skip to content
Open
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
#endif

static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;

auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
Expand Down Expand Up @@ -383,7 +383,7 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
#endif

static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;

auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,7 @@ void run(Data& data, void* stream)
//
// The upper bound is a strict requirement. The number of blocks should be determined by querying
// the device properties, or conservatively low.
// /!\ The following number is not portable!! (but works on H100 and B200)
int const numBlocksCoop = 128;
static int const numBlocksCoop = tensorrt_llm::common::getMultiProcessorCount();

// Maximum number of tokens supported by the kernel using a cooperative launch.
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ std::vector<torch::Tensor> moe_sort(torch::Tensor const& token_selected_experts,
TORCH_CHECK(token_final_scales.size(0) == num_tokens, "token_final_scales.size(0) must be num_tokens.");
TORCH_CHECK(token_final_scales.size(1) == top_k, "token_final_scales.size(1) must be top_k.");
return moe_topk_sort_impl(std::nullopt, std::nullopt, token_selected_experts, token_final_scales, num_experts,
top_k, std::nullopt, std::nullopt, local_expert_offset, local_num_experts, std::nullopt, tile_tokens_dim,
RoutingMethodType::Renormalize);
top_k, 1, 1, local_expert_offset, local_num_experts, std::nullopt, tile_tokens_dim,
RoutingMethodType::DeepSeekV3);
}

// Permute
Expand Down
921 changes: 818 additions & 103 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass.cute.nvgpu import cpasync, tcgen05

from .utils import is_power_of_2

class Sm100BlockScaledPersistentGroupedGemmKernel:

class Sm100BlockScaledContiguousGroupedGemmKernel:
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.

Expand Down Expand Up @@ -88,7 +90,7 @@ class Sm100BlockScaledPersistentGroupedGemmKernel:
- Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors

Example:
>>> gemm = Sm100BlockScaledPersistentGroupedGemmKernel(
>>> gemm = Sm100BlockScaledContiguousGroupedGemmKernel(
... sf_vec_size=16, mma_tiler_mn=(256, 128), cluster_shape_mn=(2, 1)
... )
>>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream)
Expand Down Expand Up @@ -2052,9 +2054,6 @@ def is_valid_mma_tiler_and_cluster_shape(
is_valid = False

# Skip invalid cluster shape
def is_power_of_2(x: int) -> bool:
return x > 0 and (x & (x - 1)) == 0

if (
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
or cluster_shape_mn[0] <= 0
Expand Down Expand Up @@ -2138,8 +2137,9 @@ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
is_valid = False
return is_valid

@staticmethod
@classmethod
def can_implement(
cls,
ab_dtype: Type[cutlass.Numeric],
sf_dtype: Type[cutlass.Numeric],
sf_vec_size: int,
Expand Down Expand Up @@ -2198,24 +2198,22 @@ def can_implement(
"""
can_implement = True
# Skip unsupported types
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
ab_dtype, sf_dtype, sf_vec_size, acc_dtype, c_dtype
):
can_implement = False

# Skip unsupported layouts
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_layouts(
ab_dtype, c_dtype, a_major, b_major, c_major
):
if not cls.is_valid_layouts(ab_dtype, c_dtype, a_major, b_major, c_major):
can_implement = False

# Skip invalid mma tile shape and cluster shape
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape(
if not cls.is_valid_mma_tiler_and_cluster_shape(
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, m_aligned
):
can_implement = False
# Skip illegal problem shape for load/store alignment
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_tensor_alignment(
if not cls.is_valid_tensor_alignment(
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
):
can_implement = False
Expand All @@ -2238,7 +2236,7 @@ def wrapper(
m: int,
n: int,
k: int,
l: cutlass.Constexpr, # noqa: E741
l: int, # noqa: E741
tile_size: cutlass.Constexpr,
scaling_vector_size: cutlass.Constexpr,
max_active_clusters: cutlass.Constexpr,
Expand Down Expand Up @@ -2266,7 +2264,6 @@ def wrapper(
tile_idx_to_group_idx = cute.make_tensor(
tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,))
)
tile_idx_to_group_idx.mark_layout_dynamic()
num_non_exiting_tiles = cute.make_tensor(
num_non_exiting_tiles_ptr, layout=cute.make_layout((1,))
)
Expand Down
Loading