|
1 | | -// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
2 | | -// Licensed under the Apache License, Version 2.0 (the "License"); |
3 | | -// http://www.apache.org/licenses/LICENSE-2.0 |
4 | | -// Unless required by applicable law or agreed to in writing, software |
5 | | -// distributed under the License is distributed on an "AS IS" BASIS, |
6 | | -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
7 | | -// See the License for the specific language governing permissions and |
8 | | -// limitations under the License. |
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
9 | 16 |
|
10 | 17 | // Inspired by vLLM's moe_align_kernel.cu and ported to TensorRT-LLM |
11 | 18 |
|
12 | | -#include <ATen/ATen.h> |
13 | | -#include <ATen/cuda/Atomic.cuh> |
14 | | -#include <ATen/cuda/CUDAContext.h> |
15 | | -#include <c10/cuda/CUDAGuard.h> |
| 19 | +#include "moeAlignKernels.h" |
| 20 | +#include "tensorrt_llm/common/assert.h" |
| 21 | +#include "tensorrt_llm/common/cudaUtils.h" |
16 | 22 | #include <cub/cub.cuh> |
17 | | -#include <torch/extension.h> |
18 | 23 |
|
19 | 24 | #define CEILDIV(x, y) (((x) + (y) -1) / (y)) |
20 | 25 | #define WARP_SIZE 32 |
21 | 26 |
|
22 | | -namespace auto_deploy |
23 | | -{ |
24 | | -namespace moe |
| 27 | +namespace tensorrt_llm::kernels |
25 | 28 | { |
26 | 29 |
|
27 | 30 | template <typename scalar_t> |
@@ -204,68 +207,74 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(scalar_t const* _ |
204 | 207 | } |
205 | 208 | } |
206 | 209 |
|
207 | | -} // namespace moe |
208 | | -} // namespace auto_deploy |
209 | | - |
210 | | -void moe_align_block_size_cuda(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, |
211 | | - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) |
| 210 | +template <typename scalar_t> |
| 211 | +void invokeMoeAlignBlockSizeTyped(scalar_t const* topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids, |
| 212 | + int32_t* num_tokens_post_pad, int32_t num_experts, int32_t block_size, int32_t numel, int32_t max_num_tokens_padded, |
| 213 | + cudaStream_t stream) |
212 | 214 | { |
213 | | - |
214 | | - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
215 | | - |
216 | 215 | int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; |
217 | 216 | int experts_per_warp = WARP_SIZE; |
218 | 217 | int threads = 1024; |
219 | 218 | threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; |
220 | 219 |
|
221 | 220 | // BlockScan uses 1024 threads and assigns one thread per expert. |
222 | | - TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); |
| 221 | + TLLM_CHECK_WITH_INFO(padded_num_experts < 1024, "padded_num_experts must be less than 1024"); |
223 | 222 |
|
224 | | - AT_DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", |
225 | | - [&] |
226 | | - { |
227 | | - // calc needed amount of shared mem for `cumsum` tensors |
228 | | - auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); |
229 | | - torch::Tensor cumsum_buffer = torch::empty({num_experts + 1}, options_int); |
230 | | - bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); |
231 | | - |
232 | | - if (small_batch_expert_mode) |
233 | | - { |
234 | | - const int32_t threads = std::max((int32_t) num_experts, (int32_t) WARP_SIZE); |
235 | | - const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); |
236 | | - |
237 | | - auto small_batch_expert_kernel |
238 | | - = auto_deploy::moe::moe_align_block_size_small_batch_expert_kernel<scalar_t>; |
239 | | - small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(topk_ids.data_ptr<scalar_t>(), |
240 | | - sorted_token_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(), |
241 | | - num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, topk_ids.numel(), |
242 | | - sorted_token_ids.size(0)); |
243 | | - } |
244 | | - else |
245 | | - { |
246 | | - auto align_kernel = auto_deploy::moe::moe_align_block_size_kernel<scalar_t>; |
247 | | - |
248 | | - size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); |
249 | | - size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); |
250 | | - |
251 | | - align_kernel<<<1, threads, shared_mem_size, stream>>>(topk_ids.data_ptr<scalar_t>(), |
252 | | - sorted_token_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(), |
253 | | - num_tokens_post_pad.data_ptr<int32_t>(), num_experts, padded_num_experts, experts_per_warp, |
254 | | - block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(), sorted_token_ids.size(0)); |
255 | | - |
256 | | - const int block_threads = std::min(256, (int) threads); |
257 | | - const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; |
258 | | - const int max_blocks = 65535; |
259 | | - const int actual_blocks = std::min(num_blocks, max_blocks); |
260 | | - |
261 | | - auto sort_kernel = auto_deploy::moe::count_and_sort_expert_tokens_kernel<scalar_t>; |
262 | | - sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), |
263 | | - sorted_token_ids.data_ptr<int32_t>(), cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel()); |
264 | | - } |
265 | | - }); |
| 223 | + // Allocate temporary cumsum buffer |
| 224 | + int32_t* cumsum_buffer; |
| 225 | + cudaMallocAsync(&cumsum_buffer, (num_experts + 1) * sizeof(int32_t), stream); |
| 226 | + cudaMemsetAsync(cumsum_buffer, 0, (num_experts + 1) * sizeof(int32_t), stream); |
| 227 | + |
| 228 | + bool small_batch_expert_mode = (numel < 1024) && (num_experts <= 64); |
| 229 | + |
| 230 | + if (small_batch_expert_mode) |
| 231 | + { |
| 232 | + const int32_t thread_count = std::max((int32_t) num_experts, (int32_t) WARP_SIZE); |
| 233 | + const int32_t shared_mem_size = ((thread_count + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); |
| 234 | + |
| 235 | + moe_align_block_size_small_batch_expert_kernel<scalar_t><<<1, thread_count, shared_mem_size, stream>>>(topk_ids, |
| 236 | + sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts, block_size, numel, max_num_tokens_padded); |
| 237 | + } |
| 238 | + else |
| 239 | + { |
| 240 | + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); |
| 241 | + size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); |
| 242 | + |
| 243 | + moe_align_block_size_kernel<scalar_t><<<1, threads, shared_mem_size, stream>>>(topk_ids, sorted_token_ids, |
| 244 | + expert_ids, num_tokens_post_pad, num_experts, padded_num_experts, experts_per_warp, block_size, numel, |
| 245 | + cumsum_buffer, max_num_tokens_padded); |
| 246 | + |
| 247 | + int const block_threads = std::min(256, (int) threads); |
| 248 | + int const num_blocks = (numel + block_threads - 1) / block_threads; |
| 249 | + int const max_blocks = 65535; |
| 250 | + int const actual_blocks = std::min(num_blocks, max_blocks); |
| 251 | + |
| 252 | + count_and_sort_expert_tokens_kernel<scalar_t> |
| 253 | + <<<actual_blocks, block_threads, 0, stream>>>(topk_ids, sorted_token_ids, cumsum_buffer, numel); |
| 254 | + } |
| 255 | + |
| 256 | + cudaFreeAsync(cumsum_buffer, stream); |
266 | 257 | } |
267 | 258 |
|
268 | | -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) |
| 259 | +void invokeMoeAlignBlockSize(void const* topk_ids, int32_t topk_ids_dtype_size, int32_t* sorted_token_ids, |
| 260 | + int32_t* expert_ids, int32_t* num_tokens_post_pad, int32_t num_experts, int32_t block_size, int32_t numel, |
| 261 | + int32_t max_num_tokens_padded, cudaStream_t stream) |
269 | 262 | { |
270 | | - m.def("moe_align_block_size", &moe_align_block_size_cuda, "MoE align block size (CUDA)"); |
| 263 | + // Dispatch based on dtype size |
| 264 | + if (topk_ids_dtype_size == sizeof(int32_t)) |
| 265 | + { |
| 266 | + invokeMoeAlignBlockSizeTyped(static_cast<int32_t const*>(topk_ids), sorted_token_ids, expert_ids, |
| 267 | + num_tokens_post_pad, num_experts, block_size, numel, max_num_tokens_padded, stream); |
| 268 | + } |
| 269 | + else if (topk_ids_dtype_size == sizeof(int64_t)) |
| 270 | + { |
| 271 | + invokeMoeAlignBlockSizeTyped(static_cast<int64_t const*>(topk_ids), sorted_token_ids, expert_ids, |
| 272 | + num_tokens_post_pad, num_experts, block_size, numel, max_num_tokens_padded, stream); |
| 273 | + } |
| 274 | + else |
| 275 | + { |
| 276 | + TLLM_CHECK_WITH_INFO(false, "Unsupported topk_ids dtype size: %d", topk_ids_dtype_size); |
| 277 | + } |
271 | 278 | } |
| 279 | + |
| 280 | +} // namespace tensorrt_llm::kernels |
0 commit comments