Skip to content

Commit c56ce5f

Browse files
committed
Fix Attention addLayer, make cmake to work with TRT 10.14
1 parent 57c5c8d commit c56ce5f

File tree

8 files changed

+113
-17
lines changed

8 files changed

+113
-17
lines changed

mlir-tensorrt/CMakePresets.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@
100100
"MLIR_TRT_ENABLE_NCCL": "OFF",
101101
"MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}"
102102
}
103+
},
104+
{
105+
"name": "python-wheel-build",
106+
"displayName": "Configuration for building the compiler/runtime Python package wheels",
107+
"generator": "Ninja",
108+
"binaryDir": "build",
109+
"inherits": "ninja-llvm",
110+
"cacheVariables": {
111+
"CMAKE_BUILD_TYPE": "Release",
112+
"LLVM_ENABLE_ASSERTIONS": "OFF",
113+
"CMAKE_PLATFORM_NO_VERSIONED_SONAME": "ON",
114+
"MLIR_TRT_ENABLE_NCCL": "OFF",
115+
"MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}"
116+
}
103117
}
104118
]
105119
}

mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ macro(configure_tensorrt_python_plugin_header)
5757
find_file(
5858
trt_python_plugin_header
5959
NAMES NvInferPythonPlugin.h plugin.h
60-
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
61-
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
60+
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/include/impl
61+
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/include/impl
6262
REQUIRED
6363
NO_CMAKE_PATH NO_DEFAULT_PATH
6464
NO_CACHE

mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
8080
set(ARG_VERSION "10.12.0.36")
8181
endif()
8282

83+
if(ARG_VERSION VERSION_EQUAL "10.14")
84+
set(ARG_VERSION "10.14.1.48")
85+
endif()
86+
8387
set(downloadable_versions
8488
"8.6.1.6"
8589
"9.0.1.4" "9.1.0.4" "9.2.0.5"
@@ -97,6 +101,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
97101
"10.8.0.43"
98102
"10.9.0.34"
99103
"10.12.0.36"
104+
"10.14.1.48"
100105
)
101106

102107
if(NOT ARG_VERSION IN_LIST downloadable_versions)
@@ -164,6 +169,8 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
164169
elseif(ARG_VERSION VERSION_GREATER 10.10
165170
AND ARG_VERSION VERSION_LESS 10.13)
166171
set(TRT_CUDA_VERSION 12.9)
172+
elseif(ARG_VERSION VERSION_GREATER 10.13)
173+
set(TRT_CUDA_VERSION 13.0)
167174
endif()
168175

169176
# Handle TRT 8 versions.

mlir-tensorrt/compiler/tools/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ set(LLVM_LINK_COMPONENTS
2121
add_subdirectory(mlir-tensorrt-opt)
2222
add_subdirectory(mlir-tensorrt-compiler)
2323
add_subdirectory(mlir-tensorrt-translate)
24-
add_subdirectory(mlir-tensorrt-lsp-server)
24+
# add_subdirectory(mlir-tensorrt-lsp-server)
2525
add_subdirectory(mlir-tensorrt-runner)

mlir-tensorrt/integrations/python/setup_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import subprocess
1414
import atexit
1515

16-
TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.12")
16+
TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.14")
1717

1818

1919
def log(*args):
@@ -105,8 +105,8 @@ def run_cmake_build(python_package_name: str, python_wheel_staging_dir: Path):
105105

106106
# Environment variable overrides
107107
cmake_preset = os.environ.get("MLIR_TRT_CMAKE_PRESET", "python-wheel-build")
108-
install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", None)
109-
build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", None)
108+
install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", "./install")
109+
build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", "./build")
110110
parallel_jobs = os.environ.get("MLIR_TRT_PARALLEL_JOBS", str(os.cpu_count() or 1))
111111

112112
# Additional CMake options from environment

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4485,7 +4485,8 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
44854485
indicates the position is allowed to attend. For other types, mask values
44864486
are added to BMM1 output.
44874487
- NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16
4488-
with rank 0 or 1, used for quantizing the normalization output.
4488+
with rank 0 (scalar) or 1 (1D tensor), used for quantizing the normalization output.
4489+
Required when normalization_quantize_to_type is specified.
44894490

44904491
#### Attributes:
44914492

@@ -4502,6 +4503,10 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45024503
- If normalization_quantize_to_type is specified:
45034504
* It must be kFP8 or kINT8
45044505
* normalization_quantize_scale input must be provided
4506+
- If normalization_quantize_scale is provided:
4507+
* normalization_quantize_to_type must be specified
4508+
* Element type must be f32, f16, or bf16
4509+
* Rank must be 0 (scalar) or 1 (1D tensor)
45054510
- Cannot use both mask input and causal=true simultaneously
45064511

45074512
#### Examples:
@@ -4539,7 +4544,7 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45394544
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
45404545
Optional<TensorRT_Tensor>:$mask,
45414546
Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
4542-
OptionalAttr<TensorRT_AttentionNormalizationOpAttr>:$normalization_operation,
4547+
DefaultValuedAttr<TensorRT_AttentionNormalizationOpAttr, "tensorrt::AttentionNormalizationOp::kSOFTMAX">:$normalization_operation,
45434548
DefaultValuedAttr<BoolAttr, "false">:$causal,
45444549
DefaultValuedAttr<BoolAttr, "false">:$decomposable,
45454550
OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
@@ -4565,12 +4570,7 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45654570
}] # baseClassDeclaration;
45664571

