|
| 1 | +/* |
| 2 | + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "tensorrt_llm/common/cudaTypeUtils.cuh" |
| 18 | +#include "tensorrt_llm/common/envUtils.h" |
| 19 | +#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h" |
| 20 | +#include <climits> // For INT_MAX |
| 21 | +#include <cooperative_groups.h> |
| 22 | +#include <cooperative_groups/reduce.h> |
| 23 | +#include <cub/cub.cuh> |
| 24 | +#include <cuda/std/limits> // For numeric_limits |
| 25 | +#include <math.h> |
| 26 | + |
| 27 | +namespace cg = cooperative_groups; |
| 28 | +using namespace tensorrt_llm::common; |
| 29 | + |
| 30 | +namespace tensorrt_llm::kernels |
| 31 | +{ |
| 32 | + |
| 33 | +static constexpr int BLOCK_SIZE = 1024; |
| 34 | +static constexpr int WARP_SIZE = 32; |
| 35 | +static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; |
| 36 | + |
| 37 | +namespace reduce_topk |
| 38 | +{ |
| 39 | + |
| 40 | +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) |
| 41 | +#define TLLM_GEN_ENABLE_FAST_REDUX |
| 42 | +#endif |
| 43 | + |
| 44 | +template <typename T_> |
| 45 | +struct TopKRedType |
| 46 | +{ |
| 47 | + using T = T_; |
| 48 | + static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>, |
| 49 | + "Top K reduction only implemented for float, float16 and bfloat16"); |
| 50 | + |
| 51 | + using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>; |
| 52 | + using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>; |
| 53 | + static constexpr int moveBits = (sizeof(T) == 4) ? 32 : 16; |
| 54 | + static constexpr int maxIdx = 65535; |
| 55 | + TypeCmp compValIdx; |
| 56 | + |
| 57 | + static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) |
| 58 | + { |
| 59 | + auto valueBits = cub::Traits<T>::TwiddleIn(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val)); |
| 60 | + TypeCmp compactTmp = reinterpret_cast<TypeCmp&>(valueBits); |
| 61 | + compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx)); |
| 62 | + // Use 65535 minus idx to give higher priority to elements with smaller indices. |
| 63 | + return compactTmp; |
| 64 | + } |
| 65 | + |
| 66 | + static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp) |
| 67 | + { |
| 68 | + // Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits |
| 69 | + index = maxIdx - static_cast<int32_t>((cmp & 0xFFFF)); |
| 70 | + |
| 71 | + auto compactTmp = cmp >> moveBits; |
| 72 | + auto valueBits |
| 73 | + = cub::Traits<T>::TwiddleOut(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp)); |
| 74 | + value = reinterpret_cast<T&>(valueBits); |
| 75 | + } |
| 76 | + |
| 77 | + __host__ __device__ TopKRedType() = default; |
| 78 | + |
| 79 | + __host__ __device__ TopKRedType(T val, int32_t idx) |
| 80 | + : compValIdx(makeCmpVal(val, idx)) |
| 81 | + { |
| 82 | + } |
| 83 | + |
| 84 | + __host__ __device__ operator TypeCmp() const noexcept |
| 85 | + { |
| 86 | + return compValIdx; |
| 87 | + } |
| 88 | + |
| 89 | + __device__ inline TypeCmp reduce(cg::thread_block_tile<WARP_SIZE> const& warp) |
| 90 | + { |
| 91 | +#if defined(TLLM_GEN_ENABLE_FAST_REDUX) |
| 92 | + static constexpr bool UseCg = false; |
| 93 | +#else |
| 94 | + static constexpr bool UseCg = true; |
| 95 | +#endif |
| 96 | + if constexpr (UseCg || sizeof(TypeCmp) == 8) |
| 97 | + { |
| 98 | + return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{}); |
| 99 | + } |
| 100 | + else |
| 101 | + { |
| 102 | + TypeCmp result; |
| 103 | + asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx)); |
| 104 | + return result; |
| 105 | + } |
| 106 | + } |
| 107 | +}; |
| 108 | + |
| 109 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 110 | + |
| 111 | +template <int K_, bool Enable_> |
| 112 | +struct TopKIdx |
| 113 | +{ |
| 114 | + // by default, empty |
| 115 | +}; |
| 116 | + |
| 117 | +template <int K_> |
| 118 | +struct TopKIdx<K_, true> |
| 119 | +{ |
| 120 | + static constexpr int K = K_; |
| 121 | + int32_t val[K]; |
| 122 | +}; |
| 123 | + |
| 124 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 125 | + |
| 126 | +#define TOPK_SWAP(I, J) \ |
| 127 | + { \ |
| 128 | + auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ |
| 129 | + auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ |
| 130 | + topK[I].compValIdx = pairMax; \ |
| 131 | + topK[J].compValIdx = pairMin; \ |
| 132 | + } |
| 133 | + |
| 134 | +template <int N, typename RedType> |
| 135 | +struct Sort; |
| 136 | + |
| 137 | +template <typename RedType> |
| 138 | +struct Sort<1, RedType> |
| 139 | +{ |
| 140 | + static __device__ void run(RedType* topK) {} |
| 141 | +}; |
| 142 | + |
| 143 | +template <typename RedType> |
| 144 | +struct Sort<2, RedType> |
| 145 | +{ |
| 146 | + static __device__ void run(RedType* topK) |
| 147 | + { |
| 148 | + TOPK_SWAP(0, 1); |
| 149 | + } |
| 150 | +}; |
| 151 | + |
| 152 | +template <typename RedType> |
| 153 | +struct Sort<3, RedType> |
| 154 | +{ |
| 155 | + static __device__ void run(RedType* topK) |
| 156 | + { |
| 157 | + TOPK_SWAP(0, 1); |
| 158 | + TOPK_SWAP(1, 2); |
| 159 | + TOPK_SWAP(0, 1); |
| 160 | + } |
| 161 | +}; |
| 162 | + |
| 163 | +template <typename RedType> |
| 164 | +struct Sort<4, RedType> |
| 165 | +{ |
| 166 | + static __device__ void run(RedType* topK) |
| 167 | + { |
| 168 | + TOPK_SWAP(0, 2); |
| 169 | + TOPK_SWAP(1, 3); |
| 170 | + TOPK_SWAP(0, 1); |
| 171 | + TOPK_SWAP(2, 3); |
| 172 | + TOPK_SWAP(1, 2); |
| 173 | + } |
| 174 | +}; |
| 175 | + |
| 176 | +template <int K, typename Type, int N, bool IsSorted = false> |
| 177 | +__device__ void reduceTopK(cg::thread_block_tile<WARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K], |
| 178 | + Type (&value)[N], int32_t (&idx)[N], Type minValue) |
| 179 | +{ |
| 180 | + static_assert(K > 0, "Top K must have K > 0"); |
| 181 | + static_assert(K < WARP_SIZE, "Top K must have K < WARP_SIZE"); |
| 182 | + static_assert(N > 0, "Top K must have N > 0"); |
| 183 | + static_assert(N < 5, "Only support candidates number less than or equal to 128"); |
| 184 | + using RedType = TopKRedType<Type>; |
| 185 | + RedType topK[N]; |
| 186 | +#pragma unroll |
| 187 | + for (int nn = 0; nn < N; ++nn) |
| 188 | + { |
| 189 | + topK[nn] = RedType{value[nn], idx[nn]}; |
| 190 | + } |
| 191 | + |
| 192 | + if constexpr (!IsSorted) |
| 193 | + { |
| 194 | + Sort<N, RedType>::run(topK); |
| 195 | + } |
| 196 | + typename RedType::TypeCmp packedMax{}; |
| 197 | +#pragma unroll |
| 198 | + for (int kk = 0; kk < K; ++kk) |
| 199 | + { |
| 200 | + bool update = kk > 0 && packedMax == topK[0].compValIdx; |
| 201 | +#pragma unroll |
| 202 | + for (int nn = 0; nn < N; ++nn) |
| 203 | + { |
| 204 | + topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; |
| 205 | + } |
| 206 | + // get the next largest value |
| 207 | + packedMax = topK[0].reduce(warp); |
| 208 | + RedType::unpack(out[kk], outIdx[kk], packedMax); |
| 209 | + } |
| 210 | +}; |
| 211 | + |
| 212 | +#undef TOPK_SWAP |
| 213 | + |
| 214 | +} // end of namespace reduce_topk |
| 215 | + |
| 216 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 217 | + |
| 218 | +template <typename T> |
| 219 | +__device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score, int32_t laneIdx, int32_t NumTopExperts) |
| 220 | +{ |
| 221 | + T maxScore = T{-INFINITY}; |
| 222 | + if (laneIdx < NumTopExperts) |
| 223 | + { |
| 224 | + maxScore = score >= maxScore ? score : maxScore; |
| 225 | + } |
| 226 | + maxScore = cg::reduce(warp, maxScore, cg::greater<T>()); |
| 227 | + |
| 228 | + float sumScore = float{0.f}; |
| 229 | + float newScore; |
| 230 | + // Get the summation of scores for each token |
| 231 | + if (laneIdx < NumTopExperts) |
| 232 | + { |
| 233 | + newScore = static_cast<float>(score) - static_cast<float>(maxScore); |
| 234 | + newScore = static_cast<float>(exp(newScore)); |
| 235 | + sumScore += newScore; |
| 236 | + } |
| 237 | + sumScore = cg::reduce(warp, sumScore, cg::plus<float>()); |
| 238 | + |
| 239 | + if (laneIdx < NumTopExperts) |
| 240 | + { |
| 241 | + score = static_cast<T>(newScore / sumScore); |
| 242 | + } |
| 243 | + |
| 244 | + return score; |
| 245 | +} |
| 246 | + |
| 247 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 248 | + |
| 249 | +template <typename InputT, typename OutputT, typename IdxT, int MaxNumExperts, int MaxNumTopExperts> |
| 250 | +__global__ void renormMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, |
| 251 | + int32_t const numTokens, int32_t const numExperts, int32_t const topK) |
| 252 | +{ |
| 253 | + |
| 254 | + uint32_t const blockRank = blockIdx.x; |
| 255 | + uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x; |
| 256 | + uint32_t const warpIdx = tIdx / WARP_SIZE; |
| 257 | + uint32_t const laneIdx = tIdx % WARP_SIZE; |
| 258 | + uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK; |
| 259 | + auto block = cg::this_thread_block(); |
| 260 | + auto warp = cg::tiled_partition<WARP_SIZE>(block); |
| 261 | + |
| 262 | + InputT minScore = InputT{-INFINITY}; |
| 263 | + for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum) |
| 264 | + { |
| 265 | + auto scoreOffset = tokenId * numExperts; |
| 266 | + auto outputOffset = tokenId * topK; |
| 267 | + InputT inputScore[MaxNumExperts / WARP_SIZE]; |
| 268 | + IdxT inputIndex[MaxNumExperts / WARP_SIZE]; |
| 269 | + |
| 270 | + InputT warpTopKScore[MaxNumTopExperts]; |
| 271 | + IdxT warpTopKExpertIdx[MaxNumTopExperts]; |
| 272 | + |
| 273 | + // Load scores and indices for this warp |
| 274 | + for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i) |
| 275 | + { |
| 276 | + auto expertIdx = i * WARP_SIZE + laneIdx; |
| 277 | + inputScore[i] |
| 278 | + = expertIdx < numExperts ? static_cast<InputT>(routerLogits[scoreOffset + expertIdx]) : minScore; |
| 279 | + inputIndex[i] = expertIdx; |
| 280 | + } |
| 281 | + |
| 282 | + // Reduce topK scores and indices for this warp |
| 283 | + reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore); |
| 284 | + |
| 285 | + // Perform softmax on topK scores |
| 286 | + auto score = calcSoftmax(warp, |
| 287 | + laneIdx < topK ? static_cast<float>(warpTopKScore[laneIdx]) : static_cast<float>(minScore), laneIdx, topK); |
| 288 | + if (laneIdx < topK) |
| 289 | + { |
| 290 | + topkValues[outputOffset + laneIdx] = static_cast<OutputT>(score); |
| 291 | + topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx]; |
| 292 | + } |
| 293 | + } // end for tokenId |
| 294 | +} |
| 295 | + |
| 296 | +int nextPowerOfTwo(int num) |
| 297 | +{ |
| 298 | + if (num <= 0) |
| 299 | + { |
| 300 | + return 1; // Handle invalid input |
| 301 | + } |
| 302 | + int power = 1; |
| 303 | + while (power < num) |
| 304 | + { |
| 305 | + // Check for overflow before shifting |
| 306 | + if (power > INT_MAX / 2) |
| 307 | + { |
| 308 | + return power; |
| 309 | + } |
| 310 | + power <<= 1; |
| 311 | + } |
| 312 | + return power; |
| 313 | +} |
| 314 | + |
| 315 | +#define CASE(MAX_NUM_EXPERTS) \ |
| 316 | + case MAX_NUM_EXPERTS: \ |
| 317 | + switch (maxNumTopExperts) \ |
| 318 | + { \ |
| 319 | + case 1: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 1>; break; \ |
| 320 | + case 2: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 2>; break; \ |
| 321 | + case 4: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 4>; break; \ |
| 322 | + case 8: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 8>; break; \ |
| 323 | + default: kernelInstance = nullptr; break; \ |
| 324 | + } \ |
| 325 | + break; |
| 326 | + |
| 327 | +template <typename InputT, typename OutputT, typename IdxT> |
| 328 | +void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens, |
| 329 | + int64_t const numExperts, int64_t const topK, cudaStream_t const stream) |
| 330 | +{ |
| 331 | + |
| 332 | + const uint32_t maxNumBlocks = 1024; |
| 333 | + const uint32_t numBlocks = std::min(static_cast<uint32_t>((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks); |
| 334 | + |
| 335 | + uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts); |
| 336 | + uint32_t maxNumTopExperts = nextPowerOfTwo(topK); |
| 337 | + |
| 338 | + auto* kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, 128, 8>; |
| 339 | + |
| 340 | + switch (maxNumExperts) |
| 341 | + { |
| 342 | + CASE(32) |
| 343 | + CASE(64) |
| 344 | + CASE(96) |
| 345 | + CASE(128) |
| 346 | + default: kernelInstance = nullptr; break; |
| 347 | + } |
| 348 | + |
| 349 | + if (kernelInstance == nullptr) |
| 350 | + { |
| 351 | + TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance."); |
| 352 | + } |
| 353 | + |
| 354 | + dim3 renormMoeRoutingGridDim(numBlocks); |
| 355 | + dim3 renormMoeRoutingBlockDim(BLOCK_SIZE); |
| 356 | + cudaLaunchConfig_t config; |
| 357 | + config.gridDim = renormMoeRoutingGridDim; |
| 358 | + config.blockDim = renormMoeRoutingBlockDim; |
| 359 | + config.dynamicSmemBytes = 0; |
| 360 | + config.stream = stream; |
| 361 | + cudaLaunchAttribute attrs[1]; |
| 362 | + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; |
| 363 | + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); |
| 364 | + config.numAttrs = 1; |
| 365 | + config.attrs = attrs; |
| 366 | + cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast<int32_t>(numTokens), |
| 367 | + static_cast<int32_t>(numExperts), static_cast<int32_t>(topK)); |
| 368 | + sync_check_cuda_error(stream); |
| 369 | +} |
| 370 | + |
| 371 | +#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT) \ |
| 372 | + template void invokeRenormMoeRouting<InputT, OutputT, IdxT>(InputT * routerLogits, OutputT * topkValues, \ |
| 373 | + IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, \ |
| 374 | + cudaStream_t const stream); |
| 375 | + |
| 376 | +INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t); |
| 377 | +INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t); |
| 378 | +#ifdef ENABLE_BF16 |
| 379 | +INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t); |
| 380 | +#endif |
| 381 | + |
| 382 | +} // namespace tensorrt_llm::kernels |
0 commit comments