Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions csrc/grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,164 @@ void cublas_handle_init()
}
}

#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500

#define MAX_GROUPSIZE 1024

cublasOperation_t trans_array_T[MAX_GROUPSIZE];
cublasOperation_t trans_array_N[MAX_GROUPSIZE];
int m_array[MAX_GROUPSIZE];
int n_array[MAX_GROUPSIZE];
int k_array[MAX_GROUPSIZE];
float alpha_array[MAX_GROUPSIZE];
float beta_array[MAX_GROUPSIZE];

void * Aarray[MAX_GROUPSIZE];
int lda_array[MAX_GROUPSIZE];
void * Barray[MAX_GROUPSIZE];
int ldb_array[MAX_GROUPSIZE];
void * Carray[MAX_GROUPSIZE];
int ldc_array[MAX_GROUPSIZE];

// on device
void **d_Aarray = nullptr;
void **d_Barray = nullptr;
void **d_Carray = nullptr;

int group_size[MAX_GROUPSIZE];

bool cublas_grouped_gemm_init = false;

void cublas_grouped_gemm_global_var_init()
{
cublas_grouped_gemm_init = true;

for (int i = 0; i < MAX_GROUPSIZE; i++)
{
alpha_array[i] = 1.0;
beta_array[i] = 0.0;
group_size[i] = 1;
trans_array_T[i] = CUBLAS_OP_T;
trans_array_N[i] = CUBLAS_OP_N;
}

CUDA_CALL(cudaMallocAsync(
&d_Aarray,
MAX_GROUPSIZE * sizeof(void *),
c10::cuda::getCurrentCUDAStream()));
CUDA_CALL(cudaMallocAsync(
&d_Barray,
MAX_GROUPSIZE * sizeof(void *),
c10::cuda::getCurrentCUDAStream()));
CUDA_CALL(cudaMallocAsync(
&d_Carray,
MAX_GROUPSIZE * sizeof(void *),
c10::cuda::getCurrentCUDAStream()));
}

void CublasGemmGroupedBatched(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_a, bool trans_b)
{
if (!cublas_grouped_gemm_init)
cublas_grouped_gemm_global_var_init();

c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();

int a_rows, a_cols, b_rows, b_cols, c_rows, c_cols;

int group_count = 0;
for (int i = 0; i < batch_sizes.size(0); i++)
{
int bs = batch_sizes.data_ptr<int64_t>()[i];
if (trans_a) {
a_rows = bs;
a_cols = a.size(1);

// b.dims() == 2 here
b_rows = bs;
b_cols = b.size(1);

c_rows = a_cols;
c_cols = b_cols;
} else {
a_rows = bs;
a_cols = a.size(1);

// b.dims() == 3 here
b_rows = b.size(1);
b_cols = b.size(2);

c_rows = a_rows;
c_cols = trans_b ? b_rows : b_cols;
}

if (bs != 0) {
int m = trans_b ? b_rows : b_cols;
int k = trans_b ? b_cols : b_rows;
int n = trans_a ? a_cols : a_rows;
m_array[group_count] = m;
n_array[group_count] = n;
k_array[group_count] = k;

lda_array[group_count] = trans_a ? n : k;
ldb_array[group_count] = trans_b ? k : m;
ldc_array[group_count] = c_cols;

Aarray[group_count] = a_ptr;
Barray[group_count] = b_ptr;
Carray[group_count] = c_ptr;

group_count++;
}

a_ptr += a_rows * a_cols;
b_ptr += b_rows * b_cols;
c_ptr += c_rows * c_cols;
}

CUDA_CALL(cudaMemcpyAsync(d_Aarray, Aarray,
sizeof(void *) * group_count,
cudaMemcpyHostToDevice,
c10::cuda::getCurrentCUDAStream()));
CUDA_CALL(cudaMemcpyAsync(d_Barray, Barray,
sizeof(void *) * group_count,
cudaMemcpyHostToDevice,
c10::cuda::getCurrentCUDAStream()));
CUDA_CALL(cudaMemcpyAsync(d_Carray, Carray,
sizeof(void *) * group_count,
cudaMemcpyHostToDevice,
c10::cuda::getCurrentCUDAStream()));

CUBLAS_CALL(cublasGemmGroupedBatchedEx(
at::cuda::getCurrentCUDABlasHandle(),
trans_b ? trans_array_T : trans_array_N,
trans_a ? trans_array_T : trans_array_N,
m_array,
n_array,
k_array,
alpha_array,
d_Barray,
CUDA_R_16BF,
ldb_array,
d_Aarray,
CUDA_R_16BF,
lda_array,
beta_array,
d_Carray,
CUDA_R_16BF,
ldc_array,
group_count,
group_size,
CUBLAS_COMPUTE_32F));
}

#endif

inline void cublas_current_wait_streams(cudaStream_t stream)
{
for (int s = 0; s < NUM_STREAM; s++)
Expand Down Expand Up @@ -259,6 +417,12 @@ void CublasGroupedGemm(torch::Tensor a,
torch::Tensor c,
torch::Tensor batch_sizes,
bool trans_b) {

#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500
CublasGemmGroupedBatched(a, b, c, batch_sizes, false, trans_b);
return;
#endif

if (!cublas_init)
cublas_handle_init();

Expand Down Expand Up @@ -289,6 +453,11 @@ void CublasGroupedGemmVariableK(torch::Tensor a,
torch::Tensor b,
torch::Tensor c,
torch::Tensor batch_sizes) {
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500
CublasGemmGroupedBatched(a, b, c, batch_sizes, true, false);
return;
#endif

if (!cublas_init)
cublas_handle_init();

Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}",
])

if "CUBLAS_VERSION" in os.environ:
nvcc_flags.append(f"-DCUBLAS_VERSION={os.environ['CUBLAS_VERSION']}")

ext_modules = [
CUDAExtension(
"grouped_gemm_backend",
Expand Down