Skip to content

Commit 5649898

Browse files
authored
[TRTLLM-9082][feat] AutoDeploy: Move the moe Align kernel to AOT (#9106)
Signed-off-by: Chenghao Zhang <[email protected]>
1 parent eb7792e commit 5649898

File tree

5 files changed

+200
-103
lines changed

5 files changed

+200
-103
lines changed
Lines changed: 78 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
1-
// Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2-
// Licensed under the Apache License, Version 2.0 (the "License");
3-
// http://www.apache.org/licenses/LICENSE-2.0
4-
// Unless required by applicable law or agreed to in writing, software
5-
// distributed under the License is distributed on an "AS IS" BASIS,
6-
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7-
// See the License for the specific language governing permissions and
8-
// limitations under the License.
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
916

1017
// Inspired by vLLM's moe_align_kernel.cu and ported to TensorRT-LLM
1118

12-
#include <ATen/ATen.h>
13-
#include <ATen/cuda/Atomic.cuh>
14-
#include <ATen/cuda/CUDAContext.h>
15-
#include <c10/cuda/CUDAGuard.h>
19+
#include "moeAlignKernels.h"
20+
#include "tensorrt_llm/common/assert.h"
21+
#include "tensorrt_llm/common/cudaUtils.h"
1622
#include <cub/cub.cuh>
17-
#include <torch/extension.h>
1823

1924
#define CEILDIV(x, y) (((x) + (y) -1) / (y))
2025
#define WARP_SIZE 32
2126

22-
namespace auto_deploy
23-
{
24-
namespace moe
27+
namespace tensorrt_llm::kernels
2528
{
2629

2730
template <typename scalar_t>
@@ -204,68 +207,74 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(scalar_t const* _
204207
}
205208
}
206209

207-
} // namespace moe
208-
} // namespace auto_deploy
209-
210-
void moe_align_block_size_cuda(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
211-
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad)
210+
template <typename scalar_t>
211+
void invokeMoeAlignBlockSizeTyped(scalar_t const* topk_ids, int32_t* sorted_token_ids, int32_t* expert_ids,
212+
int32_t* num_tokens_post_pad, int32_t num_experts, int32_t block_size, int32_t numel, int32_t max_num_tokens_padded,
213+
cudaStream_t stream)
212214
{
213-
214-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
215-
216215
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
217216
int experts_per_warp = WARP_SIZE;
218217
int threads = 1024;
219218
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
220219

221220
// BlockScan uses 1024 threads and assigns one thread per expert.
222-
TORCH_CHECK(padded_num_experts < 1024, "padded_num_experts must be less than 1024");
221+
TLLM_CHECK_WITH_INFO(padded_num_experts < 1024, "padded_num_experts must be less than 1024");
223222

224-
AT_DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel",
225-
[&]
226-
{
227-
// calc needed amount of shared mem for `cumsum` tensors
228-
auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
229-
torch::Tensor cumsum_buffer = torch::empty({num_experts + 1}, options_int);
230-
bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);
231-
232-
if (small_batch_expert_mode)
233-
{
234-
const int32_t threads = std::max((int32_t) num_experts, (int32_t) WARP_SIZE);
235-
const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
236-
237-
auto small_batch_expert_kernel
238-
= auto_deploy::moe::moe_align_block_size_small_batch_expert_kernel<scalar_t>;
239-
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(topk_ids.data_ptr<scalar_t>(),
240-
sorted_token_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
241-
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, topk_ids.numel(),
242-
sorted_token_ids.size(0));
243-
}
244-
else
245-
{
246-
auto align_kernel = auto_deploy::moe::moe_align_block_size_kernel<scalar_t>;
247-
248-
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
249-
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
250-
251-
align_kernel<<<1, threads, shared_mem_size, stream>>>(topk_ids.data_ptr<scalar_t>(),
252-
sorted_token_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
253-
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, padded_num_experts, experts_per_warp,
254-
block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(), sorted_token_ids.size(0));
255-
256-
const int block_threads = std::min(256, (int) threads);
257-
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
258-
const int max_blocks = 65535;
259-
const int actual_blocks = std::min(num_blocks, max_blocks);
260-
261-
auto sort_kernel = auto_deploy::moe::count_and_sort_expert_tokens_kernel<scalar_t>;
262-
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(),
263-
sorted_token_ids.data_ptr<int32_t>(), cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
264-
}
265-
});
223+
// Allocate temporary cumsum buffer
224+
int32_t* cumsum_buffer;
225+
cudaMallocAsync(&cumsum_buffer, (num_experts + 1) * sizeof(int32_t), stream);
226+
cudaMemsetAsync(cumsum_buffer, 0, (num_experts + 1) * sizeof(int32_t), stream);
227+
228+
bool small_batch_expert_mode = (numel < 1024) && (num_experts <= 64);
229+
230+
if (small_batch_expert_mode)
231+
{
232+
const int32_t thread_count = std::max((int32_t) num_experts, (int32_t) WARP_SIZE);
233+
const int32_t shared_mem_size = ((thread_count + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
234+
235+
moe_align_block_size_small_batch_expert_kernel<scalar_t><<<1, thread_count, shared_mem_size, stream>>>(topk_ids,
236+
sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts, block_size, numel, max_num_tokens_padded);
237+
}
238+
else
239+
{
240+
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
241+
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
242+
243+
moe_align_block_size_kernel<scalar_t><<<1, threads, shared_mem_size, stream>>>(topk_ids, sorted_token_ids,
244+
expert_ids, num_tokens_post_pad, num_experts, padded_num_experts, experts_per_warp, block_size, numel,
245+
cumsum_buffer, max_num_tokens_padded);
246+
247+
int const block_threads = std::min(256, (int) threads);
248+
int const num_blocks = (numel + block_threads - 1) / block_threads;
249+
int const max_blocks = 65535;
250+
int const actual_blocks = std::min(num_blocks, max_blocks);
251+
252+
count_and_sort_expert_tokens_kernel<scalar_t>
253+
<<<actual_blocks, block_threads, 0, stream>>>(topk_ids, sorted_token_ids, cumsum_buffer, numel);
254+
}
255+
256+
cudaFreeAsync(cumsum_buffer, stream);
266257
}
267258

268-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
259+
void invokeMoeAlignBlockSize(void const* topk_ids, int32_t topk_ids_dtype_size, int32_t* sorted_token_ids,
260+
int32_t* expert_ids, int32_t* num_tokens_post_pad, int32_t num_experts, int32_t block_size, int32_t numel,
261+
int32_t max_num_tokens_padded, cudaStream_t stream)
269262
{
270-
m.def("moe_align_block_size", &moe_align_block_size_cuda, "MoE align block size (CUDA)");
263+
// Dispatch based on dtype size
264+
if (topk_ids_dtype_size == sizeof(int32_t))
265+
{
266+
invokeMoeAlignBlockSizeTyped(static_cast<int32_t const*>(topk_ids), sorted_token_ids, expert_ids,
267+
num_tokens_post_pad, num_experts, block_size, numel, max_num_tokens_padded, stream);
268+
}
269+
else if (topk_ids_dtype_size == sizeof(int64_t))
270+
{
271+
invokeMoeAlignBlockSizeTyped(static_cast<int64_t const*>(topk_ids), sorted_token_ids, expert_ids,
272+
num_tokens_post_pad, num_experts, block_size, numel, max_num_tokens_padded, stream);
273+
}
274+
else
275+
{
276+
TLLM_CHECK_WITH_INFO(false, "Unsupported topk_ids dtype size: %d", topk_ids_dtype_size);
277+
}
271278
}
279+
280+
} // namespace tensorrt_llm::kernels
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include <cuda_runtime.h>
20+
#include <stdint.h>
21+
22+
namespace tensorrt_llm::kernels
23+
{
24+
25+
/**
26+
* @brief Aligns token distribution across experts to be compatible with block size for matrix multiplication.
27+
*
28+
* This kernel sorts tokens by expert assignment and pads the distribution to match block size requirements.
29+
* Inspired by vLLM's moe_align_kernel and ported to TensorRT-LLM.
30+
*
31+
* @param topk_ids Input tensor with expert IDs per token [total_tokens, top_k]
32+
* @param topk_ids_dtype_size Size of the dtype (e.g., sizeof(int32_t) or sizeof(int64_t))
33+
* @param sorted_token_ids Output tensor for sorted token indices
34+
* @param expert_ids Output tensor for expert IDs per block
35+
* @param num_tokens_post_pad Output tensor for total tokens after padding (single int32)
36+
* @param num_experts Total number of experts
37+
* @param block_size Block size for matrix multiplication alignment
38+
* @param numel Total number of elements in topk_ids (topk_ids.numel())
39+
* @param max_num_tokens_padded Maximum number of tokens after padding (sorted_token_ids.size(0))
40+
* @param stream CUDA stream for kernel execution
41+
*/
42+
void invokeMoeAlignBlockSize(void const* topk_ids, int32_t topk_ids_dtype_size, int32_t* sorted_token_ids,
43+
int32_t* expert_ids, int32_t* num_tokens_post_pad, int32_t num_experts, int32_t block_size, int32_t numel,
44+
int32_t max_num_tokens_padded, cudaStream_t stream);
45+
46+
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ add_library(
7878
moeCommOp.cpp
7979
moeAlltoAllOp.cpp
8080
moeLoadBalanceOp.cpp
81+
moeAlignOp.cpp
8182
mxFp4BlockScaleMoe.cpp
8283
mxFp8Quantize.cpp
8384
fp8BlockScaleMoe.cpp
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/kernels/moeAlignKernels.h"
18+
#include "thUtils.h"
19+
#include <torch/extension.h>
20+
21+
namespace tk = tensorrt_llm::kernels;
22+
23+
namespace torch_ext
24+
{
25+
26+
void moeAlignBlockSizeOp(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
27+
torch::Tensor sorted_token_ids, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad)
28+
{
29+
// Validate inputs
30+
CHECK_TH_CUDA(topk_ids);
31+
CHECK_CONTIGUOUS(topk_ids);
32+
CHECK_INPUT(sorted_token_ids, torch::kInt32);
33+
CHECK_INPUT(expert_ids, torch::kInt32);
34+
CHECK_INPUT(num_tokens_post_pad, torch::kInt32);
35+
36+
TORCH_CHECK(topk_ids.scalar_type() == torch::kInt32 || topk_ids.scalar_type() == torch::kInt64,
37+
"topk_ids must be int32 or int64");
38+
39+
auto stream = at::cuda::getCurrentCUDAStream();
40+
41+
tk::invokeMoeAlignBlockSize(topk_ids.data_ptr(), topk_ids.element_size(), sorted_token_ids.data_ptr<int32_t>(),
42+
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), static_cast<int32_t>(num_experts),
43+
static_cast<int32_t>(block_size), static_cast<int32_t>(topk_ids.numel()),
44+
static_cast<int32_t>(sorted_token_ids.size(0)), stream);
45+
}
46+
47+
} // namespace torch_ext
48+
49+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
50+
{
51+
m.def(
52+
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, "
53+
"Tensor(a!) sorted_token_ids, Tensor(a!) expert_ids, Tensor(a!) num_tokens_post_pad) -> ()");
54+
}
55+
56+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
57+
{
58+
m.impl("moe_align_block_size", &torch_ext::moeAlignBlockSizeOp);
59+
}
Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,19 @@
11
"""
2-
Build moe_align CUDA extension eagerly with a persistent build directory
3-
(same workflow as agent_ops/load_moe.py).
4-
"""
5-
6-
import os
7-
import tempfile
8-
9-
import torch
10-
from torch.utils.cpp_extension import load
2+
AOT-compiled moe_align CUDA kernel.
113
12-
# Recommend explicit arch list so NVCC targets the right GPUs. You can override via env.
13-
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "8.0;8.6;8.9;9.0")
4+
The moe_align kernel is now compiled ahead-of-time (AOT) as part of the main
5+
TensorRT-LLM build instead of being JIT-compiled on first use. This reduces
6+
startup time and avoids compilation overhead.
147
15-
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
8+
The kernel implementation is in:
9+
- cpp/tensorrt_llm/kernels/moeAlignKernels.cu
10+
- cpp/tensorrt_llm/kernels/moeAlignKernels.h
1611
17-
# Use system temp directory to avoid environment variable dependency
18-
BUILD_DIR = os.path.join(tempfile.gettempdir(), "ad_cache", "auto_deploy", "fused_moe", "moe_align")
19-
os.makedirs(BUILD_DIR, exist_ok=True)
12+
The torch binding is in:
13+
- cpp/tensorrt_llm/thop/moeAlignOp.cpp
14+
"""
2015

21-
moe_align_ext = load(
22-
name="moe_align_ext",
23-
sources=[os.path.join(THIS_DIR, "moe_align_kernel.cu")],
24-
extra_cflags=["-O3"],
25-
extra_cuda_cflags=[
26-
"-O3",
27-
"--use_fast_math",
28-
"-U__CUDA_NO_HALF_OPERATORS__",
29-
"-U__CUDA_NO_HALF_CONVERSIONS__",
30-
"--expt-relaxed-constexpr",
31-
# Optional: "-Xptxas=-v",
32-
],
33-
verbose=True,
34-
with_cuda=True,
35-
build_directory=BUILD_DIR,
36-
is_python_module=True,
37-
)
16+
import torch
3817

3918

4019
def moe_align_block_size(
@@ -46,7 +25,7 @@ def moe_align_block_size(
4625
num_tokens_post_pad: torch.Tensor,
4726
):
4827
"""
49-
Wrapper for the CUDA moe_align_block_size function.
28+
Wrapper for the AOT-compiled moe_align_block_size function.
5029
5130
Aligns the token distribution across experts to be compatible with block
5231
size for matrix multiplication.
@@ -64,6 +43,7 @@ def moe_align_block_size(
6443
raise ValueError("topk_ids must be a CUDA tensor")
6544
if not topk_ids.is_contiguous():
6645
topk_ids = topk_ids.contiguous()
46+
6747
for t, name in [
6848
(sorted_token_ids, "sorted_token_ids"),
6949
(expert_ids, "expert_ids"),
@@ -73,13 +53,15 @@ def moe_align_block_size(
7353
raise ValueError(f"{name} must be a CUDA tensor")
7454
if not t.is_contiguous():
7555
raise ValueError(f"{name} must be contiguous")
56+
7657
if (
7758
sorted_token_ids.dtype != torch.int32
7859
or expert_ids.dtype != torch.int32
7960
or num_tokens_post_pad.dtype != torch.int32
8061
):
8162
raise TypeError("sorted_token_ids, expert_ids, num_tokens_post_pad must be int32 tensors")
8263

83-
moe_align_ext.moe_align_block_size(
64+
# Call the AOT-compiled kernel via torch ops
65+
torch.ops.trtllm.moe_align_block_size(
8466
topk_ids, num_experts, block_size, sorted_token_ids, expert_ids, num_tokens_post_pad
8567
)

0 commit comments

Comments
 (0)