diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 79ee204d195..b24479ec705 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -287,6 +287,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) + #ifndef GGML_USE_HIP FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_TQ3_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TQ3_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_TQ3_0) @@ -297,11 +298,14 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_BF16) +#endif // GGML_USE_HIP #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) +#ifndef GGML_USE_HIP FATTN_VEC_CASES_ALL_D(GGML_TYPE_TQ3_0, GGML_TYPE_TQ3_0) +#endif // GGML_USE_HIP FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index ac8c8c205a4..ded13ee1b66 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,5 +1,7 @@ #include "gated_delta_net.cuh" +#ifndef GGML_USE_HIP #include +#endif #include // Tree-mode parent index sentinel: a node whose parent is the pre-block state diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 898fec31e36..f187dfbc9e6 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -45,6 +45,7 @@ #define cublasGemmEx hipblasGemmEx #define cublasGemmBatchedEx hipblasGemmBatchedEx #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasSgemmStridedBatched hipblasSgemmStridedBatched #define cublasHandle_t hipblasHandle_t #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream @@ -140,6 +141,9 @@ #define cudaGraphExecUpdate hipGraphExecUpdate #define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed #define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaStreamCaptureStatus hipStreamCaptureStatus +#define cudaStreamCaptureStatusNone hipStreamCaptureStatusNone +#define cudaStreamIsCapturing hipStreamIsCapturing #define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess