Skip to content

Commit cffa3a2

Browse files
committed
Fix Attention addLayer, make cmake to work with TRT 10.14
1 parent 57c5c8d commit cffa3a2

File tree

12 files changed

+302
-174
lines changed

12 files changed

+302
-174
lines changed

mlir-tensorrt/CMakePresets.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@
100100
"MLIR_TRT_ENABLE_NCCL": "OFF",
101101
"MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}"
102102
}
103+
},
104+
{
105+
"name": "python-wheel-build",
106+
"displayName": "Configuration for building the compiler/runtime Python package wheels",
107+
"generator": "Ninja",
108+
"binaryDir": "build",
109+
"inherits": "ninja-llvm",
110+
"cacheVariables": {
111+
"CMAKE_BUILD_TYPE": "Release",
112+
"LLVM_ENABLE_ASSERTIONS": "OFF",
113+
"CMAKE_PLATFORM_NO_VERSIONED_SONAME": "ON",
114+
"MLIR_TRT_ENABLE_NCCL": "OFF",
115+
"MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}"
116+
}
103117
}
104118
]
105119
}

mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ macro(configure_tensorrt_python_plugin_header)
5757
find_file(
5858
trt_python_plugin_header
5959
NAMES NvInferPythonPlugin.h plugin.h
60-
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
61-
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
60+
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/include/impl
61+
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/include/impl
6262
REQUIRED
6363
NO_CMAKE_PATH NO_DEFAULT_PATH
6464
NO_CACHE

mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
8080
set(ARG_VERSION "10.12.0.36")
8181
endif()
8282

83+
if(ARG_VERSION VERSION_EQUAL "10.14")
84+
set(ARG_VERSION "10.14.1.48")
85+
endif()
86+
8387
set(downloadable_versions
8488
"8.6.1.6"
8589
"9.0.1.4" "9.1.0.4" "9.2.0.5"
@@ -97,6 +101,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
97101
"10.8.0.43"
98102
"10.9.0.34"
99103
"10.12.0.36"
104+
"10.14.1.48"
100105
)
101106

102107
if(NOT ARG_VERSION IN_LIST downloadable_versions)
@@ -164,6 +169,8 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
164169
elseif(ARG_VERSION VERSION_GREATER 10.10
165170
AND ARG_VERSION VERSION_LESS 10.13)
166171
set(TRT_CUDA_VERSION 12.9)
172+
elseif(ARG_VERSION VERSION_GREATER 10.13)
173+
set(TRT_CUDA_VERSION 13.0)
167174
endif()
168175

169176
# Handle TRT 8 versions.

mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def test_attributes():
4949
tensorrt.TripLimitAttr.get("kWHILE"),
5050
tensorrt.FillOperationAttr.get("kRANDOM_UNIFORM"),
5151
tensorrt.ScatterModeAttr.get("kELEMENT"),
52+
tensorrt.AttentionNormalizationOpAttr.get("kSOFTMAX"),
53+
tensorrt.DataTypeAttr.get("kFLOAT"),
5254
]:
5355
print(attr)
5456

@@ -74,3 +76,5 @@ def test_attributes():
7476
# CHECK-NEXT: #tensorrt.trip_limit<kWHILE>
7577
# CHECK-NEXT: #tensorrt.fill_operation<kRANDOM_UNIFORM>
7678
# CHECK-NEXT: #tensorrt.scatter_mode<kELEMENT>
79+
# CHECK-NEXT: #tensorrt.attention_normalization_op<kSOFTMAX>
80+
# CHECK-NEXT: #tensorrt.data_type<kFLOAT>

mlir-tensorrt/compiler/tools/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ set(LLVM_LINK_COMPONENTS
2121
add_subdirectory(mlir-tensorrt-opt)
2222
add_subdirectory(mlir-tensorrt-compiler)
2323
add_subdirectory(mlir-tensorrt-translate)
24-
add_subdirectory(mlir-tensorrt-lsp-server)
24+
# add_subdirectory(mlir-tensorrt-lsp-server)
2525
add_subdirectory(mlir-tensorrt-runner)

mlir-tensorrt/integrations/python/setup_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import subprocess
1414
import atexit
1515

16-
TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.12")
16+
TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.14")
1717

1818

1919
def log(*args):
@@ -105,8 +105,8 @@ def run_cmake_build(python_package_name: str, python_wheel_staging_dir: Path):
105105

106106
# Environment variable overrides
107107
cmake_preset = os.environ.get("MLIR_TRT_CMAKE_PRESET", "python-wheel-build")
108-
install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", None)
109-
build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", None)
108+
install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", "./install")
109+
build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", "./build")
110110
parallel_jobs = os.environ.get("MLIR_TRT_PARALLEL_JOBS", str(os.cpu_count() or 1))
111111

112112
# Additional CMake options from environment

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,22 @@ DECLARE_ATTR_GETTER_FROM_STRING(ScatterMode)
188188
DECLARE_IS_ATTR(ScatterMode)
189189
DECLARE_STRING_GETTER_FROM_ATTR(ScatterMode)
190190

191+
//===----------------------------------------------------------------------===//
192+
// AttentionNormalizationOp
193+
//===----------------------------------------------------------------------===//
194+
195+
DECLARE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp)
196+
DECLARE_IS_ATTR(AttentionNormalizationOp)
197+
DECLARE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp)
198+
199+
//===----------------------------------------------------------------------===//
200+
// DataType
201+
//===----------------------------------------------------------------------===//
202+
203+
DECLARE_ATTR_GETTER_FROM_STRING(DataType)
204+
DECLARE_IS_ATTR(DataType)
205+
DECLARE_STRING_GETTER_FROM_ATTR(DataType)
206+
191207
#ifdef __cplusplus
192208
}
193209
#endif

0 commit comments

Comments
 (0)