@@ -38,17 +38,18 @@ static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
3838
3939// //////////////////////////////////////////////////////////////////////////////////////////////////
4040
41- template <typename T>
42- __device__ T calcSoftmax (cg::thread_block_tile<WARP_SIZE> const & warp, T score, int32_t laneIdx, int32_t NumTopExperts)
41+ template <typename DataType>
42+ __device__ DataType calcSoftmax (
43+ cg::thread_block_tile<WARP_SIZE> const & warp, DataType score, int32_t laneIdx, int32_t NumTopExperts)
4344{
44- T maxScore = T{ -INFINITY} ;
45+ float maxScore = -INFINITY;
4546 if (laneIdx < NumTopExperts)
4647 {
47- maxScore = score >= maxScore ? score : maxScore;
48+ maxScore = float ( score) >= maxScore ? float ( score) : maxScore;
4849 }
49- maxScore = cg::reduce (warp, maxScore, cg::greater<T >());
50+ maxScore = cg::reduce (warp, maxScore, cg::greater<float >());
5051
51- float sumScore{ 0 .f } ;
52+ float sumScore = 0 .f ;
5253 float newScore;
5354 // Get the summation of scores for each token
5455 if (laneIdx < NumTopExperts)
@@ -61,7 +62,7 @@ __device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score,
6162
6263 if (laneIdx < NumTopExperts)
6364 {
64- score = static_cast <T >(newScore / sumScore);
65+ score = static_cast <DataType >(newScore / sumScore);
6566 }
6667
6768 return score;
@@ -70,31 +71,35 @@ __device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score,
7071template <typename DataType, int VecSize>
7172__device__ void calcSoftmax (cg::thread_block_tile<WARP_SIZE> const & warp, DataType (&scores)[VecSize])
7273{
73- DataType maxScore = DataType{-INFINITY};
74- DataType sumScore = DataType{ 0 . f } ;
75-
74+ // Compute in float to support half/bfloat16 inputs safely.
75+ float maxScore = -INFINITY ;
76+ float sumScore = 0 . f ;
7677 // Get the max score for each token
7778#pragma unroll
7879 for (int i = 0 ; i < VecSize; ++i)
7980 {
80- maxScore = scores[i] >= maxScore ? scores[i] : maxScore;
81+ float si = static_cast <float >(scores[i]);
82+ maxScore = si >= maxScore ? si : maxScore;
8183 }
82- maxScore = cg::reduce (warp, maxScore, cg::greater<DataType >());
84+ maxScore = cg::reduce (warp, maxScore, cg::greater<float >());
8385
8486 // Get the summation of scores for each token
8587#pragma unroll
8688 for (int i = 0 ; i < VecSize; ++i)
8789 {
88- scores[i] = static_cast <DataType>(exp (scores[i] - maxScore));
89- sumScore += scores[i];
90+ float si = static_cast <float >(scores[i]);
91+ float e = expf (si - maxScore);
92+ scores[i] = static_cast <DataType>(e);
93+ sumScore += e;
9094 }
91- sumScore = cg::reduce (warp, sumScore, cg::plus<DataType >());
95+ sumScore = cg::reduce (warp, sumScore, cg::plus<float >());
9296
9397 // Normalize the scores
9498#pragma unroll
9599 for (int i = 0 ; i < VecSize; ++i)
96100 {
97- scores[i] = static_cast <DataType>(scores[i] / sumScore);
101+ float si = static_cast <float >(scores[i]) / sumScore;
102+ scores[i] = static_cast <DataType>(si);
98103 }
99104}
100105
@@ -205,7 +210,7 @@ int nextPowerOfTwo(int num)
205210 break ;
206211
207212template <typename InputT, typename OutputT, typename IdxT, bool DoSoftmaxBeforeTopK>
208- void invokeRenormMoeRouting (InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
213+ void invokeCustomMoeRouting (InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
209214 int64_t const numExperts, int64_t const topK, cudaStream_t const stream)
210215{
211216
@@ -249,20 +254,25 @@ void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* top
249254}
250255
251256#define INSTANTIATE_RENORM_MOE_ROUTING (InputT, OutputT, IdxT, DoSoftmaxBeforeTopK ) \
252- template void invokeRenormMoeRouting <InputT, OutputT, IdxT, DoSoftmaxBeforeTopK>(InputT * routerLogits, \
257+ template void invokeCustomMoeRouting <InputT, OutputT, IdxT, DoSoftmaxBeforeTopK>(InputT * routerLogits, \
253258 OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \
254259 int64_t const topK, cudaStream_t const stream);
255260
256261INSTANTIATE_RENORM_MOE_ROUTING (float , float , int32_t , false );
257262INSTANTIATE_RENORM_MOE_ROUTING (half, float , int32_t , false );
258- #ifdef ENABLE_BF16
259- INSTANTIATE_RENORM_MOE_ROUTING (__nv_bfloat16, float , int32_t , false );
260- #endif
261-
262263INSTANTIATE_RENORM_MOE_ROUTING (float , float , int32_t , true );
263264INSTANTIATE_RENORM_MOE_ROUTING (half, float , int32_t , true );
265+
264266#ifdef ENABLE_BF16
267+ INSTANTIATE_RENORM_MOE_ROUTING (__nv_bfloat16, float , int32_t , false );
268+ INSTANTIATE_RENORM_MOE_ROUTING (float , __nv_bfloat16, int32_t , false );
269+ INSTANTIATE_RENORM_MOE_ROUTING (half, __nv_bfloat16, int32_t , false );
270+ INSTANTIATE_RENORM_MOE_ROUTING (__nv_bfloat16, __nv_bfloat16, int32_t , false );
271+
265272INSTANTIATE_RENORM_MOE_ROUTING (__nv_bfloat16, float , int32_t , true );
273+ INSTANTIATE_RENORM_MOE_ROUTING (float , __nv_bfloat16, int32_t , true );
274+ INSTANTIATE_RENORM_MOE_ROUTING (half, __nv_bfloat16, int32_t , true );
275+ INSTANTIATE_RENORM_MOE_ROUTING (__nv_bfloat16, __nv_bfloat16, int32_t , true );
266276#endif
267277
268278} // namespace tensorrt_llm::kernels
0 commit comments