45674572
let trtLayerAdd = [{
4568-
// Get normalization operation, default to kSOFTMAX
4569-
nvinfer1::AttentionNormalizationOp normOp = $normalization_operation
4570-
? *$normalization_operation
4571-
: nvinfer1::AttentionNormalizationOp::kSOFTMAX;
4572-
4573-
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal);
4573+
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal);
45744574
if (!layer)
45754575
return failure();
45764576

@@ -4584,19 +4584,22 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45844584
}
45854585

45864586
if ($normalization_quantize_to_type) {
4587-
layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type);
4587+
auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type);
4588+
if (!convertedDataType)
4589+
return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum";
4590+
layer->setNormalizationQuantizeToType(*convertedDataType);
45884591
}
45894592

45904593
if (!$e.isStronglyTyped()){
45914594
FailureOr<nvinfer1::DataType> outputTrtType = getNvInferDataType($op.getLoc(),
45924595
$op.getType().getElementType());
45934596
if (failed(outputTrtType))
45944597
return failure();
4595-
layer->setOutputType(0, *outputTrtType);
45964598
}
45974599

45984600
$results.push_back(layer->getOutput(0));
4599-
$e.setMetadata(layer, $op);
4601+
// TODO: nvinfer1::IAttention does not have setMetadata API in 10.14
4602+
// layer->setMetadata($op);
46004603
}];
46014604
}
46024605

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,3 +1633,19 @@ LogicalResult tensorrt::DequantizeOp::inferReturnTypeComponents(
16331633
/*elementType=*/nullptr);
16341634
return success();
16351635
}
1636+
1637+
//===----------------------------------------------------------------------===//
1638+
// AttentionOp
1639+
//===----------------------------------------------------------------------===//
1640+
1641+
LogicalResult tensorrt::AttentionOp::inferReturnTypeComponents(
1642+
MLIRContext *ctx, std::optional<Location> loc, ValueShapeRange operands,
1643+
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1644+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1645+
AttentionOp::Adaptor adaptor(operands, attributes, properties, regions);
1646+
auto queryType = cast<RankedTensorType>(adaptor.getQuery().getType());
1647+
inferredReturnShapes.emplace_back(
1648+
/*vec=*/queryType.getShape(),
1649+
/*elementType=*/queryType.getElementType());
1650+
return success();
1651+
}

mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,3 +1464,59 @@ static LogicalResult verifyAllowedDataTypes(UnaryOp op) {
14641464
LogicalResult tensorrt::UnaryOp::verify() {
14651465
return verifyAllowedDataTypes(*this);
14661466
}
1467+
1468+
//===----------------------------------------------------------------------===//
1469+
// AttentionOp
1470+
//===----------------------------------------------------------------------===//
1471+
1472+
LogicalResult tensorrt::AttentionOp::verify() {
1473+
// Check 1: Cannot use both mask input and causal=true simultaneously
1474+
if (getMask() && getCausal())
1475+
return emitOpError(
1476+
"cannot use both mask input and causal=true simultaneously");
1477+
1478+
// Check 2: If normalization_quantize_to_type is specified, it must be kFP8
1479+
// or kINT8 and normalization_quantize_scale must be provided
1480+
std::optional<DataType> quantizeType = getNormalizationQuantizeToType();
1481+
if (quantizeType.has_value()) {
1482+
if (*quantizeType != DataType::kFP8 && *quantizeType != DataType::kINT8)
1483+
return emitOpError("normalization_quantize_to_type must be kFP8 or "
1484+
"kINT8, but got ")
1485+
<< stringifyDataType(*quantizeType);
1486+
1487+
if (!getNormalizationQuantizeScale())
1488+
return emitOpError(
1489+
"normalization_quantize_scale input must be provided when "
1490+
"normalization_quantize_to_type is specified");
1491+
}
1492+
1493+
// Check 3: If normalization_quantize_scale is provided,
1494+
// normalization_quantize_to_type must be specified
1495+
if (getNormalizationQuantizeScale() && !quantizeType.has_value())
1496+
return emitOpError(
1497+
"normalization_quantize_to_type must be specified when "
1498+
"normalization_quantize_scale input is provided");
1499+
1500+
// Check 4: If normalization_quantize_scale is provided, validate its type
1501+
if (getNormalizationQuantizeScale()) {
1502+
RankedTensorType scaleType = getNormalizationQuantizeScale().getType();
1503+
Type scaleElemType = scaleType.getElementType();
1504+
1505+
// Check that element type is f32, f16, or bf16
1506+
if (!scaleElemType.isF32() && !scaleElemType.isF16() &&
1507+
!scaleElemType.isBF16())
1508+
return emitOpError(
1509+
"normalization_quantize_scale element type must be f32, f16, "
1510+
"or bf16, but got ")
1511+
<< scaleElemType;
1512+
1513+
// Check that scale is rank 0 or 1
1514+
if (scaleType.getRank() != 0 && scaleType.getRank() != 1)
1515+
return emitOpError(
1516+
"normalization_quantize_scale must be rank 0 or 1, but got "
1517+
"rank ")
1518+
<< scaleType.getRank();
1519+
}
1520+
1521+
return success();
1522+
}

0 commit comments

Comments
 (0)