Skip to content

Commit c0d747e

Browse files
dongfengyliji-nv
authored andcommitted
[TRTLLM-7775][feat] Integrate tinygemm2 for gpt-oss (NVIDIA#7916)
Signed-off-by: Dongfeng Yu <[email protected]> Signed-off-by: dongfengy <[email protected]> Co-authored-by: Jin Li <[email protected]>
1 parent 24f15ce commit c0d747e

File tree

8 files changed

+690
-3
lines changed

8 files changed

+690
-3
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
3+
# All rights reserved. SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6+
# use this file except in compliance with the License. You may obtain a copy of
7+
# the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
# License for the specific language governing permissions and limitations under
15+
# the License.
16+
#
17+
18+
file(GLOB_RECURSE SRC_CPP *.cpp)
19+
file(GLOB_RECURSE SRC_CU *.cu)
20+
add_library(tinygemm2_src OBJECT ${SRC_CPP} ${SRC_CU})
21+
22+
target_compile_options(tinygemm2_src
23+
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=32>)
24+
25+
set_property(TARGET tinygemm2_src PROPERTY POSITION_INDEPENDENT_CODE ON)
26+
set_property(TARGET tinygemm2_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <ATen/cuda/CUDAContext.h>
18+
#include <c10/cuda/CUDAStream.h>
19+
#include <cuda.h>
20+
#include <cuda_runtime.h>
21+
#include <torch/extension.h>
22+
23+
#include "tinygemm2_kernel.cuh"
24+
25+
void launch_tinygemm2(__nv_bfloat16* gA, __nv_bfloat16* gB, __nv_bfloat16* gC, __nv_bfloat16* bias, int batch_size,
26+
int output_features, int input_features, cudaStream_t stream)
27+
{
28+
29+
static int const WARP_TILE_M = 16;
30+
static int const TILE_M = WARP_TILE_M;
31+
static int const TILE_N = 8;
32+
static int const TILE_K = 64;
33+
static int const STAGES = 16;
34+
static int const STAGE_UNROLL = 4;
35+
static bool const PROFILE = false;
36+
37+
CUtensorMap weight_map{};
38+
CUtensorMap activation_map{};
39+
40+
constexpr uint32_t rank = 2;
41+
uint64_t size[rank] = {(uint64_t) input_features, (uint64_t) output_features};
42+
uint64_t stride[rank - 1] = {input_features * sizeof(__nv_bfloat16)};
43+
uint32_t box_size[rank] = {TILE_K, TILE_M};
44+
uint32_t elem_stride[rank] = {1, 1};
45+
46+
CUresult res = cuTensorMapEncodeTiled(&weight_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, gB,
47+
size, stride, box_size, elem_stride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
48+
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
49+
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
50+
assert(res == 0);
51+
52+
size[1] = batch_size;
53+
box_size[1] = TILE_N;
54+
55+
res = cuTensorMapEncodeTiled(&activation_map, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, rank, gA, size,
56+
stride, box_size, elem_stride, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
57+
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
58+
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
59+
assert(res == 0);
60+
61+
int smem_size
62+
= STAGES * STAGE_UNROLL * (TILE_M * TILE_K * sizeof(__nv_bfloat16) + TILE_N * TILE_K * sizeof(__nv_bfloat16));
63+
64+
gpuErrChk(cudaFuncSetAttribute(kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>,
65+
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
66+
67+
int tiles_m = (output_features + TILE_M - 1) / TILE_M;
68+
int tiles_n = (batch_size + TILE_N - 1) / TILE_N;
69+
70+
dim3 grid(tiles_m, tiles_n);
71+
72+
dim3 block(384);
73+
74+
cudaLaunchConfig_t config;
75+
cudaLaunchAttribute attrs[1];
76+
config.gridDim = grid;
77+
config.blockDim = block;
78+
config.dynamicSmemBytes = smem_size;
79+
config.stream = stream;
80+
config.attrs = attrs;
81+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
82+
attrs[0].val.programmaticStreamSerializationAllowed = 1;
83+
config.numAttrs = 1;
84+
85+
cudaLaunchKernelEx(&config, &kernel<WARP_TILE_M, TILE_M, TILE_N, TILE_K, STAGES, STAGE_UNROLL, PROFILE>, gC, gA, gB,
86+
bias, output_features, batch_size, input_features, weight_map, activation_map, nullptr);
87+
}
88+
89+
torch::Tensor tinygemm2_cuda_forward(torch::Tensor input, torch::Tensor weight, torch::Tensor bias)
90+
{
91+
92+
auto const batch_size = input.size(0);
93+
auto const input_dim = input.size(1);
94+
auto const output_dim = weight.size(0);
95+
96+
auto output = torch::empty({batch_size, output_dim}, input.options());
97+
98+
auto stream = at::cuda::getCurrentCUDAStream();
99+
100+
if (input.scalar_type() == at::ScalarType::BFloat16)
101+
{
102+
103+
launch_tinygemm2((__nv_bfloat16*) input.data_ptr(), (__nv_bfloat16*) weight.data_ptr(),
104+
(__nv_bfloat16*) output.data_ptr(), (__nv_bfloat16*) bias.data_ptr(), batch_size, output_dim, input_dim,
105+
stream);
106+
}
107+
else
108+
{
109+
assert(false);
110+
}
111+
return output;
112+
}

0 commit comments

Comments
 (0)