Skip to content

Commit 3fd94a8

Browse files
YUNQIUGUOwejoncyxaduprefs-eireHectorSVC
authored
[ORT 1.17.0 Release] Cherry pick 1st round (#19243)
### Description <!-- Describe your changes. --> [ORT 1.17.0 Release] Cherry pick 1st round PR authors please take a look, and let me know if there are any questions about the changes or approve accordingly. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: wejoncy <[email protected]> Co-authored-by: Xavier Dupré <[email protected]> Co-authored-by: Yulong Wang <[email protected]> Co-authored-by: Hector Li <[email protected]> Co-authored-by: luoyu-intel <[email protected]> Co-authored-by: kunal-vaishnavi <[email protected]> Co-authored-by: Chi Lo <[email protected]> Co-authored-by: Ye Wang <[email protected]> Co-authored-by: Adrian Lizarraga <[email protected]> Co-authored-by: snadampal <[email protected]> Co-authored-by: Tianlei Wu <[email protected]> Co-authored-by: Heflin Stephen Raj <[email protected]> Co-authored-by: Yifan Li <[email protected]> Co-authored-by: Yufeng Li <[email protected]> Co-authored-by: Changming Sun <[email protected]>
1 parent daafe63 commit 3fd94a8

File tree

174 files changed

+6708
-25834
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

174 files changed

+6708
-25834
lines changed

cgmanifests/generated/cgmanifest.json

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"component": {
3737
"type": "git",
3838
"git": {
39-
"commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f",
39+
"commitHash": "4a2c63365eff8823a5221db86ef490e828306f9d",
4040
"repositoryUrl": "https://github.com/abseil/abseil-cpp.git"
4141
},
4242
"comments": "abseil_cpp"
@@ -192,6 +192,16 @@
192192
"comments": "mp11"
193193
}
194194
},
195+
{
196+
"component": {
197+
"type": "git",
198+
"git": {
199+
"commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a",
200+
"repositoryUrl": "https://github.com/intel/neural-speed.git"
201+
},
202+
"comments": "neural_speed"
203+
}
204+
},
195205
{
196206
"component": {
197207
"type": "git",

cmake/CMakeLists.txt

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
8787
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
8888
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
8989
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
90-
option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON)
90+
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON)
9191
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
9292
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
9393
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
@@ -96,7 +96,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
9696
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
9797
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
9898

99-
cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF)
10099
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
101100
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
102101

@@ -706,20 +705,16 @@ if (onnxruntime_USE_CUDA)
706705
enable_language(CUDA)
707706
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
708707

708+
if (onnxruntime_DISABLE_CONTRIB_OPS)
709+
set(onnxruntime_USE_FLASH_ATTENTION OFF)
710+
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
711+
endif()
709712
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
710-
message( STATUS "Turn off cutlass since CUDA compiler version < 11.6")
711-
set(onnxruntime_USE_CUTLASS OFF)
713+
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
714+
set(onnxruntime_USE_FLASH_ATTENTION OFF)
715+
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
712716
endif()
713717
else()
714-
set(onnxruntime_USE_CUTLASS OFF)
715-
endif()
716-
717-
if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS)
718-
if (onnxruntime_DISABLE_CONTRIB_OPS)
719-
message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled")
720-
else()
721-
message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled")
722-
endif()
723718
set(onnxruntime_USE_FLASH_ATTENTION OFF)
724719
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
725720
endif()
@@ -905,8 +900,8 @@ function(onnxruntime_set_compile_flags target_name)
905900
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
906901
endif()
907902

908-
if (onnxruntime_USE_CUTLASS)
909-
target_compile_definitions(${target_name} PRIVATE USE_CUTLASS)
903+
if(USE_NEURAL_SPEED)
904+
target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED)
910905
endif()
911906

912907
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON)
@@ -1193,14 +1188,10 @@ if (onnxruntime_USE_DNNL)
11931188
add_compile_definitions(DNNL_OPENMP)
11941189
endif()
11951190

1196-
set(USE_JBLAS FALSE)
1197-
if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD)
1198-
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
1199-
add_compile_definitions(MLAS_JBLAS)
1200-
set(USE_JBLAS TRUE)
1201-
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
1202-
add_compile_definitions(MLAS_JBLAS)
1203-
set(USE_JBLAS TRUE)
1191+
if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD)
1192+
include(neural_speed)
1193+
if (USE_NEURAL_SPEED)
1194+
list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla)
12041195
endif()
12051196
endif()
12061197

