Skip to content

Commit f45aff2

Browse files
authored
Add customized renormalized moe routing kernel for moe cutlass backend (#4955)
Signed-off-by: Christina Zhang <[email protected]>
1 parent c104388 commit f45aff2

File tree

11 files changed

+577
-39
lines changed

11 files changed

+577
-39
lines changed
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
/*
2+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
18+
#include "tensorrt_llm/common/envUtils.h"
19+
#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h"
20+
#include <climits> // For INT_MAX
21+
#include <cooperative_groups.h>
22+
#include <cooperative_groups/reduce.h>
23+
#include <cub/cub.cuh>
24+
#include <cuda/std/limits> // For numeric_limits
25+
#include <math.h>
26+
27+
namespace cg = cooperative_groups;
28+
using namespace tensorrt_llm::common;
29+
30+
namespace tensorrt_llm::kernels
31+
{
32+
33+
static constexpr int BLOCK_SIZE = 1024;
34+
static constexpr int WARP_SIZE = 32;
35+
static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
36+
37+
namespace reduce_topk
38+
{
39+
40+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL))
41+
#define TLLM_GEN_ENABLE_FAST_REDUX
42+
#endif
43+
44+
template <typename T_>
45+
struct TopKRedType
46+
{
47+
using T = T_;
48+
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>,
49+
"Top K reduction only implemented for float, float16 and bfloat16");
50+
51+
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
52+
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
53+
static constexpr int moveBits = (sizeof(T) == 4) ? 32 : 16;
54+
static constexpr int maxIdx = 65535;
55+
TypeCmp compValIdx;
56+
57+
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0)
58+
{
59+
auto valueBits = cub::Traits<T>::TwiddleIn(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
60+
TypeCmp compactTmp = reinterpret_cast<TypeCmp&>(valueBits);
61+
compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx));
62+
// Use 65535 minus idx to give higher priority to elements with smaller indices.
63+
return compactTmp;
64+
}
65+
66+
static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp)
67+
{
68+
// Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits
69+
index = maxIdx - static_cast<int32_t>((cmp & 0xFFFF));
70+
71+
auto compactTmp = cmp >> moveBits;
72+
auto valueBits
73+
= cub::Traits<T>::TwiddleOut(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
74+
value = reinterpret_cast<T&>(valueBits);
75+
}
76+
77+
__host__ __device__ TopKRedType() = default;
78+
79+
__host__ __device__ TopKRedType(T val, int32_t idx)
80+
: compValIdx(makeCmpVal(val, idx))
81+
{
82+
}
83+
84+
__host__ __device__ operator TypeCmp() const noexcept
85+
{
86+
return compValIdx;
87+
}
88+
89+
__device__ inline TypeCmp reduce(cg::thread_block_tile<WARP_SIZE> const& warp)
90+
{
91+
#if defined(TLLM_GEN_ENABLE_FAST_REDUX)
92+
static constexpr bool UseCg = false;
93+
#else
94+
static constexpr bool UseCg = true;
95+
#endif
96+
if constexpr (UseCg || sizeof(TypeCmp) == 8)
97+
{
98+
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
99+
}
100+
else
101+
{
102+
TypeCmp result;
103+
asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx));
104+
return result;
105+
}
106+
}
107+
};
108+
109+
////////////////////////////////////////////////////////////////////////////////////////////////////
110+
111+
template <int K_, bool Enable_>
112+
struct TopKIdx
113+
{
114+
// by default, empty
115+
};
116+
117+
template <int K_>
118+
struct TopKIdx<K_, true>
119+
{
120+
static constexpr int K = K_;
121+
int32_t val[K];
122+
};
123+
124+
////////////////////////////////////////////////////////////////////////////////////////////////////
125+
126+
#define TOPK_SWAP(I, J) \
127+
{ \
128+
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
129+
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
130+
topK[I].compValIdx = pairMax; \
131+
topK[J].compValIdx = pairMin; \
132+
}
133+
134+
template <int N, typename RedType>
135+
struct Sort;
136+
137+
template <typename RedType>
138+
struct Sort<1, RedType>
139+
{
140+
static __device__ void run(RedType* topK) {}
141+
};
142+
143+
template <typename RedType>
144+
struct Sort<2, RedType>
145+
{
146+
static __device__ void run(RedType* topK)
147+
{
148+
TOPK_SWAP(0, 1);
149+
}
150+
};
151+
152+
template <typename RedType>
153+
struct Sort<3, RedType>
154+
{
155+
static __device__ void run(RedType* topK)
156+
{
157+
TOPK_SWAP(0, 1);
158+
TOPK_SWAP(1, 2);
159+
TOPK_SWAP(0, 1);
160+
}
161+
};
162+
163+
template <typename RedType>
164+
struct Sort<4, RedType>
165+
{
166+
static __device__ void run(RedType* topK)
167+
{
168+
TOPK_SWAP(0, 2);
169+
TOPK_SWAP(1, 3);
170+
TOPK_SWAP(0, 1);
171+
TOPK_SWAP(2, 3);
172+
TOPK_SWAP(1, 2);
173+
}
174+
};
175+
176+
template <int K, typename Type, int N, bool IsSorted = false>
177+
__device__ void reduceTopK(cg::thread_block_tile<WARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
178+
Type (&value)[N], int32_t (&idx)[N], Type minValue)
179+
{
180+
static_assert(K > 0, "Top K must have K > 0");
181+
static_assert(K < WARP_SIZE, "Top K must have K < WARP_SIZE");
182+
static_assert(N > 0, "Top K must have N > 0");
183+
static_assert(N < 5, "Only support candidates number less than or equal to 128");
184+
using RedType = TopKRedType<Type>;
185+
RedType topK[N];
186+
#pragma unroll
187+
for (int nn = 0; nn < N; ++nn)
188+
{
189+
topK[nn] = RedType{value[nn], idx[nn]};
190+
}
191+
192+
if constexpr (!IsSorted)
193+
{
194+
Sort<N, RedType>::run(topK);
195+
}
196+
typename RedType::TypeCmp packedMax{};
197+
#pragma unroll
198+
for (int kk = 0; kk < K; ++kk)
199+
{
200+
bool update = kk > 0 && packedMax == topK[0].compValIdx;
201+
#pragma unroll
202+
for (int nn = 0; nn < N; ++nn)
203+
{
204+
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
205+
}
206+
// get the next largest value
207+
packedMax = topK[0].reduce(warp);
208+
RedType::unpack(out[kk], outIdx[kk], packedMax);
209+
}
210+
};
211+
212+
#undef TOPK_SWAP
213+
214+
} // end of namespace reduce_topk
215+
216+
////////////////////////////////////////////////////////////////////////////////////////////////////
217+
218+
template <typename T>
219+
__device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score, int32_t laneIdx, int32_t NumTopExperts)
220+
{
221+
T maxScore = T{-INFINITY};
222+
if (laneIdx < NumTopExperts)
223+
{
224+
maxScore = score >= maxScore ? score : maxScore;
225+
}
226+
maxScore = cg::reduce(warp, maxScore, cg::greater<T>());
227+
228+
float sumScore = float{0.f};
229+
float newScore;
230+
// Get the summation of scores for each token
231+
if (laneIdx < NumTopExperts)
232+
{
233+
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
234+
newScore = static_cast<float>(exp(newScore));
235+
sumScore += newScore;
236+
}
237+
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
238+
239+
if (laneIdx < NumTopExperts)
240+
{
241+
score = static_cast<T>(newScore / sumScore);
242+
}
243+
244+
return score;
245+
}
246+
247+
////////////////////////////////////////////////////////////////////////////////////////////////////
248+
249+
template <typename InputT, typename OutputT, typename IdxT, int MaxNumExperts, int MaxNumTopExperts>
250+
__global__ void renormMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices,
251+
int32_t const numTokens, int32_t const numExperts, int32_t const topK)
252+
{
253+
254+
uint32_t const blockRank = blockIdx.x;
255+
uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x;
256+
uint32_t const warpIdx = tIdx / WARP_SIZE;
257+
uint32_t const laneIdx = tIdx % WARP_SIZE;
258+
uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK;
259+
auto block = cg::this_thread_block();
260+
auto warp = cg::tiled_partition<WARP_SIZE>(block);
261+
262+
InputT minScore = InputT{-INFINITY};
263+
for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
264+
{
265+
auto scoreOffset = tokenId * numExperts;
266+
auto outputOffset = tokenId * topK;
267+
InputT inputScore[MaxNumExperts / WARP_SIZE];
268+
IdxT inputIndex[MaxNumExperts / WARP_SIZE];
269+
270+
InputT warpTopKScore[MaxNumTopExperts];
271+
IdxT warpTopKExpertIdx[MaxNumTopExperts];
272+
273+
// Load scores and indices for this warp
274+
for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i)
275+
{
276+
auto expertIdx = i * WARP_SIZE + laneIdx;
277+
inputScore[i]
278+
= expertIdx < numExperts ? static_cast<InputT>(routerLogits[scoreOffset + expertIdx]) : minScore;
279+
inputIndex[i] = expertIdx;
280+
}
281+
282+
// Reduce topK scores and indices for this warp
283+
reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
284+
285+
// Perform softmax on topK scores
286+
auto score = calcSoftmax(warp,
287+
laneIdx < topK ? static_cast<float>(warpTopKScore[laneIdx]) : static_cast<float>(minScore), laneIdx, topK);
288+
if (laneIdx < topK)
289+
{
290+
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(score);
291+
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
292+
}
293+
} // end for tokenId
294+
}
295+
296+
int nextPowerOfTwo(int num)
297+
{
298+
if (num <= 0)
299+
{
300+
return 1; // Handle invalid input
301+
}
302+
int power = 1;
303+
while (power < num)
304+
{
305+
// Check for overflow before shifting
306+
if (power > INT_MAX / 2)
307+
{
308+
return power;
309+
}
310+
power <<= 1;
311+
}
312+
return power;
313+
}
314+
315+
#define CASE(MAX_NUM_EXPERTS) \
316+
case MAX_NUM_EXPERTS: \
317+
switch (maxNumTopExperts) \
318+
{ \
319+
case 1: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 1>; break; \
320+
case 2: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 2>; break; \
321+
case 4: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 4>; break; \
322+
case 8: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 8>; break; \
323+
default: kernelInstance = nullptr; break; \
324+
} \
325+
break;
326+
327+
template <typename InputT, typename OutputT, typename IdxT>
328+
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
329+
int64_t const numExperts, int64_t const topK, cudaStream_t const stream)
330+
{
331+
332+
const uint32_t maxNumBlocks = 1024;
333+
const uint32_t numBlocks = std::min(static_cast<uint32_t>((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks);
334+
335+
uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts);
336+
uint32_t maxNumTopExperts = nextPowerOfTwo(topK);
337+
338+
auto* kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, 128, 8>;
339+
340+
switch (maxNumExperts)
341+
{
342+
CASE(32)
343+
CASE(64)
344+
CASE(96)
345+
CASE(128)
346+
default: kernelInstance = nullptr; break;
347+
}
348+
349+
if (kernelInstance == nullptr)
350+
{
351+
TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance.");
352+
}
353+
354+
dim3 renormMoeRoutingGridDim(numBlocks);
355+
dim3 renormMoeRoutingBlockDim(BLOCK_SIZE);
356+
cudaLaunchConfig_t config;
357+
config.gridDim = renormMoeRoutingGridDim;
358+
config.blockDim = renormMoeRoutingBlockDim;
359+
config.dynamicSmemBytes = 0;
360+
config.stream = stream;
361+
cudaLaunchAttribute attrs[1];
362+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
363+
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
364+
config.numAttrs = 1;
365+
config.attrs = attrs;
366+
cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast<int32_t>(numTokens),
367+
static_cast<int32_t>(numExperts), static_cast<int32_t>(topK));
368+
sync_check_cuda_error(stream);
369+
}
370+
371+
#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT) \
372+
template void invokeRenormMoeRouting<InputT, OutputT, IdxT>(InputT * routerLogits, OutputT * topkValues, \
373+
IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, \
374+
cudaStream_t const stream);
375+
376+
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t);
377+
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t);
378+
#ifdef ENABLE_BF16
379+
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t);
380+
#endif
381+
382+
} // namespace tensorrt_llm::kernels

0 commit comments

Comments
 (0)