diff --git a/csrc/grouped_gemm.cu b/csrc/grouped_gemm.cu index 3729862..6a2af86 100644 --- a/csrc/grouped_gemm.cu +++ b/csrc/grouped_gemm.cu @@ -206,6 +206,164 @@ void cublas_handle_init() } } +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500 + +#define MAX_GROUPSIZE 1024 + +cublasOperation_t trans_array_T[MAX_GROUPSIZE]; +cublasOperation_t trans_array_N[MAX_GROUPSIZE]; +int m_array[MAX_GROUPSIZE]; +int n_array[MAX_GROUPSIZE]; +int k_array[MAX_GROUPSIZE]; +float alpha_array[MAX_GROUPSIZE]; +float beta_array[MAX_GROUPSIZE]; + +void * Aarray[MAX_GROUPSIZE]; +int lda_array[MAX_GROUPSIZE]; +void * Barray[MAX_GROUPSIZE]; +int ldb_array[MAX_GROUPSIZE]; +void * Carray[MAX_GROUPSIZE]; +int ldc_array[MAX_GROUPSIZE]; + +// on device +void **d_Aarray = nullptr; +void **d_Barray = nullptr; +void **d_Carray = nullptr; + +int group_size[MAX_GROUPSIZE]; + +bool cublas_grouped_gemm_init = false; + +void cublas_grouped_gemm_global_var_init() +{ + cublas_grouped_gemm_init = true; + + for (int i = 0; i < MAX_GROUPSIZE; i++) + { + alpha_array[i] = 1.0; + beta_array[i] = 0.0; + group_size[i] = 1; + trans_array_T[i] = CUBLAS_OP_T; + trans_array_N[i] = CUBLAS_OP_N; + } + + CUDA_CALL(cudaMallocAsync( + &d_Aarray, + MAX_GROUPSIZE * sizeof(void *), + c10::cuda::getCurrentCUDAStream())); + CUDA_CALL(cudaMallocAsync( + &d_Barray, + MAX_GROUPSIZE * sizeof(void *), + c10::cuda::getCurrentCUDAStream())); + CUDA_CALL(cudaMallocAsync( + &d_Carray, + MAX_GROUPSIZE * sizeof(void *), + c10::cuda::getCurrentCUDAStream())); +} + +void CublasGemmGroupedBatched(torch::Tensor a, + torch::Tensor b, + torch::Tensor c, + torch::Tensor batch_sizes, + bool trans_a, bool trans_b) +{ + if (!cublas_grouped_gemm_init) + cublas_grouped_gemm_global_var_init(); + + c10::BFloat16* a_ptr = a.data_ptr(); + c10::BFloat16* b_ptr = b.data_ptr(); + c10::BFloat16* c_ptr = c.data_ptr(); + + int a_rows, a_cols, b_rows, b_cols, c_rows, c_cols; + + int group_count = 0; + for (int i = 0; i < batch_sizes.size(0); i++) + { + int bs = batch_sizes.data_ptr()[i]; + if (trans_a) { + a_rows = bs; + a_cols = a.size(1); + + // b.dims() == 2 here + b_rows = bs; + b_cols = b.size(1); + + c_rows = a_cols; + c_cols = b_cols; + } else { + a_rows = bs; + a_cols = a.size(1); + + // b.dims() == 3 here + b_rows = b.size(1); + b_cols = b.size(2); + + c_rows = a_rows; + c_cols = trans_b ? b_rows : b_cols; + } + + if (bs != 0) { + int m = trans_b ? b_rows : b_cols; + int k = trans_b ? b_cols : b_rows; + int n = trans_a ? a_cols : a_rows; + m_array[group_count] = m; + n_array[group_count] = n; + k_array[group_count] = k; + + lda_array[group_count] = trans_a ? n : k; + ldb_array[group_count] = trans_b ? k : m; + ldc_array[group_count] = c_cols; + + Aarray[group_count] = a_ptr; + Barray[group_count] = b_ptr; + Carray[group_count] = c_ptr; + + group_count++; + } + + a_ptr += a_rows * a_cols; + b_ptr += b_rows * b_cols; + c_ptr += c_rows * c_cols; + } + + CUDA_CALL(cudaMemcpyAsync(d_Aarray, Aarray, + sizeof(void *) * group_count, + cudaMemcpyHostToDevice, + c10::cuda::getCurrentCUDAStream())); + CUDA_CALL(cudaMemcpyAsync(d_Barray, Barray, + sizeof(void *) * group_count, + cudaMemcpyHostToDevice, + c10::cuda::getCurrentCUDAStream())); + CUDA_CALL(cudaMemcpyAsync(d_Carray, Carray, + sizeof(void *) * group_count, + cudaMemcpyHostToDevice, + c10::cuda::getCurrentCUDAStream())); + + CUBLAS_CALL(cublasGemmGroupedBatchedEx( + at::cuda::getCurrentCUDABlasHandle(), + trans_b ? trans_array_T : trans_array_N, + trans_a ? trans_array_T : trans_array_N, + m_array, + n_array, + k_array, + alpha_array, + d_Barray, + CUDA_R_16BF, + ldb_array, + d_Aarray, + CUDA_R_16BF, + lda_array, + beta_array, + d_Carray, + CUDA_R_16BF, + ldc_array, + group_count, + group_size, + CUBLAS_COMPUTE_32F)); +} + +#endif + inline void cublas_current_wait_streams(cudaStream_t stream) { for (int s = 0; s < NUM_STREAM; s++) @@ -259,6 +417,12 @@ void CublasGroupedGemm(torch::Tensor a, torch::Tensor c, torch::Tensor batch_sizes, bool trans_b) { + +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500 + CublasGemmGroupedBatched(a, b, c, batch_sizes, false, trans_b); + return; +#endif + if (!cublas_init) cublas_handle_init(); @@ -289,6 +453,11 @@ void CublasGroupedGemmVariableK(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes) { +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 12500 + CublasGemmGroupedBatched(a, b, c, batch_sizes, true, false); + return; +#endif + if (!cublas_init) cublas_handle_init(); diff --git a/setup.py b/setup.py index 8798172..36d1acf 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,9 @@ f"-DGROUPED_GEMM_DEVICE_CAPABILITY={device_capability}", ]) +if "CUBLAS_VERSION" in os.environ: + nvcc_flags.append(f"-DCUBLAS_VERSION={os.environ['CUBLAS_VERSION']}") + ext_modules = [ CUDAExtension( "grouped_gemm_backend",