Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
22 changes: 20 additions & 2 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF)
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_ENABLE_CUDA_FP4_QMOE "Build CUDA QMoE FP4 kernels" OFF)
Comment thread
tianleiwu marked this conversation as resolved.
option(onnxruntime_ENABLE_CUDA_FP8_QMOE "Build CUDA QMoE FP8 kernels" OFF)
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF)
option(onnxruntime_USE_FP8_KV_CACHE "Build cuda kernels for fp8 kv cache" ON)
Expand Down Expand Up @@ -786,8 +788,8 @@ if (onnxruntime_USE_CUDA)
endif()

if (onnxruntime_QUICK_BUILD)
message( STATUS "Quick build mode: Flash attention limited to head dimension 128 only")
list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
message( STATUS "Quick build mode: reducing selected CUDA/CUTLASS kernel instantiations")
list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
endif()

if (onnxruntime_USE_INT4_KV_CACHE)
Expand Down Expand Up @@ -1440,6 +1442,12 @@ if (Git_FOUND)
if (onnxruntime_USE_FP8_KV_CACHE)
string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ")
endif()
if (onnxruntime_ENABLE_CUDA_FP4_QMOE)
string(APPEND ORT_BUILD_INFO "fp4-qmoe=1, ")
endif()
if (onnxruntime_ENABLE_CUDA_FP8_QMOE)
string(APPEND ORT_BUILD_INFO "fp8-qmoe=1, ")
endif()
if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN)
string(APPEND ORT_BUILD_INFO "cuda-plugin-ep=1, ")
endif()
Expand All @@ -1466,10 +1474,20 @@ if (onnxruntime_USE_CUDA)
message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_BF16 flag")
add_definitions("-DENABLE_FP8")
message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag")
if(onnxruntime_ENABLE_CUDA_FP8_QMOE)
add_definitions("-DENABLE_CUDA_FP8_QMOE")
message(STATUS "CUDA FP8 QMoE kernels enabled")
endif()

if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
add_definitions("-DENABLE_FP4")
message(STATUS "CUDA Toolkit version is greater or equal than 12.8, enable -DENABLE_FP4 flag")
if(onnxruntime_ENABLE_CUDA_FP4_QMOE)
add_definitions("-DENABLE_CUDA_FP4_QMOE")
message(STATUS "CUDA FP4 QMoE kernels enabled")
endif()
elseif(onnxruntime_ENABLE_CUDA_FP4_QMOE)
message(FATAL_ERROR "onnxruntime_ENABLE_CUDA_FP4_QMOE requires CUDA Toolkit version 12.8 or newer")
endif()

if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0")
Expand Down
16 changes: 14 additions & 2 deletions cmake/external/cuda_configuration.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,30 @@ macro(setup_cuda_architectures)
message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}")

unset(ORT_HAS_SM80_OR_LATER)
unset(ORT_HAS_SM90_OR_LATER)
unset(ORT_HAS_SM100_OR_LATER)
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES_ORIG)
if(CUDA_ARCH MATCHES "^([0-9]+)")
if(CMAKE_MATCH_1 GREATER_EQUAL 80)
set(ORT_HAS_SM80_OR_LATER ON)
break()
endif()
if(CMAKE_MATCH_1 GREATER_EQUAL 90)
set(ORT_HAS_SM90_OR_LATER ON)
endif()
if(CMAKE_MATCH_1 GREATER_EQUAL 100)
set(ORT_HAS_SM100_OR_LATER ON)
endif()
endif()
endforeach()

if(ORT_HAS_SM80_OR_LATER)
add_definitions("-DHAS_SM80_OR_LATER")
endif()
if(ORT_HAS_SM90_OR_LATER)
add_definitions("-DHAS_SM90_OR_LATER")
endif()
if(ORT_HAS_SM100_OR_LATER)
add_definitions("-DHAS_SM100_OR_LATER")
endif()

set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "110" "120")
foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS)
Expand Down
15 changes: 14 additions & 1 deletion cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,20 @@ onnxruntime_fetchcontent_declare(
PATCH_COMMAND ${Patch_EXECUTABLE} --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass_4.4.2.patch
)

# We only consume CUTLASS as a header-only dependency. Avoid FetchContent_MakeAvailable here
# because CUTLASS ships its own CMakeLists.txt that adds many test/example/tool targets and
# may override compile flags. Calling FetchContent_Populate downloads the source tree without
# invoking add_subdirectory.
FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
if(POLICY CMP0169)
# CMake >= 3.30 deprecates the single-argument form of FetchContent_Populate. Keep using
# the OLD policy locally so we can populate without inviting CUTLASS targets into the build.
cmake_policy(PUSH)
cmake_policy(SET CMP0169 OLD)
FetchContent_Populate(cutlass)
cmake_policy(POP)
else()
FetchContent_Populate(cutlass)
endif()
endif()
43 changes: 43 additions & 0 deletions cmake/onnxruntime_cuda_source_filters.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

