|
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include <ATen/cuda/CUDAContext.h> |
| 18 | +#include <c10/cuda/CUDAStream.h> |
| 19 | +#include <cuda.h> |
| 20 | +#include <cuda_runtime.h> |
| 21 | +#include <torch/extension.h> |
| 22 | + |
| 23 | +#include "tinygemm2_kernel.cuh" |
| 24 | + |
| 25 | +void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, __nv_bfloat16* bias, int batch_size, |
| 26 | + int output_features, int input_features, cudaStream_t stream) |
| 27 | +{ |
| 28 | + |
| 29 | + static int const WARP_TILE_M = 16; |
| 30 | + static int const TILE_M = WARP_TILE_M; |
| 31 | + static int const TILE_N = 8; |
| 32 | + static int const TILE_K = 64; |
| 33 | + static int const STAGES = 16; |
| 34 | + static int const STAGE_UNROLL = 4; |
| 35 | + static bool const PROFILE = false; |
| 36 | + |
| 37 | + CUtensorMap weight_map{}; |
| 38 | + CUtensorMap activation_map{}; |
| 39 | + |
| 40 | + constexpr uint32_t rank = 2; |
| 41 | + uint64_t size[rank] = {(uint64_t) input_features, (uint64_t) output_features}; |
| 42 | + uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)}; |
| 43 | + uint32_t box_size[rank] = {TILE_K, TILE_M}; |
| 44 | + uint32_t elem_stride[rank] = {1, 1}; |
| 45 | + |
| 46 | + CUresult res = cuTensorMapEncodeTiled(&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, gB, |
| 47 | + size, stride, box_size, elem_stride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, |
| 48 | + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, |
| 49 | + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); |
| 50 | + assert(res == 0); |
| 51 | + |
| 52 | + size[1] = batch_size; |
| 53 | + box_size[1] = TILE_N; |
| 54 | + |
| 55 | + res = cuTensorMapEncodeTiled(&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, gA, size, |
| 56 | + stride, box_size, elem_stride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, |
| 57 | + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, |
| 58 | + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); |
| 59 | + assert(res == 0); |
| 60 | + |
| 61 | + int smem_size |
| 62 | + = STAGES * STAGE_UNROLL * (TILE_M * TILE_K * sizeof(__nv_bfloat16) + TILE_N * TILE_K * sizeof(__nv_bfloat16)); |
| 63 | + |
| 64 | + gpuErrChk(cudaFuncSetAttribute(kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>, |
| 65 | + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); |
| 66 | + |
| 67 | + int tiles_m = (output_features + TILE_M - 1) / TILE_M; |
| 68 | + int tiles_n = (batch_size + TILE_N - 1) / TILE_N; |
| 69 | + |
| 70 | + dim3 grid(tiles_m, tiles_n); |
| 71 | + |
| 72 | + dim3 block(384); |
| 73 | + |
| 74 | + cudaLaunchConfig_t config; |
| 75 | + cudaLaunchAttribute attrs[1]; |
| 76 | + config.gridDim = grid; |
| 77 | + config.blockDim = block; |
| 78 | + config.dynamicSmemBytes = smem_size; |
| 79 | + config.stream = stream; |
| 80 | + config.attrs = attrs; |
| 81 | + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; |
| 82 | + attrs[0].val.programmaticStreamSerializationAllowed = 1; |
| 83 | + config.numAttrs = 1; |
| 84 | + |
| 85 | + cudaLaunchKernelEx(&config, &kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>, gC, gA, gB, |
| 86 | + bias, output_features, batch_size, input_features, weight_map, activation_map, nullptr); |
| 87 | +} |
| 88 | + |
| 89 | +torch::Tensor tinygemm2_cuda_forward(torch::Tensor input, torch::Tensor weight, torch::Tensor bias) |
| 90 | +{ |
| 91 | + |
| 92 | + auto const batch_size = input.size(0); |
| 93 | + auto const input_dim = input.size(1); |
| 94 | + auto const output_dim = weight.size(0); |
| 95 | + |
| 96 | + auto output = torch::empty({batch_size, output_dim}, input.options()); |
| 97 | + |
| 98 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 99 | + |
| 100 | + if (input.scalar_type() == at::ScalarType::BFloat16) |
| 101 | + { |
| 102 | + |
| 103 | + launch_tinygemm2((__nv_bfloat16*) input.data_ptr(), (__nv_bfloat16*) weight.data_ptr(), |
| 104 | + (__nv_bfloat16*) output.data_ptr(), (__nv_bfloat16*) bias.data_ptr(), batch_size, output_dim, input_dim, |
| 105 | + stream); |
| 106 | + } |
| 107 | + else |
| 108 | + { |
| 109 | + assert(false); |
| 110 | + } |
| 111 | + return output; |
| 112 | +} |
0 commit comments