Skip to content

Commit 822cb01

Browse files
VALLIS-NERIAdjns99
andauthored
[TRTLLM-6286] [perf] Add NoSmem epilogue schedule and dynamic cluster shape for sm10x group gemm (NVIDIA#7757)
Signed-off-by: Xiwen Yu <[email protected]> Signed-off-by: djns99 <[email protected]> Co-authored-by: djns99 <[email protected]>
1 parent 897c4dd commit 822cb01

File tree

12 files changed

+256
-153
lines changed

12 files changed

+256
-153
lines changed

cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,9 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
549549
ActivationType mActType = ActivationType::Relu;
550550

551551
constexpr static int64_t NUM_BUFFERS = 32;
552+
int64_t mNumWorkspaceBuffers = NUM_BUFFERS;
553+
int64_t mNumInputBuffers = NUM_BUFFERS;
554+
int64_t mNumGemmProfilerBuffers = NUM_BUFFERS;
552555

553556
std::array<QuantParams, NUM_BUFFERS> mQuantParams{};
554557
bool mUseLora = false;
@@ -619,12 +622,12 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
619622

620623
if (gemm_to_profile == GemmToProfile::LAYER)
621624
{
622-
623625
mWorkspaceSize = mMoERunner.getWorkspaceSize(mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK,
624626
mActType, parallelism_config, mUseLora, /*use_deepseek_fp8_block_scale=*/false,
625627
/*min_latency_mode=*/false, mUsePrequantScale);
626628

627-
mWorkspace = allocBuffer<char>(mWorkspaceSize * NUM_BUFFERS);
629+
mNumWorkspaceBuffers = mWorkspaceSize > 1024 * 1024 * 1024 ? 2 : NUM_BUFFERS;
630+
mWorkspace = allocBuffer<char>(mWorkspaceSize * mNumWorkspaceBuffers);
628631

629632
mExpertBias1 = nullptr;
630633
mExpertBias2 = nullptr;
@@ -690,9 +693,10 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
690693
mScaleProbsSize = padSize(mTotalTokens * mK);
691694
mScaleProbs = allocBuffer<float>(mScaleProbsSize * NUM_BUFFERS);
692695
mInputTensorSize = padSize(mTotalTokens * mHiddenSize);
693-
mInputTensor = allocBuffer<DataType>(mInputTensorSize * NUM_BUFFERS);
696+
mNumInputBuffers = mInputTensorSize > 1024 * 1024 * 1024 ? 2 : NUM_BUFFERS;
697+
mInputTensor = allocBuffer<DataType>(mInputTensorSize * mNumInputBuffers);
694698
mFinalOutputSize = padSize(mTotalTokens * mHiddenSize);
695-
mFinalOutput = allocBuffer<OutputType>(mFinalOutputSize * NUM_BUFFERS);
699+
mFinalOutput = allocBuffer<OutputType>(mFinalOutputSize * mNumInputBuffers);
696700

697701
mSourceToExpandedMapSize = padSize(mTotalTokens * mK);
698702
mSourceToExpandedMap = allocBuffer<int>(mSourceToExpandedMapSize * NUM_BUFFERS);
@@ -732,10 +736,11 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
732736
= std::max(mGemmProfilerWorkspaceSize, mGemmProfilerBackend.getWorkspaceSize(mTotalTokens));
733737
}
734738

735-
int64_t num_gemm_buffers = gemm_to_profile == GemmToProfile::LAYER ? 1 : NUM_BUFFERS;
736739
mGemmProfilerWorkspaceSize = padSize(mGemmProfilerWorkspaceSize);
740+
mNumGemmProfilerBuffers = mGemmProfilerWorkspaceSize > 1024 * 1024 * 1024 ? 2 : NUM_BUFFERS;
741+
mNumGemmProfilerBuffers = gemm_to_profile == GemmToProfile::LAYER ? 1 : mNumGemmProfilerBuffers;
737742
mGemmProfilerWorkspace = mGemmProfilerWorkspaceSize > 0
738-
? allocBuffer<char>(mGemmProfilerWorkspaceSize * num_gemm_buffers)
743+
? allocBuffer<char>(mGemmProfilerWorkspaceSize * mNumGemmProfilerBuffers)
739744
: nullptr;
740745

741746
check_cuda_error(cudaStreamSynchronize(streamPtr->get()));
@@ -748,7 +753,8 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
748753
mGemmProfilerBackend.mGemmToProfile = static_cast<GemmProfilerBackend::GemmToProfile>(gemm_to_profile);
749754
auto* expert_weights = gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1 : mExpertWeight2;
750755
auto expert_weights_size = gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1Size : mExpertWeight2Size;
751-
mGemmProfilerBackend.prepare(mTotalTokens, mGemmProfilerWorkspace + mGemmProfilerWorkspaceSize * mBufferIndex,
756+
mGemmProfilerBackend.prepare(mTotalTokens,
757+
mGemmProfilerWorkspace + mGemmProfilerWorkspaceSize * (mBufferIndex % mNumGemmProfilerBuffers),
752758
/*expert_weights=*/expert_weights + expert_weights_size * mBufferIndex, streamPtr->get());
753759
}
754760

@@ -865,7 +871,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
865871
}
866872