# Shared filtering logic for CUDA contrib ops .cu source lists.
# Both the main CUDA provider and the plugin EP build use identical filtering
# rules for flash attention (quick build) and MoE GEMM FP4/FP8 kernels.
#
# Usage:
# onnxruntime_filter_cuda_cu_sources(<list_variable_name>)
#
# The macro modifies the named list variable in the caller's scope.

macro(onnxruntime_filter_cuda_cu_sources CU_SRC_LIST)
# Quick build mode: Filter flash attention kernels for faster development iteration.
# - We keep only hdim128 fp16 flash attention kernels in quick build mode.
# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256).
# If new head dimensions are added or removed, update this list to match the supported set.
if(onnxruntime_QUICK_BUILD)
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)")
endif()

if(NOT onnxruntime_ENABLE_CUDA_FP4_QMOE)
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_tma_ws_sm90_fp4_.*\\.generated\\.cu")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_tma_ws_sm120_fp4_.*\\.generated\\.cu")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_tma_ws_sm120_fp8_fp4\\.generated\\.cu")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_kernels_(fp16|bf16)_fp4\\.cu")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_kernels_fp8_fp4\\.cu")
else()
# CUDA 13 PTXAS does not complete the FP4 M=128/N=64 pingpong specializations in
# this build configuration. The dispatcher routes that tile through cooperative
# mainloop variants instead, so exclude only those unused generated units.
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_tma_ws_sm90_fp4_(fp16|bf16)_m128_n64_k[0-9]+_cm[12]_cn[12]_pp(_finalize)?\\.generated\\.cu")
endif()

