Skip to content

Commit 990f4db

Browse files
authored
Merge branch 'main' into clean_cuda_graph
2 parents d74ed49 + f666ad2 commit 990f4db

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1368
-515
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ repos:
6666
additional_dependencies:
6767
- tomli
6868
# add ignore words list
69-
args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md", "--skip", "security_scanning/*"]
69+
args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"]
7070
- repo: https://github.com/astral-sh/ruff-pre-commit
7171
rev: v0.9.4
7272
hooks:

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesBlo
130130
{
131131
if (laneIdx < params.mTopK)
132132
{
133-
int offset = warpIdx * MaxNumExperts + params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx];
134-
smemKIdx[offset] = static_cast<int8_t>(laneIdx);
133+
auto expertIdx = params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx];
134+
if (expertIdx != -1)
135+
{
136+
int offset = warpIdx * MaxNumExperts + expertIdx;
137+
smemKIdx[offset] = static_cast<int8_t>(laneIdx);
138+
}
139+
else
140+
{
141+
params.mPtrExpandedIdxToPermutedIdx[warpIdx * params.mTopK + laneIdx] = int32_t{-1};
142+
}
135143
}
136144
}
137145
}

cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
*/
1616

1717
#include "tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h"
18+
#include "tensorrt_llm/thop/thUtils.h"
1819

1920
#include <ATen/ATen.h>
2021
#include <ATen/cuda/CUDAContext.h>
2122
#include <ATen/cuda/EmptyTensor.h>
2223
#include <torch/library.h>
2324

2425
#include <cstdint>
26+
#include <memory>
27+
#include <unordered_map>
2528

2629
namespace torch_ext
2730
{
@@ -316,16 +319,30 @@ class FP8BlockScaleMoeRunner : public torch::CustomClassHolder
316319
{
317320

318321
public:
319-
explicit FP8BlockScaleMoeRunner(int64_t tileTokensDim)
320-
: mTileTokensDim(tileTokensDim)
322+
explicit FP8BlockScaleMoeRunner()
323+
: mSupportedTileN{8, 16, 32, 64}
321324
{
322-
mRunner = std::make_unique<RunnerType>(mDtypeElt, mUseDeepSeekFp8, mTileTokensDim);
325+
for (int tileN : mSupportedTileN)
326+
{
327+
mRunners.emplace(tileN, std::make_unique<RunnerType>(mDtypeElt, mUseDeepSeekFp8, tileN));
328+
}
323329
}
324330

325-
[[nodiscard]] std::vector<int64_t> getValidConfigs(
331+
[[nodiscard]] std::vector<std::vector<int64_t>> getValidConfigs(
326332
int64_t topK, int64_t hiddenSize, int64_t intermediateSize, int64_t numLocalExperts, int64_t numTokens) const
327333
{
328-
return mRunner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
334+
// returns (tileN, config)
335+
std::vector<std::vector<int64_t>> tactics;
336+
for (auto& [tileN, runner] : mRunners)
337+
{
338+
auto config_indices_per_runner
339+
= runner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
340+
for (auto cfg : config_indices_per_runner)
341+
{
342+
tactics.push_back({tileN, cfg});
343+
}
344+
}
345+
return tactics;
329346
}
330347

331348
[[nodiscard]] at::Tensor run(at::optional<at::Tensor> const& routing_logits,
@@ -334,42 +351,48 @@ class FP8BlockScaleMoeRunner : public torch::CustomClassHolder
334351
at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale, int64_t num_experts, int64_t top_k,
335352
std::optional<int64_t> const n_group, std::optional<int64_t> const topk_group, int64_t const intermediate_size,
336353
int64_t const local_expert_offset, int64_t const local_num_experts,
337-
std::optional<double> const routed_scaling_factor, int64_t routing_method_type, int64_t moeConfigIndex,
338-
std::optional<at::Tensor> const& topk_weights, std::optional<at::Tensor> const& topk_ids)
354+
std::optional<double> const routed_scaling_factor, int64_t routing_method_type,
355+
std::vector<int64_t> tile_config_pair, std::optional<at::Tensor> const& topk_weights,
356+
std::optional<at::Tensor> const& topk_ids)
339357
{
358+
// tile_config_pair corresponds to pair (tileN, config)
359+
auto [tileN, config] = std::tie(tile_config_pair[0], tile_config_pair[1]);
340360

341361
// Autotuner has requested a default or 'fallback' config index
342-
if (moeConfigIndex == -1)
362+
if (tileN == -1 || config == -1)
343363
{
344364
auto const num_tokens = hidden_states.sizes()[0];
345365
auto const hidden_size = hidden_states.sizes()[1];
346366

347-
moeConfigIndex = mRunner->getDefaultValidConfigIndex(
367+
float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / local_num_experts;
368+
tileN = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), mSupportedTileN.front(), mSupportedTileN.back());
369+
370+
config = mRunners.at(tileN)->getDefaultValidConfigIndex(
348371
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
349372
}
350373

351374
return run_fp8_block_scale_moe(routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
352375
gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, num_experts, top_k, n_group, topk_group,
353-
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, mTileTokensDim,
354-
routing_method_type, *mRunner, moeConfigIndex, topk_weights, topk_ids);
376+
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tileN,
377+
routing_method_type, *mRunners.at(tileN), config, topk_weights, topk_ids);
355378
}
356379

357380
private:
358381
using RunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
359382

360-
std::unique_ptr<RunnerType> mRunner;
383+
std::vector<int32_t> const mSupportedTileN;
384+
std::unordered_map<int32_t, std::unique_ptr<RunnerType>> mRunners;
361385

362386
btg::Dtype mDtypeElt{btg::Dtype::E4m3}; // FP8 runner so hard-coded
363387
bool mUseDeepSeekFp8{true}; // Always true for BlockScaleMoe
364-
int64_t mTileTokensDim;
365388
};
366389