867873
// Profile all samples or for 1 sec
868-
int const max_iters = mGemmProfilerBackend.NUM_ROUTING_SAMPLES;
874+
int const max_iters = mGemmProfilerBackend.NUM_ROUTING_SAMPLES * 2;
869875
float const max_time_ms = 1000.f;
870876

871877
float time = 0.f;
@@ -974,7 +980,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
974980
}
975981
mGemmProfilerBackend.mSampleIndex = mBufferIndex % mGemmProfilerBackend.NUM_ROUTING_SAMPLES;
976982
mGemmProfilerBackend.runProfiler(mTotalTokens, tactics,
977-
mGemmProfilerWorkspace + mGemmProfilerWorkspaceSize * mBufferIndex,
983+
mGemmProfilerWorkspace + mGemmProfilerWorkspaceSize * (mBufferIndex % mNumGemmProfilerBuffers),
978984
/*expert_weights=*/expert_weights + expert_weights_size * mBufferIndex, streamPtr->get());
979985
break;
980986
}
@@ -983,26 +989,28 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
983989
auto stream = streamPtr->get();
984990
MoeMinLatencyParams min_latency_params;
985991
#ifdef USING_OSS_CUTLASS_MOE_GEMM
986-
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true,
992+
mMoERunner.runMoe(mInputTensor + mInputTensorSize * (mBufferIndex % mNumInputBuffers), nullptr, true,
987993
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
988994
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
989995
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
990996
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
991997
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
992-
mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
993-
mFinalOutput + mFinalOutputSize * mBufferIndex,
998+
mHiddenSize, mInterSize, mNumExperts, mK,
999+
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
1000+
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
9941001
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
9951002
/*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex],
9961003
/*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream);
9971004
#else
998-
mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true,
1005+
mMoERunner.runMoe(mInputTensor + mInputTensorSize * (mBufferIndex % mNumInputBuffers), nullptr, true,
9991006
mSelectedExperts + mSelectedExpertsSize * mBufferIndex,
10001007
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
10011008
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
10021009
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
10031010
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
1004-
mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex,
1005-
mFinalOutput + mFinalOutputSize * mBufferIndex,
1011+
mHiddenSize, mInterSize, mNumExperts, mK,
1012+
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
1013+
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
10061014
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
10071015
/*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex],
10081016
/*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream);

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
2323
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
2424

25-
namespace fused_moe
25+
namespace fused_moe_oss
2626
{
2727
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int MaxTileM_, int TileN_,
2828
int TileK_, int Stages_, Activation_Type activation_type_>
@@ -215,4 +215,4 @@ static int fused_gemm_maximum_active_blocks(int smem_capacity = -1)
215215
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
216216
return max_active_blocks;
217217
}
218-
} // namespace fused_moe
218+
} // namespace fused_moe_oss

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#pragma once
1919
#include <cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh>
2020

21-
namespace fused_moe
21+
namespace fused_moe_oss
2222
{
2323

2424
template <typename ElementInput_, typename ElementWeight_, typename ElementOutput_, int TileM_, int TileN_, int TileK_,
@@ -798,4 +798,4 @@ struct Fused_Moe_Kernel_routine_sm80<ElementInput_, ElementWeight_, ElementOutpu
798798
}
799799
};
800800

801-
} // namespace fused_moe
801+
} // namespace fused_moe_oss

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <cutlass_extensions/gemm/kernel/moe_cute_util.cuh>
2323
#include <cutlass_extensions/gemm/kernel/moe_problem_visitor.h>
2424

25-
namespace fused_moe
25+
namespace fused_moe_oss
2626
{
2727
template <typename ElementInput, typename ElementWeight, typename ElementOutput>
2828
struct Routine_Arguments
@@ -212,4 +212,4 @@ struct Fused_Moe_Kernel_traits_sm80
212212

213213
// #endif
214214
};
215-
} // namespace fused_moe
215+
} // namespace fused_moe_oss

cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ function(add_instantiations library base_dir)
138138
endif()
139139
endmacro()
140140

141-
glob_src_create_target(80 "80;86;90;100f;120f") # we use 80 kernels to support
142-
# fp16 of all archs
141+
glob_src_create_target(80 "80;86;90;100f;120f") # we use sm80 kernels to
142+
# support fp16 of all archs
143143
glob_src_create_target(90 90)
144144
glob_src_create_target(100 100f)
145145
glob_src_create_target(103 103)

cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(CutlassGemmConfig::Can
367367
return candidate_configs;
368368
}
369369

370-
std::vector<CutlassGemmConfig> get_candidate_configs_sm100_dynamic_cluster_shape(
370+
std::vector<CutlassGemmConfig> get_candidate_configs_sm100_dynamic_cluster_shape(int sm,
371371
CutlassGemmConfig::CandidateConfigTypeParam const config, EpilogueScheduleType schedule,
372-
ClusterShape const dynamic_cluster_shape, ClusterShape const fallback_cluster_shape, int sm)
372+
ClusterShape const dynamic_cluster_shape, ClusterShape const fallback_cluster_shape)
373373
{
374374
auto cluster1sm = ClusterShape::ClusterShape_1x1x1;
375375
auto cluster2sm = ClusterShape::ClusterShape_2x1x1;
@@ -379,8 +379,20 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100_dynamic_cluster_shape
379379
std::vector<CutlassGemmConfig> candidate_configs;
380380
if ((config & CutlassGemmConfig::FP4_ONLY) != 0)
381381
{
382-
if (schedule != EpilogueScheduleType::TMA)
383-
return {};
382+
if (sm == 100)
383+
{
384+
if (schedule != EpilogueScheduleType::TMA)
385+
return {};
386+
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
387+
MainloopScheduleType::AUTO, schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm});
388+
if (supports_2sm)
389+
{
390+
candidate_configs.push_back(
391+
CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, MainloopScheduleType::AUTO, schedule,
392+
cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm});
393+
}
394+
}
395+
384396
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
385397
MainloopScheduleType::AUTO, schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm});
386398
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
@@ -392,18 +404,6 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100_dynamic_cluster_shape
392404
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
393405
MainloopScheduleType::AUTO, schedule, cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm});
394406
}
395-
if (sm == 103)
396-
{
397-
return candidate_configs;
398-
}
399-
400-
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
401-
MainloopScheduleType::AUTO, schedule, cluster1sm, dynamic_cluster_shape, fallback_cluster_shape, sm});
402-
if (supports_2sm)
403-
{
404-
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
405-
MainloopScheduleType::AUTO, schedule, cluster2sm, dynamic_cluster_shape, fallback_cluster_shape, sm});
406-
}
407407
return candidate_configs;
408408
}
409409

@@ -468,12 +468,12 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(
468468
? ClusterShape::ClusterShape_1x1x1
469469
: ClusterShape::ClusterShape_2x1x1;
470470
auto configs = get_candidate_configs_sm100_dynamic_cluster_shape(
471-
config, schedule, cluster_shape, fallback_cluster_shape, sm);
471+
sm, config, schedule, cluster_shape, fallback_cluster_shape);
472472
candidate_configs.insert(candidate_configs.end(), configs.begin(), configs.end());
473473
}
474474

475475
auto configs = get_candidate_configs_sm100_dynamic_cluster_shape(
476-
config, schedule, ClusterShape::Undefined, ClusterShape::Undefined, sm);
476+
sm, config, schedule, ClusterShape::Undefined, ClusterShape::Undefined);
477477
candidate_configs.insert(candidate_configs.end(), configs.begin(), configs.end());
478478
}
479479
return candidate_configs;

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe
3636
int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream,
3737
int* kernel_occupancy)
3838
{
39-
constexpr auto activation_type = fused_moe::EpilogueRouting<EpilogueTag>(true);
40-
using GemmType = fused_moe::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_, TileN_,
41-
TileK_, Stages_, activation_type>;
39+
constexpr auto activation_type = fused_moe_oss::EpilogueRouting<EpilogueTag>(true);
40+
using GemmType = fused_moe_oss::Fused_Moe_Kernel_sm80<ElementType_, CutlassWeightType_, ElementType_, MaxTileM_,
41+
TileN_, TileK_, Stages_, activation_type>;
4242

4343
// make sure GPU has enough resources..
4444
if (kernel_occupancy != nullptr)
@@ -53,7 +53,7 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe
5353
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
5454
tensorrt_llm::common::check_cuda_error(
5555
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
56-
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global<GemmType>));
56+
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe_oss::run_global<GemmType>));
5757
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
5858
{
5959
// This should mean that
@@ -67,11 +67,11 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe
6767

6868
int max_active_blocks = -1;
6969
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
70-
&max_active_blocks, fused_moe::run_global<GemmType>, GemmType::kThreadCount, smem_size));
70+
&max_active_blocks, fused_moe_oss::run_global<GemmType>, GemmType::kThreadCount, smem_size));
7171
*kernel_occupancy = max_active_blocks;
7272
return;
7373
}
74-
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
74+
int occupancy = std::min(2, fused_moe_oss::fused_gemm_maximum_active_blocks<GemmType>());
7575
int const threadblock_count = multi_processor_count * occupancy;
7676
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
7777
using Arguments = typename GemmType::Arguments;
@@ -83,13 +83,13 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe
8383
if (GemmType::kSmemSize >= (48 << 10))
8484
{
8585
cudaError_t result = cudaFuncSetAttribute(
86-
fused_moe::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
86+
fused_moe_oss::run_global<GemmType>, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize);
8787
TLLM_CHECK_WITH_INFO(result == cudaSuccess,
8888
"Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel");
8989
}
9090
dim3 grid(params.threadblock_count, 1, 1);
9191
dim3 block(GemmType::kThreadCount);
92-
fused_moe::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
92+
fused_moe_oss::run_global<GemmType><<<grid, block, GemmType::kSmemSize, stream>>>(params);
9393
auto result = cudaGetLastError();
9494
TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result));
9595
}

0 commit comments

Comments
 (0)