cmake/deps.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI.
1313
# See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29
1414
#
15-
abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c
15+
abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9
1616
cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0
1717
date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159
1818
dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445
@@ -34,6 +34,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36
3434
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
3535
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
3636
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
37+
neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939
3738
onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11
3839
#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459)
3940
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035
@@ -54,4 +55,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2
5455
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee
5556
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
5657
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
57-
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299
58+
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299

cmake/external/abseil-cpp.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ if(WIN32 AND NOT Patch_FOUND)
1919
set(ABSL_ENABLE_INSTALL ON)
2020
endif()
2121
# NB! Advancing Abseil version changes its internal namespace,
22-
# currently absl::lts_20230125 which affects abseil-cpp.natvis debugger
22+
# currently absl::lts_20240116 which affects abseil-cpp.natvis debugger
2323
# visualization file, that must be adjusted accordingly, unless we eliminate
2424
# that namespace at build time.
2525
FetchContent_Declare(

cmake/external/abseil-cpp.natvis

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<?xml version="1.0" encoding="utf-8"?>
22
<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010">
3-
<Type Name="absl::lts_20230802::InlinedVector&lt;*&gt;">
3+
<Type Name="absl::lts_20240116::InlinedVector&lt;*&gt;">
44
<Intrinsic Name="_size" Expression="storage_.metadata_.value >> 1"/>
55
<Intrinsic Name="_is_allocated" Expression="(storage_.metadata_.value &amp; 1) == 1"/>
66
<Intrinsic Name="_inlined_data" Expression="($T1*)storage_.data_.inlined.inlined_data"/>
@@ -24,7 +24,7 @@
2424
</Expand>
2525
</Type>
2626
<!-- Should handle both flat hash_set and hash_map -->
27-
<Type Name="absl::lts_20230802::container_internal::raw_hash_set&lt;*&gt;">
27+
<Type Name="absl::lts_20240116::container_internal::raw_hash_set&lt;*&gt;">
2828
<Intrinsic Name="_commonfields" Expression="settings_.value"/>
2929
<Intrinsic Name="_size" Expression="settings_.value.compressed_tuple_.value"/>
3030
<Intrinsic Name="_capacity" Expression="_commonfields().capacity_"/>
@@ -51,7 +51,7 @@
5151
</Type>
5252

5353
<!-- Primitive types stored as a value -->
54-
<Type Name="absl::lts_20230802::container_internal::Storage&lt;*,*,0&gt;">
54+
<Type Name="absl::lts_20240116::container_internal::Storage&lt;*,*,0&gt;">
5555
<DisplayString IncludeView="noparens">*($T1 *){value}</DisplayString>
5656
<DisplayString ExcludeView="noparens">(*($T1 *){value})</DisplayString>
5757
<Expand>
@@ -60,15 +60,15 @@
6060
</Type>
6161

6262
<!-- For storage inherited from the type -->
63-
<Type Name="absl::lts_20230802::container_internal::Storage&lt;*,*,1&gt;">
63+
<Type Name="absl::lts_20240116::container_internal::Storage&lt;*,*,1&gt;">
6464
<DisplayString IncludeView="noparens">*($T1 *)this</DisplayString>
6565
<DisplayString ExcludeView="noparens">(*($T1 *)this)</DisplayString>
6666
<Expand>
6767
<ExpandedItem>*($T1 *)this</ExpandedItem>
6868
</Expand>
6969
</Type>
7070

71-
<Type Name="absl::lts_20230802::container_internal::map_slot_type&lt;*&gt;">
71+
<Type Name="absl::lts_20240116::container_internal::map_slot_type&lt;*&gt;">
7272
<DisplayString IncludeView="noparens">{value.first}, {value.second}</DisplayString>
7373
<DisplayString ExcludeView="noparens">({value.first}, {value.second})</DisplayString>
7474
<Expand>

cmake/external/cutlass.cmake

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
if (onnxruntime_USE_CUTLASS)
2-
include(FetchContent)
3-
FetchContent_Declare(
4-
cutlass
5-
URL ${DEP_URL_cutlass}
6-
URL_HASH SHA1=${DEP_SHA1_cutlass}
7-
)
1+
include(FetchContent)
2+
FetchContent_Declare(
3+
cutlass
4+
URL ${DEP_URL_cutlass}
5+
URL_HASH SHA1=${DEP_SHA1_cutlass}
6+
)
87

9-
FetchContent_GetProperties(cutlass)
10-
if(NOT cutlass_POPULATED)
11-
FetchContent_Populate(cutlass)
12-
endif()
8+
FetchContent_GetProperties(cutlass)
9+
if(NOT cutlass_POPULATED)
10+
FetchContent_Populate(cutlass)
1311
endif()

cmake/external/neural_speed.cmake

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
2+
set(USE_NEURAL_SPEED TRUE)
3+
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
4+
set(USE_NEURAL_SPEED TRUE)
5+
endif()
6+
7+
if(USE_NEURAL_SPEED)
8+
FetchContent_Declare(
9+
neural_speed
10+
URL ${DEP_URL_neural_speed}
11+
URL_HASH SHA1=${DEP_SHA1_neural_speed}
12+
)
13+
set(BTLA_USE_OPENMP OFF)
14+
onnxruntime_fetchcontent_makeavailable(neural_speed)
15+
endif()

cmake/onnxruntime_mlas.cmake

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@ endif()
5757

5858
set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
5959

60-
function(add_jblas)
61-
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
62-
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
63-
target_sources(onnxruntime_mlas PRIVATE
64-
${MLAS_SRC_DIR}/jblas_gemm.cpp
65-
)
66-
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
67-
endfunction()
68-
6960
#TODO: set MASM flags properly
7061
function(setup_mlas_source_for_windows)
7162

@@ -364,19 +355,23 @@ else()
364355
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
365356
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
366357
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
358+
${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
367359
${MLAS_SRC_DIR}/activate_fp16.cpp
368360
${MLAS_SRC_DIR}/dwconv.cpp
369361
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
370362
${MLAS_SRC_DIR}/pooling_fp16.cpp
371363
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
372364
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
365+
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
373366
)
374367
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
375368
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
376369
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
370+
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
377371
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
378372
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
379373
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
374+
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
380375
endif()
381376

382377
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
@@ -622,10 +617,6 @@ else()
622617
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
623618
endif()
624619

625-
if(USE_JBLAS)
626-
add_jblas()
627-
endif()
628-
629620
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
630621
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
631622
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})