367390
} // namespace torch_ext
368391

369392
TORCH_LIBRARY_FRAGMENT(trtllm, m)
370393
{
371394
m.class_<torch_ext::FP8BlockScaleMoeRunner>("FP8BlockScaleMoERunner")
372-
.def(torch::init<int64_t>())
395+
.def(torch::init<>())
373396
.def("get_valid_configs", &torch_ext::FP8BlockScaleMoeRunner::getValidConfigs)
374397
.def("run_moe", &torch_ext::FP8BlockScaleMoeRunner::run);
375398
}

cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization32)
217217
/*numExperts=*/32, /*topK=*/8,
218218
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
219219
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
220-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
220+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
221221
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
222222
this->runTest(param);
223223
};
@@ -228,7 +228,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization72)
228228
/*numExperts=*/72, /*topK=*/6,
229229
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
230230
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
231-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
231+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
232232
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
233233
this->runTest(param);
234234
};
@@ -239,7 +239,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384)
239239
/*numExperts=*/384, /*topK=*/8,
240240
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
241241
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
242-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
242+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
243243
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
244244
this->runTest(param);
245245
};
@@ -250,7 +250,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
250250
/*numExperts=*/256, /*topK=*/8,
251251
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
252252
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
253-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
253+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
254254
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
255255
this->runTest(param);
256256
};
@@ -261,7 +261,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput
261261
/*numExperts=*/256, /*topK=*/8,
262262
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192,
263263
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
264-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
264+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
265265
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
266266
this->runTest(param);
267267
};
@@ -272,7 +272,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput
272272
/*numExperts=*/384, /*topK=*/8,
273273
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
274274
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
275-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true,
275+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false,
276276
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
277277
this->runTest(param);
278278
};
@@ -283,7 +283,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
283283
/*numExperts=*/256, /*topK=*/8,
284284
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
285285
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
286-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
286+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
287287
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
288288
this->runTest(param);
289289
};
@@ -294,7 +294,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization)
294294
/*numExperts=*/256, /*topK=*/8,
295295
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
296296
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
297-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
297+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
298298
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
299299
this->runTest(param);
300300
};
@@ -305,7 +305,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384)
305305
/*numExperts=*/384, /*topK=*/8,
306306
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
307307
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
308-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
308+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
309309
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
310310
this->runTest(param);
311311
};
@@ -316,7 +316,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
316316
/*numExperts=*/256, /*topK=*/8,
317317
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
318318
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
319-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
319+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
320320
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
321321
this->runTest(param);
322322
};
@@ -327,7 +327,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384)
327327
/*numExperts=*/384, /*topK=*/8,
328328
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
329329
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
330-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
330+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
331331
/*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
332332
this->runTest(param);
333333
};
@@ -338,7 +338,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2)
338338
/*numExperts=*/256, /*topK=*/2,
339339
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
340340
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
341-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
341+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
342342
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
343343
this->runTest(param);
344344
};
@@ -349,7 +349,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
349349
/*numExperts=*/256, /*topK=*/2,
350350
/*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192,
351351
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
352-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
352+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
353353
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
354354
this->runTest(param);
355355
};
@@ -360,7 +360,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2)
360360
/*numExperts=*/256, /*topK=*/2,
361361
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
362362
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
363-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
363+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false,
364364
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
365365
this->runTest(param);
366366
};
@@ -371,7 +371,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop8)
371371
/*numExperts=*/32, /*topK=*/8,
372372
/*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256,
373373
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
374-
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false,
374+
/*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true,
375375
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
376376
this->runTest(param);
377377
};

0 commit comments

Comments
 (0)