Skip to content

Commit 13fbd43

Browse files
authored
[TRTLLM-9370][feat] Integration of CuteDSL NVFP4 grouped GEMM (Part 2: SwiGLU Fusion and Finalize Fusion) (#9288)
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 9b2abb8 commit 13fbd43

File tree

16 files changed

+6229
-270
lines changed

16 files changed

+6229
-270
lines changed

cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
142142
#endif
143143

144144
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
145-
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
145+
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
146146
int32_t const threads = kThreadsPerBlock;
147147

148148
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
@@ -383,7 +383,7 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
383383
#endif
384384

385385
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
386-
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
386+
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
387387
int32_t const threads = kThreadsPerBlock;
388388

389389
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,7 @@ void run(Data& data, void* stream)
647647
//
648648
// The upper bound is a strict requirement. The number of blocks should be determined by querying
649649
// the device properties, or conservatively low.
650-
// /!\ The following number is not portable!! (but works on H100 and B200)
651-
int const numBlocksCoop = 128;
650+
static int const numBlocksCoop = tensorrt_llm::common::getMultiProcessorCount();
652651

653652
// Maximum number of tokens supported by the kernel using a cooperative launch.
654653
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;

cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ std::vector<torch::Tensor> moe_sort(torch::Tensor const& token_selected_experts,
120120
TORCH_CHECK(token_final_scales.size(0) == num_tokens, "token_final_scales.size(0) must be num_tokens.");
121121
TORCH_CHECK(token_final_scales.size(1) == top_k, "token_final_scales.size(1) must be top_k.");
122122
return moe_topk_sort_impl(std::nullopt, std::nullopt, token_selected_experts, token_final_scales, num_experts,
123-
top_k, std::nullopt, std::nullopt, local_expert_offset, local_num_experts, std::nullopt, tile_tokens_dim,
124-
RoutingMethodType::Renormalize);
123+
top_k, 1, 1, local_expert_offset, local_num_experts, std::nullopt, tile_tokens_dim,
124+
RoutingMethodType::DeepSeekV3);
125125
}
126126

127127
// Permute

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 818 additions & 103 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/grouped_blockscaled_gemm_persistent.py renamed to tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@
5252
import cutlass.utils.blockscaled_layout as blockscaled_utils
5353
from cutlass.cute.nvgpu import cpasync, tcgen05
5454

55+
from .utils import is_power_of_2
5556

56-
class Sm100BlockScaledPersistentGroupedGemmKernel:
57+
58+
class Sm100BlockScaledContiguousGroupedGemmKernel:
5759
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
5860
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
5961
@@ -88,7 +90,7 @@ class Sm100BlockScaledPersistentGroupedGemmKernel:
8890
- Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors
8991
9092
Example:
91-
>>> gemm = Sm100BlockScaledPersistentGroupedGemmKernel(
93+
>>> gemm = Sm100BlockScaledContiguousGroupedGemmKernel(
9294
... sf_vec_size=16, mma_tiler_mn=(256, 128), cluster_shape_mn=(2, 1)
9395
... )
9496
>>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream)
@@ -2052,9 +2054,6 @@ def is_valid_mma_tiler_and_cluster_shape(
20522054
is_valid = False
20532055

20542056
# Skip invalid cluster shape
2055-
def is_power_of_2(x: int) -> bool:
2056-
return x > 0 and (x & (x - 1)) == 0
2057-
20582057
if (
20592058
cluster_shape_mn[0] * cluster_shape_mn[1] > 16
20602059
or cluster_shape_mn[0] <= 0
@@ -2138,8 +2137,9 @@ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
21382137
is_valid = False
21392138
return is_valid
21402139

2141-
@staticmethod
2140+
@classmethod
21422141
def can_implement(
2142+
cls,
21432143
ab_dtype: Type[cutlass.Numeric],
21442144
sf_dtype: Type[cutlass.Numeric],
21452145
sf_vec_size: int,
@@ -2198,24 +2198,22 @@ def can_implement(
21982198
"""
21992199
can_implement = True
22002200
# Skip unsupported types
2201-
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
2201+
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
22022202
ab_dtype, sf_dtype, sf_vec_size, acc_dtype, c_dtype
22032203
):
22042204
can_implement = False
22052205

22062206
# Skip unsupported layouts
2207-
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_layouts(
2208-
ab_dtype, c_dtype, a_major, b_major, c_major
2209-
):
2207+
if not cls.is_valid_layouts(ab_dtype, c_dtype, a_major, b_major, c_major):
22102208
can_implement = False
22112209

22122210
# Skip invalid mma tile shape and cluster shape
2213-
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape(
2211+
if not cls.is_valid_mma_tiler_and_cluster_shape(
22142212
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, m_aligned
22152213
):
22162214
can_implement = False
22172215
# Skip illegal problem shape for load/store alignment
2218-
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_tensor_alignment(
2216+
if not cls.is_valid_tensor_alignment(
22192217
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
22202218
):
22212219
can_implement = False
@@ -2238,7 +2236,7 @@ def wrapper(
22382236
m: int,
22392237
n: int,
22402238
k: int,
2241-
l: cutlass.Constexpr, # noqa: E741
2239+
l: int, # noqa: E741
22422240
tile_size: cutlass.Constexpr,
22432241
scaling_vector_size: cutlass.Constexpr,
22442242
max_active_clusters: cutlass.Constexpr,
@@ -2266,7 +2264,6 @@ def wrapper(
22662264
tile_idx_to_group_idx = cute.make_tensor(
22672265
tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,))
22682266
)
2269-
tile_idx_to_group_idx.mark_layout_dynamic()
22702267
num_non_exiting_tiles = cute.make_tensor(
22712268
num_non_exiting_tiles_ptr, layout=cute.make_layout((1,))
22722269
)

0 commit comments

Comments
 (0)