if(NOT onnxruntime_ENABLE_CUDA_FP8_QMOE)
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_tma_ws_sm90_wfp8_.*\\.generated\\.cu")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_tma_ws_sm120_fp4_fp8_.*\\.generated\\.cu")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_tma_ws_sm120_fp8_fp4\\.generated\\.cu")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_kernels_(fp16|bf16)_fp8\\.cu")
list(FILTER ${CU_SRC_LIST} EXCLUDE REGEX "moe_gemm_kernels_fp8_fp4\\.cu")
endif()
endmacro()
19 changes: 0 additions & 19 deletions cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,6 @@ file(GLOB_RECURSE onnxruntime_cpu_contrib_ops_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/*.cc"
)

file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cc"
)

file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cu"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh"
)

# Quick build mode: Filter flash attention kernels for faster development iteration.
# - We keep only hdim128 fp16 flash attention kernels in quick build mode.
# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256).
# If new head dimensions are added or removed, update this list to match the supported set.
if(onnxruntime_QUICK_BUILD)
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)")
endif()

file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc"
Expand Down
34 changes: 29 additions & 5 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})

# Collect CUDA contrib ops sources
file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cc"
)

file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cu"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cuh"
)

include(onnxruntime_cuda_source_filters.cmake)
onnxruntime_filter_cuda_cu_sources(onnxruntime_cuda_contrib_ops_cu_srcs)

# disable contrib ops conditionally
if(NOT onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL)
if (NOT onnxruntime_ENABLE_ATEN)
Expand All @@ -78,7 +92,6 @@
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs})
list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs})
endif()

Expand Down Expand Up @@ -194,8 +207,7 @@
endif()
endforeach()

# Note: The minimum required CUDA version is greater than 11.3.
# CUDA 11.3+ supports parallel compilation
# Note: CUDA 11.3+ supports parallel compilation
# https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads
set(onnxruntime_NVCC_THREADS "1" CACHE STRING "Number of threads that NVCC can use for compilation.")
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">")
Expand All @@ -214,6 +226,8 @@
endif()
# skip diagnosis error caused by cuda header files
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--diag-suppress=221>")
# CUDA 12.8 also reports deprecated implicit by-copy 'this' captures from CUTLASS headers.
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--diag-suppress=2908>")
endif()

if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
Expand Down Expand Up @@ -310,7 +324,7 @@
message( WARNING "To compile with NHWC ops enabled please compile against cuDNN 9 or newer." )
endif()
endif()
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas CUDNN::cudnn_all cudnn_frontend CUDA::curand CUDA::cufft CUDA::cudart
target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas CUDNN::cudnn_all cudnn_frontend CUDA::curand CUDA::cufft CUDA::cudart CUDA::nvrtc CUDA::cuda_driver
${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
endif()

Expand All @@ -333,15 +347,25 @@
set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA)
set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime")

if("90" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
if(ORT_HAS_SM90_OR_LATER)
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xptxas=-w>)
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-DCUTLASS_ENABLE_GDC_FOR_SM90=1>)
target_compile_definitions(${target} PRIVATE COMPILE_HOPPER_TMA_GEMMS)
target_compile_definitions(${target} PRIVATE COMPILE_HOPPER_TMA_GROUPED_GEMMS)
if (MSVC)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /bigobj>")
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /wd4172>")
endif()
endif()

if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
target_compile_definitions(${target} PRIVATE COMPILE_BLACKWELL_TMA_GROUPED_GEMMS)
endif()

if("120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
target_compile_definitions(${target} PRIVATE COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS)
endif()

if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling
target_link_libraries(${target} PRIVATE CUDA::cupti)
endif()
Expand Down
24 changes: 19 additions & 5 deletions cmake/onnxruntime_providers_cuda_plugin.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$")
# which cannot convert to ep::adapter::OpKernel in the plugin build.
list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/shape_op\\.cc$")

# Exclude contrib llm/ for now. The core CUDA llm kernels are adapter-safe, but
# contrib llm kernels still need their own plugin pass.
list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*")
list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*")

# Exclude contrib training ops (shrunken_gather depends on provider_api.h in header).
list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$")

Expand All @@ -106,6 +101,10 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shr
list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*")
list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*")

# Apply shared CUDA .cu source filtering (flash attention quick build, MoE GEMM FP4/FP8).
include(onnxruntime_cuda_source_filters.cmake)
onnxruntime_filter_cuda_cu_sources(CUDA_PLUGIN_EP_CU_SRCS)

# Create shared library target using the ORT helper function for plugins
onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_plugin
${CUDA_PLUGIN_EP_CC_SRCS}
Expand Down Expand Up @@ -191,6 +190,7 @@ if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE
"$<$<COMPILE_LANGUAGE:CUDA>:--static-global-template-stub=false>"
"$<$<COMPILE_LANGUAGE:CUDA>:--diag-suppress=221>"
"$<$<COMPILE_LANGUAGE:CUDA>:--diag-suppress=2908>"
)

if (MSVC)
Expand All @@ -203,6 +203,20 @@ endif()
include(cudnn_frontend)
include(cutlass)

# TMA compile definitions — mirror config_cuda_provider_shared_module in onnxruntime_providers_cuda.cmake
if(ORT_HAS_SM90_OR_LATER)
target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xptxas=-w>)
target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-DCUTLASS_ENABLE_GDC_FOR_SM90=1>)
target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE COMPILE_HOPPER_TMA_GEMMS)
target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE COMPILE_HOPPER_TMA_GROUPED_GEMMS)
endif()
if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE COMPILE_BLACKWELL_TMA_GROUPED_GEMMS)
endif()
if("120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE COMPILE_BLACKWELL_SM120_TMA_GROUPED_GEMMS)
endif()

# --- Find cuDNN (may be at a custom path via onnxruntime_CUDNN_HOME) ---
set(_CUDNN_SEARCH_PATHS "")
if(onnxruntime_CUDNN_HOME)
Expand Down
19 changes: 18 additions & 1 deletion cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ if(MSVC)
target_sources(onnxruntime_pybind11_state PRIVATE "${ONNXRUNTIME_ROOT}/core/dll/delay_load_hook.cc")

target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>" "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/utf-8>")
target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj")
target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /bigobj>" "$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/bigobj>")
endif()
if(HAS_CAST_FUNCTION_TYPE)
target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type")
Expand Down Expand Up @@ -230,6 +230,23 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE
${pybind11_lib}
Python::NumPy
)

# Starting with Python 3.8 on Windows, PATH environment variable are no longer used to resolve DLL dependencies
# for extension modules or libraries loaded via ctypes.
# To avoid package import issues, we do not link pybind module against the CUDA runtime on Windows, instead of
# os.add_dll_directory() to deal with CUDA paths.
if (onnxruntime_USE_CUDA AND NOT WIN32)
target_sources(onnxruntime_pybind11_state PRIVATE
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu"
)
include(cutlass)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${cutlass_SOURCE_DIR}/include)
endif()
if (onnxruntime_USE_CUDA AND WIN32)
target_compile_definitions(onnxruntime_pybind11_state PRIVATE ORT_NO_CUDA_IN_PYBIND)
endif()

set(onnxruntime_pybind11_state_dependencies
${onnxruntime_EXTERNAL_DEPENDENCIES}
${pybind11_dep}
Expand Down
Loading
Loading