Skip to content

Commit aad62e4

Browse files
committed
Fix Attention addLayer, make cmake to work with TRT 10.14
1 parent 57c5c8d commit aad62e4

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ macro(configure_tensorrt_python_plugin_header)
5757
find_file(
5858
trt_python_plugin_header
5959
NAMES NvInferPythonPlugin.h plugin.h
60-
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
61-
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
60+
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/include/impl
61+
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/include/impl
6262
REQUIRED
6363
NO_CMAKE_PATH NO_DEFAULT_PATH
6464
NO_CACHE

mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
8080
set(ARG_VERSION "10.12.0.36")
8181
endif()
8282

83+
if(ARG_VERSION VERSION_EQUAL "10.14")
84+
set(ARG_VERSION "10.14.1.48")
85+
endif()
86+
8387
set(downloadable_versions
8488
"8.6.1.6"
8589
"9.0.1.4" "9.1.0.4" "9.2.0.5"
@@ -97,6 +101,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
97101
"10.8.0.43"
98102
"10.9.0.34"
99103
"10.12.0.36"
104+
"10.14.1.48"
100105
)
101106

102107
if(NOT ARG_VERSION IN_LIST downloadable_versions)
@@ -164,6 +169,8 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
164169
elseif(ARG_VERSION VERSION_GREATER 10.10
165170
AND ARG_VERSION VERSION_LESS 10.13)
166171
set(TRT_CUDA_VERSION 12.9)
172+
elseif(ARG_VERSION VERSION_GREATER 10.13)
173+
set(TRT_CUDA_VERSION 13.0)
167174
endif()
168175

169176
# Handle TRT 8 versions.

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4539,7 +4539,7 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45394539
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
45404540
Optional<TensorRT_Tensor>:$mask,
45414541
Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
4542-
OptionalAttr<TensorRT_AttentionNormalizationOpAttr>:$normalization_operation,
4542+
DefaultValuedAttr<TensorRT_AttentionNormalizationOpAttr, "tensorrt::AttentionNormalizationOp::kSOFTMAX">:$normalization_operation,
45434543
DefaultValuedAttr<BoolAttr, "false">:$causal,
45444544
DefaultValuedAttr<BoolAttr, "false">:$decomposable,
45454545
OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
@@ -4565,12 +4565,7 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45654565
}] # baseClassDeclaration;
45664566

45674567
let trtLayerAdd = [{
4568-
// Get normalization operation, default to kSOFTMAX
4569-
nvinfer1::AttentionNormalizationOp normOp = $normalization_operation
4570-
? *$normalization_operation
4571-
: nvinfer1::AttentionNormalizationOp::kSOFTMAX;
4572-
4573-
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal);
4568+
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal);
45744569
if (!layer)
45754570
return failure();
45764571

@@ -4584,19 +4579,22 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45844579
}
45854580

45864581
if ($normalization_quantize_to_type) {
4587-
layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type);
4582+
auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type);
4583+
if (!convertedDataType)
4584+
return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum";
4585+
layer->setNormalizationQuantizeToType(*convertedDataType);
45884586
}
45894587

45904588
if (!$e.isStronglyTyped()){
45914589
FailureOr<nvinfer1::DataType> outputTrtType = getNvInferDataType($op.getLoc(),
45924590
$op.getType().getElementType());
45934591
if (failed(outputTrtType))
45944592
return failure();
4595-
layer->setOutputType(0, *outputTrtType);
45964593
}
45974594

45984595
$results.push_back(layer->getOutput(0));
4599-
$e.setMetadata(layer, $op);
4596+
// TODO: nvinfer1::IAttention does not have setMetadata API in 10.14
4597+
// layer->setMetadata($op);
46004598
}];
46014599
}
46024600

0 commit comments

Comments
 (0)