cmake/onnxruntime_providers_cpu.cmake

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
6060
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc"
6161
)
6262
endif()
63+
set(onnxruntime_cpu_neural_speed_srcs
64+
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h"
65+
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h"
66+
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc"
67+
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h"
68+
)
69+
if(NOT USE_NEURAL_SPEED)
70+
list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs})
71+
endif()
6372
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
6473
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs})
6574
list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs})
@@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL)
144153
target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical")
145154
endif()
146155

156+
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
157+
if(USE_NEURAL_SPEED)
158+
onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla)
159+
endif()
160+
endif()
161+
147162
if (MSVC)
148163
target_compile_options(onnxruntime_providers PRIVATE "/bigobj")
149164
# if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)

docs/ContribOperators.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3031,6 +3031,8 @@ This version of the operator has been available since version 1 of the 'com.micr
30313031
<dd>Number of attention heads</dd>
30323032
<dt><tt>scale</tt> : float</dt>
30333033
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
3034+
<dt><tt>unidirectional</tt> : int</dt>
3035+
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
30343036
</dl>
30353037

30363038
#### Inputs (1 - 8)
@@ -5021,6 +5023,10 @@ This version of the operator has been available since version 1 of the 'com.micr
50215023
<dl>
50225024
<dt><tt>interleaved</tt> : int</dt>
50235025
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
5026+
<dt><tt>num_heads</tt> : int</dt>
5027+
<dd>Number of attention heads. Default value is 0. Must use with rotary_embedding_dim</dd>
5028+
<dt><tt>rotary_embedding_dim</tt> : int</dt>
5029+
<dd>Rotary embedding dimension. Default value is 0.</dd>
50245030
<dt><tt>scale</tt> : float</dt>
50255031
<dd>Custom scale will be used if specified. Default value is 1.0</dd>
50265032
</dl>
@@ -5033,9 +5039,9 @@ This version of the operator has been available since version 1 of the 'com.micr
50335039
<dt><tt>position_ids</tt> : M</dt>
50345040
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
50355041
<dt><tt>cos_cache</tt> : T</dt>
5036-
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
5042+
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
50375043
<dt><tt>sin_cache</tt> : T</dt>
5038-
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
5044+
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
50395045
</dl>
50405046

50415047
#### Outputs
@@ -5048,7 +5054,7 @@ This version of the operator has been available since version 1 of the 'com.micr
50485054
#### Type Constraints
50495055

50505056
<dl>
5051-
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
5057+
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
50525058
<dd>Constrain input and output types to float tensors.</dd>
50535059
<dt><tt>M</tt> : tensor(int64)</dt>
50545060
<dd>Constrain input and output types to integer tensors</dd>

0 commit comments

Comments
 (0)