Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,24 @@ endif()
set(MLAS_SOURCE_IS_NOT_SET 0)
endif()
endif()
if(RISCV64 AND MLAS_SOURCE_IS_NOT_SET)
# RISC-V 64: FP32 SGEMM uses the upstream scalar fallback
# (scalar/*.cpp -> MlasSgemmKernelZero/Add), invoked via the #else
# branch in sgemm.cpp. Only the RVV INT8 GEMM kernel is added here.
file(GLOB_RECURSE mlas_platform_srcs_scalar "${MLAS_SRC_DIR}/scalar/*.cpp")
set(mlas_platform_srcs
${mlas_platform_srcs_scalar}
${MLAS_SRC_DIR}/qgemm_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sgemv_kernel_rvv.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_rvv.cpp
PROPERTIES COMPILE_FLAGS "-march=rv64gcv")
set_source_files_properties(${MLAS_SRC_DIR}/riscv64/sgemv_kernel_rvv.cpp
PROPERTIES COMPILE_FLAGS "-march=rv64gcv")
if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH)
set(MLAS_SOURCE_IS_NOT_SET 0)
endif()
endif()
if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET)
set(mlas_platform_srcs
${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,7 @@ extern "C" {
#if defined(MLAS_TARGET_AMD64)
MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1Avx;
MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1TransposeBAvx;
#elif defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_WASM)
#elif defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_WASM) || defined(MLAS_TARGET_RISCV64)
MLAS_GEMV_FLOAT_KERNEL MlasGemvFloatKernel;
#endif

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/mlas/lib/qgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,11 @@ MlasGemmQuantGetDispatch(
if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchZVECTOR) {
GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch;
}
#elif defined(MLAS_TARGET_RISCV64)
// RISC-V 64: a single RVV INT8 dispatch covers U8S8/S8S8/U8U8/S8U8.
// See onnxruntime/core/mlas/lib/qgemm_kernel_rvv.cpp.
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchRvv;
GemmQuantDispatch = &MlasGemmU8S8DispatchRvv;
#endif
#endif // !defined(FORCE_GENERIC_ALGORITHMS)

Expand Down
Loading