diff --git a/CMakeLists.txt b/CMakeLists.txt index 19fdfa46ca4f1..a7f93612eed32 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -102,6 +102,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF) option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) +option(LLAMA_MUSA "llama: use MUSA" OFF) option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_VULKAN "llama: use Vulkan" OFF) @@ -574,6 +575,49 @@ if (LLAMA_HIPBLAS) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::device PUBLIC hip::host roc::rocblas roc::hipblas) endif() +if (LLAMA_MUSA) + option(MUSA_ARCH "MUSA architecture" "21") + + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + + find_package(MUSA REQUIRED) + + message(STATUS "MUSA found") + + enable_language(MUSA) + + set(GGML_HEADERS_MUSA ggml-cuda.h) + + file(GLOB GGML_SOURCES_MUSA "ggml-cuda/*.cu") + list(APPEND GGML_SOURCES_MUSA "ggml-cuda.cu") + + add_compile_definitions(GGML_USE_MUSA GGML_USE_CUDA) + + if (LLAMA_CUDA_FORCE_DMMV) + add_compile_definitions(GGML_CUDA_FORCE_DMMV) + endif() + + if (LLAMA_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() + + if (LLAMA_CUDA_NO_PEER_COPY) + add_compile_definitions(GGML_CUDA_NO_PEER_COPY) + endif() + + add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + + set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE MUSA) + + if (LLAMA_STATIC) + message(FATAL_ERROR "Static linking not supported for MUSA") + endif() + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} PUBLIC MUSA::musa MUSA::mublas MUSA::musart) +endif() + if (LLAMA_SYCL) if (NOT LLAMA_SYCL_TARGET MATCHES "^(INTEL|NVIDIA)$") message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL or NVIDIA") @@ -1160,6 +1204,7 @@ add_library(ggml OBJECT ${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE} ${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN} ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} + ${GGML_SOURCES_MUSA} ${GGML_HEADERS_MUSA} ) target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) diff --git a/Makefile b/Makefile index 7a69ad1b3c14f..0b8714c24a200 100644 --- a/Makefile +++ b/Makefile @@ -565,6 +565,38 @@ ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/com endif # LLAMA_HIPBLAS +ifdef LLAMA_MUSA + MUSA_PATH ?= /usr/local/musa + MUSA_ARCH ?= 21 + MCC ?= $(CCACHE) $(MUSA_PATH)/bin/mcc + LLAMA_CUDA_DMMV_X ?= 32 + LLAMA_CUDA_MMV_Y ?= 1 + LLAMA_CUDA_KQUANTS_ITER ?= 2 + MK_CPPFLAGS += -DGGML_USE_MUSA -DGGML_USE_CUDA + MK_LDFLAGS += -L$(MUSA_PATH)/lib -Wl,-rpath=$(MUSA_PATH)/lib + MK_LDFLAGS += -lmublas -lmusa -lmusart + MUSAFLAGS += --cuda-gpu-arch=mp_$(MUSA_ARCH) + MUSAFLAGS += -Wno-unknown-warning-option -Wno-gnu-anonymous-struct -Wno-nested-anon-types -Wno-invalid-noreturn + MUSAFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) + MUSAFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) + MUSAFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) +ifdef LLAMA_CUDA_FORCE_DMMV + MUSAFLAGS += -DGGML_CUDA_FORCE_DMMV +endif # LLAMA_CUDA_FORCE_DMMV +ifdef LLAMA_CUDA_NO_PEER_COPY + MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY +endif # LLAMA_CUDA_NO_PEER_COPY + OBJS += ggml-cuda.o + OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu)) + +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh) + $(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -c -o $@ $< + +ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh + $(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -c -o $@ $< + +endif # LLAMA_MUSA + ifdef LLAMA_METAL MK_CPPFLAGS += -DGGML_USE_METAL MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit diff --git a/cmake/CMakeDetermineMUSACompiler.cmake b/cmake/CMakeDetermineMUSACompiler.cmake new file mode 100644 index 0000000000000..35dadda54f7d4 --- /dev/null +++ b/cmake/CMakeDetermineMUSACompiler.cmake @@ -0,0 +1,11 @@ + +set(CMAKE_MUSA_ARCHITECTURES "mp_${MUSA_ARCH}") +set(CMAKE_MUSA_COMPILER "${MUSA_MCC}") +set(CMAKE_MUSA_COMPILER_ID "Clang") +set(CMAKE_MUSA_COMPILER_ARG1 "") +set(CMAKE_MUSA_COMPILER_ENV_VAR "MCC") + +configure_file( + ${CMAKE_CURRENT_LIST_DIR}/CMakeMUSACompiler.cmake.in + ${CMAKE_PLATFORM_INFO_DIR}/CMakeMUSACompiler.cmake +) diff --git a/cmake/CMakeMUSACompiler.cmake.in b/cmake/CMakeMUSACompiler.cmake.in new file mode 100644 index 0000000000000..cd611c4dc150a --- /dev/null +++ b/cmake/CMakeMUSACompiler.cmake.in @@ -0,0 +1,6 @@ +set(CMAKE_MUSA_COMPILER "@CMAKE_MUSA_COMPILER@") +set(CMAKE_MUSA_COMPILER_ARG1 "@CMAKE_MUSA_COMPILER_ARG1@") +set(CMAKE_MUSA_COMPILER_LOADED 1) +set(CMAKE_MUSA_SOURCE_FILE_EXTENSIONS mu;cu) +set(CMAKE_MUSA_OUTPUT_EXTENSION .o) +set(CMAKE_MUSA_COMPILER_ENV_VAR "MUSA") diff --git a/cmake/CMakeMUSAInformation.cmake b/cmake/CMakeMUSAInformation.cmake new file mode 100644 index 0000000000000..bb0244d8e9a0d --- /dev/null +++ b/cmake/CMakeMUSAInformation.cmake @@ -0,0 +1,26 @@ + +# reuse cxx things + +include(CMakeLanguageInformation) +include(CMakeCommonLanguageInclude) + +include(Compiler/Clang) + +__compiler_clang(MUSA) +__compiler_clang_cxx_standards(MUSA) + +set(CMAKE_INCLUDE_FLAG_MUSA "-I") + +set(CMAKE_MUSA_RUNTIME_LIBRARY_DEFAULT "SHARED") +set(CMAKE_MUSA_RUNTIME_LIBRARY_LINK_OPTIONS_STATIC "") +set(CMAKE_MUSA_RUNTIME_LIBRARY_LINK_OPTIONS_SHARED "") + +# Populated by CMakeHIPInformation.cmake +set(CMAKE_MUSA_RUNTIME_LIBRARIES_STATIC "") +set(CMAKE_MUSA_RUNTIME_LIBRARIES_SHARED "") + +# compile a C++ file into an object file +if(NOT CMAKE_MUSA_COMPILE_OBJECT) + set(CMAKE_MUSA_COMPILE_OBJECT + " -x musa --cuda-gpu-arch=${CMAKE_MUSA_ARCHITECTURES} -fPIC -o -c ") +endif() diff --git a/cmake/CMakeTestMUSACompiler.cmake b/cmake/CMakeTestMUSACompiler.cmake new file mode 100644 index 0000000000000..6bc32198d0a87 --- /dev/null +++ b/cmake/CMakeTestMUSACompiler.cmake @@ -0,0 +1 @@ +# do nothing, make cmake happy diff --git a/cmake/FindMUSA.cmake b/cmake/FindMUSA.cmake new file mode 100644 index 0000000000000..6841cf372896f --- /dev/null +++ b/cmake/FindMUSA.cmake @@ -0,0 +1,101 @@ +# find MUSA things + +include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) +include(${CMAKE_ROOT}/Modules/SelectLibraryConfigurations.cmake) +include(${CMAKE_ROOT}/Modules/CMakeFindDependencyMacro.cmake) + +if(DEFINED ENV{MUSA_HOME}) + set(MUSA_HOME $ENV{MUSA_HOME}) +else() + set(MUSA_HOME /usr/local/musa) +endif() + +set(MUSA_MCC ${MUSA_HOME}/bin/mcc) + +if (DEFINED ENV{MUSA_ARCH}) + set(MUSA_ARCH $ENV{MUSA_ARCH}) +elseif(NOT MUSA_ARCH) + set(MUSA_ARCH "21") +endif() + +if(NOT MUSA_INCLUDE_DIR) + set(MUSA_INCLUDE_DIR ${MUSA_HOME}/include) +endif() + +if(NOT MUSA_LIBRARY_DIR) + set(MUSA_LIBRARY_DIR ${MUSA_HOME}/lib) +endif() + +if(NOT MUSA_LIBRARIES) + find_library( + MUSA_MUSA_LIBRARY + NAMES libmusa.so + PATHS ${MUSA_LIBRARY_DIR} + ) + + find_library( + MUSA_MUBLAS_LIBRARY + NAMES libmublas.so + PATHS ${MUSA_LIBRARY_DIR} + ) + + find_library( + MUSA_MUSART_LIBRARY + NAMES libmusart.so + PATHS ${MUSA_LIBRARY_DIR} + ) + + if(MUSA_MUSA_LIBRARY AND MUSA_MUBLAS_LIBRARY AND MUSA_MUSART_LIBRARY) + set(MUSA_LIBRARIES "${MUSA_MUSA_LIBRARY};${MUSA_MUBLAS_LIBRARY};${MUSA_MUSART_LIBRARY}") + set(MUSA_MUSA_LIBRARY CACHE STRING "${MUSA_MUSA_LIBRARY}") + set(MUSA_MUBLAS_LIBRARY CACHE STRING "${MUSA_MUBLAS_LIBRARY}") + set(MUSA_MUSART_LIBRARY CACHE STRING "${MUSA_MUSART_LIBRARY}") + endif() +endif() + +if(MUSA_LIBRARIES) + if(NOT TARGET MUSA::musa) + add_library(MUSA::musa SHARED IMPORTED) + set_target_properties(MUSA::musa PROPERTIES + IMPORTED_LOCATION ${MUSA_MUSA_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR} + ) + endif() + + if(NOT TARGET MUSA::mublas) + add_library(MUSA::mublas SHARED IMPORTED) + set_target_properties(MUSA::mublas PROPERTIES + IMPORTED_LOCATION ${MUSA_MUBLAS_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR} + ) + endif() + + if(NOT TARGET MUSA::musart) + add_library(MUSA::musart SHARED IMPORTED) + set_target_properties(MUSA::musart PROPERTIES + IMPORTED_LOCATION ${MUSA_MUSART_LIBRARY} + INTERFACE_INCLUDE_DIRECTORIES ${MUSA_INCLUDE_DIR} + ) + endif() + + set(MUSA_INCLUDE_DIR ${MUSA_INCLUDE_DIR} CACHE STRING "") + set(MUSA_LIBRARY_DIR ${MUSA_LIBRARY_DIR} CACHE STRING "") + set(MUSA_LIBRARIES ${MUSA_LIBRARIES} CACHE STRING "") +endif() + +find_package_handle_standard_args( + MUSA + REQUIRED_VARS + MUSA_ARCH + MUSA_MCC + MUSA_INCLUDE_DIR + MUSA_LIBRARIES + MUSA_LIBRARY_DIR +) +mark_as_advanced( + MUSA_INCLUDE_DIR + MUSA_LIBRARIES + MUSA_LIBRARY_DIR + CMAKE_MUSA_ARCHITECTURES + CMAKE_MUSA_COMPILER +) diff --git a/ggml-common.h b/ggml-common.h index 43c7978a0982d..993906cbb9fe4 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -17,6 +17,15 @@ typedef half2 ggml_half2; #define GGML_COMMON_AGGR +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_MUSA) +#include + +typedef half ggml_half; +typedef half2 ggml_half2; + +#define GGML_COMMON_AGGR + #define GGML_COMMON_DECL #elif defined(GGML_COMMON_DECL_CUDA) #include @@ -73,7 +82,7 @@ typedef sycl::half2 ggml_half2; #define K_SCALE_SIZE 12 #endif // GGML_QKK_64 -#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL) +#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL) || defined(GGML_COMMON_DECL_MUSA) // QR = QK / number of values before dequantization // QI = number of 32 bit integers before dequantization @@ -439,7 +448,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_ #define GGML_TABLE_END() }; #define GGML_COMMON_IMPL -#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) +#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA) #include #define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bff8ad9d96e88..0a21ec9ed2af2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -112,7 +112,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -124,7 +124,7 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; @@ -257,7 +257,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -351,10 +351,10 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); } }; -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } @@ -1596,7 +1596,7 @@ static void ggml_cuda_op_mul_mat( float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); dhf_dst_i += src1_col_0*ne0 + dev[id].row_low; -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices cudaMemcpy3DPeerParms p = {}; p.dstDevice = ctx.device; @@ -1793,7 +1793,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; -#if 0 +#if defined(GGML_USE_MUSA) // use cublasGemmEx { for (int i13 = 0; i13 < ne13; ++i13) { @@ -1802,10 +1802,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co int i02 = i12 / r2; CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, + cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), + alpha, (const char *) src0_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), + (const char *) src1_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 481065b2a3484..0d35a15d438b5 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -8,6 +8,9 @@ #if defined(GGML_USE_HIPBLAS) #define GGML_COMMON_DECL_HIP #define GGML_COMMON_IMPL_HIP +#elif defined(GGML_USE_MUSA) +#define GGML_COMMON_DECL_MUSA +#define GGML_COMMON_IMPL_MUSA #else #define GGML_COMMON_DECL_CUDA #define GGML_COMMON_IMPL_CUDA @@ -117,6 +120,10 @@ #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED +#elif defined(GGML_USE_MUSA) +#include +#include +#include "musa_compatible.cuh" #else #include #include @@ -189,6 +196,26 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in static const char * cublas_get_error_str(const cublasStatus_t err) { return cublasGetStatusString(err); } +#elif defined(GGML_USE_MUSA) + static const char * cublas_get_error_str(const cublasStatus_t err) { + switch (err) { + case MUBLAS_STATUS_SUCCESS: return "MUBLAS_STATUS_SUCCESS"; + case MUBLAS_STATUS_INVALID_HANDLE: return "MUBLAS_STATUS_INVALID_HANDLE"; + case MUBLAS_STATUS_NOT_IMPLEMENTED: return "MUBLAS_STATUS_NOT_IMPLEMENTED"; + case MUBLAS_STATUS_INVALID_POINTER: return "MUBLAS_STATUS_INVALID_POINTER"; + case MUBLAS_STATUS_INVALID_SIZE: return "MUBLAS_STATUS_INVALID_SIZE"; + case MUBLAS_STATUS_MEMORY_ERROR: return "MUBLAS_STATUS_MEMORY_ERROR"; + case MUBLAS_STATUS_INTERNAL_ERROR: return "MUBLAS_STATUS_INTERNAL_ERROR"; + case MUBLAS_STATUS_PERF_DEGRADED: return "MUBLAS_STATUS_PERF_DEGRADED"; + case MUBLAS_STATUS_SIZE_QUERY_MISMATCH: return "MUBLAS_STATUS_SIZE_QUERY_MISMATCH"; + case MUBLAS_STATUS_SIZE_INCREASED: return "MUBLAS_STATUS_SIZE_INCREASED"; + case MUBLAS_STATUS_SIZE_UNCHANGED: return "MUBLAS_STATUS_SIZE_UNCHANGED"; + case MUBLAS_STATUS_INVALID_VALUE: return "MUBLAS_STATUS_INVALID_VALUE"; + case MUBLAS_STATUS_CONTINUE: return "MUBLAS_STATUS_CONTINUE"; + + default: return "unknown error"; + } + } #else static const char * cublas_get_error_str(const cublasStatus_t err) { switch (err) { diff --git a/ggml-cuda/musa_compatible.cuh b/ggml-cuda/musa_compatible.cuh new file mode 100644 index 0000000000000..16e0683d62482 --- /dev/null +++ b/ggml-cuda/musa_compatible.cuh @@ -0,0 +1,202 @@ + +#ifndef _MUSA_COMPATIBLE_CUH +#define _MUSA_COMPATIBLE_CUH + + +#define CUresult MUresult +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr + +#define cudaDataType_t musaDataType_t +#define cudaError_t musaError_t +#define cudaEvent_t musaEvent_t +#define cudaStream_t musaStream_t +#define cudaDeviceProp musaDeviceProp + +#define cublasStatus_t mublasStatus_t +#define cublasHandle_t mublasHandle_t +#define cublasComputeType_t musaDataType_t // reserved in musa + +#define cuGetErrorString muGetErrorString +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +// #define cuMemGetAllocationGranularity muMemGetAllocationGranularity // so far, not implemeted +// #define CUmemAllocationProp MUmemAllocationProp + +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaMemGetInfo musaMemGetInfo +#define cudaMemset musaMemset +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaHostUnregister musaHostUnregister +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaSetDevice musaSetDevice +#define cudaEventRecord musaEventRecord +#define cudaEventDestroy musaEventDestroy +#define cudaEventCreate musaEventCreate +#define cudaEventSynchronize musaEventSynchronize +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamWaitEvent musaStreamWaitEvent + +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasGemmEx mublasGemmEx +#define cublasSgemm mublasSgemm +#ifdef mublasGemmStridedBatchedEx +#undef mublasGemmStridedBatchedEx +#endif // mublasGemmStridedBatchedEx +#define cublasGemmStridedBatchedEx( \ + handle, \ + transA, \ + transB, \ + m, \ + n, \ + k, \ + alpha, \ + A, \ + Atype, \ + lda, \ + strideA, \ + B, \ + Btype, \ + ldb, \ + strideB, \ + beta, \ + C, \ + Ctype, \ + ldc, \ + strideC, \ + batchCount, \ + computeType, \ + algo \ +) \ +mublasGemmStridedBatchedEx( \ + handle, \ + transA, \ + transB, \ + m, \ + n, \ + k, \ + alpha, \ + A, \ + Atype, \ + lda, \ + strideA, \ + B, \ + Btype, \ + ldb, \ + strideB, \ + beta, \ + C, \ + Ctype, \ + ldc, \ + strideC, \ + C /* D */, \ + Ctype, \ + ldc, \ + strideC, \ + batchCount, \ + computeType, \ + algo, \ + 0 /* solution type, reserved */, \ + 0 /* flags */ \ +) + +#define cublasGemmBatchedEx( \ + handle, \ + transA, \ + transB, \ + m, \ + n, \ + k, \ + alpha, \ + A, \ + Atype, \ + lda, \ + B, \ + Btype, \ + ldb, \ + beta, \ + C, \ + Ctype, \ + ldc, \ + batchCount, \ + computeType, \ + algo \ +) \ +mublasGemmBatchedEx( \ + handle, \ + transA, \ + transB, \ + m, \ + n, \ + k, \ + alpha, \ + A, \ + Atype, \ + lda, \ + B, \ + Btype, \ + ldb, \ + beta, \ + C, \ + Ctype, \ + ldc, \ + C /* D */, \ + Ctype, \ + ldc, \ + batchCount, \ + computeType, \ + algo, \ + 0 /* solution type, reserved */, \ + 0 /* flags */ \ +) + +#define CUDART_VERSION MUSART_VERSION + +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +// #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +// #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED + +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED MUBLAS_STATUS_NOT_IMPLEMENTED +#define CUBLAS_STATUS_ALLOC_FAILED MUBLAS_STATUS_NOT_IMPLEMENTED +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_TP32_TENSOR // ??? +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_COMPUTE_16F MUSA_R_16F // reserved in musa +#define CUBLAS_COMPUTE_32F MUSA_R_32F // reserved in musa +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT_TENSOR_OP + +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_32F MUSA_R_32F +#define cudaSuccess musaSuccess +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled + +#endif // _MUSA_COMPATIBLE_CUH \ No newline at end of file