Skip to content

Commit 38e72f8

Browse files
committed
Add ROCm support
1 parent 4585e2c commit 38e72f8

22 files changed

+105
-21
lines changed

cuda_ext.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,30 @@
1111
library_dir = os.path.dirname(os.path.abspath(__file__))
1212
extension_name = "exllama_ext"
1313

14+
if torch.version.hip:
15+
# FIXME: To build, I had to comment "flags += ['-fno-gpu-rdc']" in torch/utils/cpp_extension.py.
16+
# I am not sure if it's possible to find a way to build without editing that file.
17+
# If building without gpu-rdc, build will error with "lld: error: undefined hidden symbol: __llvm_amdgcn_rcp_f16".
18+
extra_cuda_cflags= ["-U__HIP_NO_HALF_CONVERSIONS__", "-fgpu-rdc"]
19+
else:
20+
extra_cuda_cflags = []
21+
1422
exllama_ext = load(
1523
name = extension_name,
1624
sources = [
1725
os.path.join(library_dir, "exllama_ext/cuda_buffers.cu"),
1826
os.path.join(library_dir, "exllama_ext/cpu_func/rep_penalty.cpp"),
19-
os.path.join(library_dir, "exllama_ext/cuda_func/column_remap.cu"),
20-
os.path.join(library_dir, "exllama_ext/cuda_func/half_matmul.cu"),
21-
os.path.join(library_dir, "exllama_ext/cuda_func/q4v2_matmul.cu"),
22-
os.path.join(library_dir, "exllama_ext/cuda_func/q4v2_mlp.cu"),
23-
os.path.join(library_dir, "exllama_ext/cuda_func/q4v2_recons.cu"),
24-
os.path.join(library_dir, "exllama_ext/cuda_func/q4v2_sequential.cu"),
25-
os.path.join(library_dir, "exllama_ext/cuda_func/rms_norm.cu"),
26-
os.path.join(library_dir, "exllama_ext/cuda_func/rope.cu"),
27+
os.path.join(library_dir, "exllama_ext/cu_func/column_remap.cu"),
28+
os.path.join(library_dir, "exllama_ext/cu_func/half_matmul.cu"),
29+
os.path.join(library_dir, "exllama_ext/cu_func/q4v2_matmul.cu"),
30+
os.path.join(library_dir, "exllama_ext/cu_func/q4v2_mlp.cu"),
31+
os.path.join(library_dir, "exllama_ext/cu_func/q4v2_recons.cu"),
32+
os.path.join(library_dir, "exllama_ext/cu_func/q4v2_sequential.cu"),
33+
os.path.join(library_dir, "exllama_ext/cu_func/rms_norm.cu"),
34+
os.path.join(library_dir, "exllama_ext/cu_func/rope.cu"),
2735
os.path.join(library_dir, "exllama_ext/exllama_ext.cpp")
2836
],
37+
extra_cuda_cflags = extra_cuda_cflags
2938
# verbose = True,
3039
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
3140
)

exllama_ext/cuda_func/column_remap.cuh renamed to exllama_ext/cu_func/column_remap.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _column_remap_cuh
22
#define _column_remap_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t column_remap_cuda

exllama_ext/cuda_func/half_matmul.cu renamed to exllama_ext/cu_func/half_matmul.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ cudaError_t half_matmul_cublas_cuda
102102
const half alpha = __float2half(1.0f);
103103
const half beta = __float2half(0.0f);
104104

105-
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, w, width, x, dim, &beta, out, width);
105+
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, reinterpret_cast<const rocblas_half*>(&alpha), reinterpret_cast<const rocblas_half*>(w), width, reinterpret_cast<const rocblas_half*>(x), dim, reinterpret_cast<const rocblas_half*>(&beta), reinterpret_cast<rocblas_half*>(out), width);
106106

107107
// cudaDeviceSynchronize();
108108
// _cuda_check(cudaGetLastError());

exllama_ext/cuda_func/half_matmul.cuh renamed to exllama_ext/cu_func/half_matmul.cuh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
#ifndef _half_matmul_cuh
22
#define _half_matmul_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#include <rocblas/rocblas.h>
8+
#include <ATen/hip/HIPContext.h>
9+
#define cudaError_t hipError_t
10+
#define cublasHandle_t rocblas_handle
11+
#else
412
#include <cuda_runtime.h>
513
#include <cuda_fp16.h>
6-
#include <cstdint>
714
#include <ATen/cuda/CUDAContext.h>
15+
#endif
16+
#include <cstdint>
817

918
cudaError_t half_matmul_cuda
1019
(

exllama_ext/cuda_func/q4v2_matmul.cuh renamed to exllama_ext/cu_func/q4v2_matmul.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_matmul_cuh
22
#define _q4v2_matmul_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713
#include <cstdio>
814

File renamed without changes.

exllama_ext/cuda_func/q4v2_mlp.cuh renamed to exllama_ext/cu_func/q4v2_mlp.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#ifndef _q4v2_mlp_cuh
22
#define _q4v2_mlp_cuh
33

4+
#if USE_ROCM
5+
#include <hip/hip_runtime.h>
6+
#include <hip/hip_fp16.h>
7+
#define cudaError_t hipError_t
8+
#else
49
#include <cuda_runtime.h>
510
#include <cuda_fp16.h>
11+
#endif
612
#include <cstdint>
713

814
cudaError_t q4v2_mlp_cuda

0 commit comments

Comments
 (0)