Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -929,11 +929,15 @@ endif()
${MLAS_SRC_DIR}/riscv64/sgemm_pack_b_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sgemm_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/qgemm_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/activation_kernel_rvv.cpp
)
set_source_files_properties(
${MLAS_SRC_DIR}/riscv64/sgemm_pack_b_rvv.cpp
${MLAS_SRC_DIR}/riscv64/sgemm_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/qgemm_kernel_rvv.cpp
${MLAS_SRC_DIR}/riscv64/activation_kernel_rvv.cpp
PROPERTIES COMPILE_FLAGS "-march=rv64gcv -mabi=lp64d")
list(APPEND mlas_private_compile_definitions MLAS_USE_RVV=1)
else()
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ Return Value:

--*/
{
#if defined(MLAS_TARGET_AMD64)
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_RISCV64)
GetMlasPlatform().ComputeExpF32Kernel(Input, Output, N);
#else
MlasComputeExpF32Kernel(Input, Output, N);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/erf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ Return Value:

--*/
{
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE)
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) || defined(MLAS_TARGET_RISCV64)
GetMlasPlatform().ErfKernelRoutine(Input, Output, N);
#else
MlasErfKernel(Input, Output, N);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/gelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ MlasComputeGeluErf(
size_t N
)
{
#if defined(MLAS_TARGET_AMD64)
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_RISCV64)
// TODO: Add an intermediate fused AVX2/FMA3 GELU(erf) path on AMD64.
// Today the dispatch jumps from the generic multi-pass implementation to
// AVX512F, so non-AVX512 x64 machines fall back to the generic kernel.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/logistic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ Return Value:

--*/
{
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE)
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) || defined(MLAS_TARGET_RISCV64)
GetMlasPlatform().LogisticKernelRoutine(Input, Output, N);
#else
MlasLogisticKernel(Input, Output, N);
Expand Down
24 changes: 23 additions & 1 deletion onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,7 @@ extern "C" {
#if defined(MLAS_TARGET_AMD64)
MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1Avx;
MLAS_SGEMM_KERNEL_M1_ROUTINE MlasSgemmKernelM1TransposeBAvx;
#elif defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_WASM)
#elif defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_WASM) || defined(MLAS_TARGET_RISCV64)
MLAS_GEMV_FLOAT_KERNEL MlasGemvFloatKernel;
#endif

Expand Down Expand Up @@ -1203,6 +1203,12 @@ extern "C" {
MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelRvv;
MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeSoftmaxOutputF32KernelRvv;
MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL MlasComputeLogSoftmaxOutputF32KernelRvv;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelRvv;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasLogisticKernelRvv;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasGeluErfKernelRvv;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasSiluKernelRvv;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasTanhKernelRvv;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelRvv;
#endif
#if defined(MLAS_TARGET_AMD64)
MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32KernelAvx;
Expand Down Expand Up @@ -1300,6 +1306,10 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8S8DispatchRvv;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchRvv;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchRvv;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8U8DispatchRvv;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10;
extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchZVECTOR;
Expand Down Expand Up @@ -1500,6 +1510,12 @@ struct MLAS_PLATFORM {
MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel;
uint32_t NchwcBlockSize;
#endif
#if defined(MLAS_TARGET_RISCV64)
const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch;
const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch;
const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch;
const MLAS_GEMM_QUANT_DISPATCH* GemmS8U8Dispatch;
#endif
#if defined(MLAS_TARGET_AMD64_IX86)
const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch;
const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch;
Expand Down Expand Up @@ -1552,6 +1568,12 @@ struct MLAS_PLATFORM {
MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel;
MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel;
#endif
#if defined(MLAS_TARGET_RISCV64)
MLAS_COMPUTE_UNARY_FLOAT_KERNEL* GeluErfKernelRoutine;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL* SiluKernelRoutine;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL* TanhKernelRoutine;
MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ComputeExpF32Kernel;
#endif

MLAS_COMPUTE_ERF_FP16_KERNEL* ErfFP16KernelRoutine = nullptr;
MLAS_COMPUTE_GELU_FP16_KERNEL* GeluFP16KernelRoutine = nullptr;
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,16 @@ Return Value:

#if defined(MLAS_TARGET_RISCV64)
this->GemmFloatKernel = nullptr;
this->GemmU8S8Dispatch = &MlasGemmQuantDispatchDefault;
this->GemmU8U8Dispatch = &MlasGemmQuantDispatchDefault;
this->GemmS8S8Dispatch = &MlasGemmQuantDispatchDefault;
this->GemmS8U8Dispatch = &MlasGemmQuantDispatchDefault;
this->ErfKernelRoutine = MlasErfKernel;
this->LogisticKernelRoutine = MlasLogisticKernel;
this->GeluErfKernelRoutine = MlasGeluErfKernel;
this->SiluKernelRoutine = MlasSiluKernel;
this->TanhKernelRoutine = MlasTanhKernel;
this->ComputeExpF32Kernel = MlasComputeExpF32Kernel;
this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel;
this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel;
Expand All @@ -334,6 +342,16 @@ Return Value:
}
if (has_rvv) {
this->GemmFloatKernel = MlasGemmFloatKernelRvv;
this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchRvv;
this->GemmU8U8Dispatch = &MlasGemmU8U8DispatchRvv;
this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchRvv;
this->GemmS8U8Dispatch = &MlasGemmS8U8DispatchRvv;
this->ErfKernelRoutine = MlasErfKernelRvv;
this->LogisticKernelRoutine = MlasLogisticKernelRvv;
this->GeluErfKernelRoutine = MlasGeluErfKernelRvv;
this->SiluKernelRoutine = MlasSiluKernelRvv;
this->TanhKernelRoutine = MlasTanhKernelRvv;
this->ComputeExpF32Kernel = MlasComputeExpF32KernelRvv;
this->ReduceMaximumF32Kernel = MlasReduceMaximumF32KernelRvv;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelRvv;
this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32KernelRvv;
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/mlas/lib/qgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,14 @@ MlasGemmQuantGetDispatch(
GemmQuantDispatch =
BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch;
}
#elif defined(MLAS_TARGET_RISCV64)
if (AIsSigned) {
GemmQuantDispatch =
BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch;
} else { // !AIsSigned
GemmQuantDispatch =
BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch;
}
#elif defined(MLAS_TARGET_S390X)
if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchZVECTOR) {
GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch;
Expand Down
Loading
Loading