Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir-tensorrt/CMakePresets.json
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@
"MLIR_TRT_ENABLE_NCCL": "OFF",
"MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}"
}
},
{
"name": "python-wheel-build",
"displayName": "Configuration for building the compiler/runtime Python package wheels",
"generator": "Ninja",
"binaryDir": "build",
"inherits": "ninja-llvm",
"cacheVariables": {
"CMAKE_BUILD_TYPE": "Release",
"LLVM_ENABLE_ASSERTIONS": "OFF",
"CMAKE_PLATFORM_NO_VERSIONED_SONAME": "ON",
"MLIR_TRT_ENABLE_NCCL": "OFF",
"MLIR_TRT_DOWNLOAD_TENSORRT_VERSION": "$env{DOWNLOAD_TENSORRT_VERSION}"
}
}
]
}
4 changes: 2 additions & 2 deletions mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ macro(configure_tensorrt_python_plugin_header)
find_file(
trt_python_plugin_header
NAMES NvInferPythonPlugin.h plugin.h
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl ${ARG_INSTALL_DIR}/include/impl
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl ${ARG_INSTALL_DIR}/include/impl
REQUIRED
NO_CMAKE_PATH NO_DEFAULT_PATH
NO_CACHE
Expand Down
7 changes: 7 additions & 0 deletions mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
set(ARG_VERSION "10.12.0.36")
endif()

if(ARG_VERSION VERSION_EQUAL "10.14")
set(ARG_VERSION "10.14.1.48")
endif()

set(downloadable_versions
"8.6.1.6"
"9.0.1.4" "9.1.0.4" "9.2.0.5"
Expand All @@ -97,6 +101,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
"10.8.0.43"
"10.9.0.34"
"10.12.0.36"
"10.14.1.48"
)

if(NOT ARG_VERSION IN_LIST downloadable_versions)
Expand Down Expand Up @@ -164,6 +169,8 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
elseif(ARG_VERSION VERSION_GREATER 10.10
AND ARG_VERSION VERSION_LESS 10.13)
set(TRT_CUDA_VERSION 12.9)
elseif(ARG_VERSION VERSION_GREATER 10.13)
set(TRT_CUDA_VERSION 13.0)
endif()

# Handle TRT 8 versions.
Expand Down
2 changes: 1 addition & 1 deletion mlir-tensorrt/build_tools/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ case "${LINUX_DISTRO}" in
dnf install -y \
which wget gcc zlib-devel bzip2 bzip2-devel readline-devel sqlite \
sqlite-devel xz xz-devel libffi-devel curl git ncurses-devel \
openssh-clients libcudnn8-devel zip jq \
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudnn8 conflicts with cudnn9 in the base container.

openssh-clients zip jq \
protobuf-compiler autoconf automake libtool dnf-plugins-core cmake
dnf config-manager --set-enabled powertools
dnf -y install gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def test_attributes():
tensorrt.TripLimitAttr.get("kWHILE"),
tensorrt.FillOperationAttr.get("kRANDOM_UNIFORM"),
tensorrt.ScatterModeAttr.get("kELEMENT"),
tensorrt.AttentionNormalizationOpAttr.get("kSOFTMAX"),
tensorrt.DataTypeAttr.get("kFLOAT"),
]:
print(attr)

Expand All @@ -74,3 +76,5 @@ def test_attributes():
# CHECK-NEXT: #tensorrt.trip_limit<kWHILE>
# CHECK-NEXT: #tensorrt.fill_operation<kRANDOM_UNIFORM>
# CHECK-NEXT: #tensorrt.scatter_mode<kELEMENT>
# CHECK-NEXT: #tensorrt.attention_normalization_op<kSOFTMAX>
# CHECK-NEXT: #tensorrt.data_type<kFLOAT>
2 changes: 1 addition & 1 deletion mlir-tensorrt/compiler/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ set(LLVM_LINK_COMPONENTS
add_subdirectory(mlir-tensorrt-opt)
add_subdirectory(mlir-tensorrt-compiler)
add_subdirectory(mlir-tensorrt-translate)
add_subdirectory(mlir-tensorrt-lsp-server)
# add_subdirectory(mlir-tensorrt-lsp-server)
add_subdirectory(mlir-tensorrt-runner)
6 changes: 3 additions & 3 deletions mlir-tensorrt/integrations/python/setup_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import subprocess
import atexit

TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.12")
TENSORRT_VERSION = os.getenv("MLIR_TRT_DOWNLOAD_TENSORRT_VERSION", "10.14")


def log(*args):
Expand Down Expand Up @@ -105,8 +105,8 @@ def run_cmake_build(python_package_name: str, python_wheel_staging_dir: Path):

# Environment variable overrides
cmake_preset = os.environ.get("MLIR_TRT_CMAKE_PRESET", "python-wheel-build")
install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", None)
build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", None)
install_prefix = os.environ.get("MLIR_TRT_INSTALL_DIR", "./install")
build_dir = os.environ.get("MLIR_TRT_BUILD_DIR", "./build")
parallel_jobs = os.environ.get("MLIR_TRT_PARALLEL_JOBS", str(os.cpu_count() or 1))

# Additional CMake options from environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,22 @@ DECLARE_ATTR_GETTER_FROM_STRING(ScatterMode)
DECLARE_IS_ATTR(ScatterMode)
DECLARE_STRING_GETTER_FROM_ATTR(ScatterMode)

//===----------------------------------------------------------------------===//
// AttentionNormalizationOp
//===----------------------------------------------------------------------===//

DECLARE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp)
DECLARE_IS_ATTR(AttentionNormalizationOp)
DECLARE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp)

//===----------------------------------------------------------------------===//
// DataType
//===----------------------------------------------------------------------===//

DECLARE_ATTR_GETTER_FROM_STRING(DataType)
DECLARE_IS_ATTR(DataType)
DECLARE_STRING_GETTER_FROM_ATTR(DataType)

#ifdef __cplusplus
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,4 +378,42 @@ def TensorRT_ScatterMode : TensorRT_I32EnumAttr<
def TensorRT_ScatterModeAttr : TensorRT_EnumAttr<TensorRT_ScatterMode, "scatter_mode">{
}

def TensorRT_AttentionNormalizationOp : TensorRT_I32EnumAttr<
"AttentionNormalizationOp", "",
[
I32EnumAttrCase<"kNONE", 0>,
I32EnumAttrCase<"kSOFTMAX", 1>
]>
{
let cppNamespace = "::mlir::tensorrt";
let genSpecializedAttr = 0;
}

def TensorRT_AttentionNormalizationOpAttr : TensorRT_EnumAttr<TensorRT_AttentionNormalizationOp, "attention_normalization_op">{
}

def TensorRT_DataType : TensorRT_I32EnumAttr<
"DataType", "",
[
I32EnumAttrCase<"kFLOAT", 0>,
I32EnumAttrCase<"kHALF", 1>,
I32EnumAttrCase<"kINT8", 2>,
I32EnumAttrCase<"kINT32", 3>,
I32EnumAttrCase<"kBOOL", 4>,
I32EnumAttrCase<"kUINT8", 5>,
I32EnumAttrCase<"kFP8", 6>,
I32EnumAttrCase<"kBF16", 7>,
I32EnumAttrCase<"kINT64", 8>,
I32EnumAttrCase<"kINT4", 9>,
I32EnumAttrCase<"kFP4", 10>,
I32EnumAttrCase<"kE8M0", 11>
]>
{
let cppNamespace = "::mlir::tensorrt";
let genSpecializedAttr = 0;
}

def TensorRT_DataTypeAttr : TensorRT_EnumAttr<TensorRT_DataType, "data_type">{
}

#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTENUMS
Original file line number Diff line number Diff line change
Expand Up @@ -3504,6 +3504,171 @@ def TensorRT_DequantizeOp : TensorRT_Op<"dequantize",
}];
}

//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//

def TensorRT_AttentionOp : TensorRT_Op<"attention",
[Pure, AttrSizedOperandSegments, TensorRTPartiallyInferTensorResultTypes,
AllElementTypesMatch<["query", "key", "value"]>,
AllRanksMatch<["query", "key", "value"]>]>{
let summary = "TensorRT attention (IAttention) operation";
let description = [{
The `tensorrt.attention` operation implements a fused attention mechanism
that consumes query, key, and value tensors. The operation implicitly includes
two matrix multiplication layers (BMM1 and BMM2) and a normalization operation
(typically softmax).

By default, TensorRT will try to use a single fused kernel for better efficiency.
The operation can optionally be decomposed into multiple kernels if no fused
kernel is available by setting `decomposable` to true.

#### Architecture:

```
Query Key Value Mask (optional) NormalizationQuantizeScale (optional)
| | | | |
| Transpose | | |
| | | | |
----BMM1---- | | |
| | | |
*--------------------------- |
| | |
Normalization | |
| | |
*------------------------------------------------
| |
-------BMM2------
|
Output
```

#### Inputs:

- Query: tensor of type f32, f16, or bf16 with shape
[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]
- Key: tensor of type f32, f16, or bf16 with shape
[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
- Value: tensor of type f32, f16, or bf16 with shape
[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
- Mask (optional): tensor of type i1 or same type as BMM1 output with shape
[batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue]
where batchSize and numHeadsQuery are broadcastable. For i1 mask, true
indicates the position is allowed to attend. For other types, mask values
are added to BMM1 output.
- NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16
with rank 0 (scalar) or 1 (1D tensor), used for quantizing the normalization output.
Required when normalization_quantize_to_type is specified.

#### Attributes:

- normalization_operation: The normalization operation to use (default: kSOFTMAX)
- causal: Whether to use causal masking (default: false). Cannot be used with mask input.
- decomposable: Whether the operation can be decomposed (default: false)
- normalization_quantize_to_type: Optional output type for quantized normalization.
When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided.

#### Constraints:

- All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead]
- Query, key, and value must have the same element type (f32, f16, or bf16)
- If normalization_quantize_to_type is specified:
* It must be kFP8 or kINT8
* normalization_quantize_scale input must be provided
- If normalization_quantize_scale is provided:
* normalization_quantize_to_type must be specified
* Element type must be f32, f16, or bf16
* Rank must be 0 (scalar) or 1 (1D tensor)
- Cannot use both mask input and causal=true simultaneously

#### Examples:

Basic attention:
```mlir
%output = tensorrt.attention ins(%query, %key, %value :
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
-> tensor<2x8x128x64xf16>
```

Causal attention:
```mlir
%output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value :
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
-> tensor<2x8x128x64xf16>
```

Attention with quantization:
```mlir
%scale = tensorrt.constant dense<1.0> : tensor<f32>
%output_quant = tensorrt.attention {
normalization_quantize_to_type = #tensorrt.data_type<kFP8>
} ins(%query, %key, %value,
normalization_quantize_scale = %scale :
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>,
tensor<2x8x128x64xf16>, tensor<f32>)
-> tensor<2x8x128x64xf16>
```
}];

let arguments = (ins
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query,
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key,
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
Optional<TensorRT_Tensor>:$mask,
Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
DefaultValuedAttr<TensorRT_AttentionNormalizationOpAttr, "tensorrt::AttentionNormalizationOp::kSOFTMAX">:$normalization_operation,
DefaultValuedAttr<BoolAttr, "false">:$causal,
DefaultValuedAttr<BoolAttr, "false">:$decomposable,
OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
);

let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result);

let assemblyFormat = [{
attr-dict `ins` `(` $query `,` $key `,` $value
(`,` `mask` `=` $mask^)?
(`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)?
`:` type($query) `,` type($key) `,` type($value)
(`,` type($mask)^)?
(`,` type($normalization_quantize_scale)^)?
`)` `->` type($result)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true if created op is valid for TensorRT major version.
bool isValidForTensorRTVersion(int64_t trtMajorVersion);
}] # baseClassDeclaration;

let trtLayerAdd = [{
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal);
if (!layer)
return failure();

if ($mask)
layer->setMask(*$mask);

layer->setDecomposable($decomposable);

if ($normalization_quantize_scale) {
layer->setNormalizationQuantizeScale(*$normalization_quantize_scale);
}

if ($normalization_quantize_to_type) {
auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type);
if (!convertedDataType)
return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum";
layer->setNormalizationQuantizeToType(*convertedDataType);
}

$results.push_back(layer->getOutput(0));
#if MLIR_TRT_COMPILE_TIME_TENSORRT_VERSION_GTE(10, 15, 0)
layer->setMetadata($op);
#endif
}];
}

//===----------------------------------------------------------------------===//
// TensorRT Dialect Extension Operations
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,6 @@ PYBIND11_MODULE(_tensorrt, m) {
ADD_PYTHON_ATTRIBUTE_ADAPTOR(TripLimit)
ADD_PYTHON_ATTRIBUTE_ADAPTOR(FillOperation)
ADD_PYTHON_ATTRIBUTE_ADAPTOR(ScatterMode)
ADD_PYTHON_ATTRIBUTE_ADAPTOR(AttentionNormalizationOp)
ADD_PYTHON_ATTRIBUTE_ADAPTOR(DataType)
}
8 changes: 8 additions & 0 deletions mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,11 @@ DEFINE_STRING_GETTER_FROM_ATTR(FillOperation)
DEFINE_ATTR_GETTER_FROM_STRING(ScatterMode)
DEFINE_IS_ATTR(ScatterMode)
DEFINE_STRING_GETTER_FROM_ATTR(ScatterMode)

DEFINE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp)
DEFINE_IS_ATTR(AttentionNormalizationOp)
DEFINE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp)

DEFINE_ATTR_GETTER_FROM_STRING(DataType)
DEFINE_IS_ATTR(DataType)
DEFINE_STRING_GETTER_FROM_ATTR(DataType)
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,16 @@ bool tensorrt::ScatterElementsOp::isValidForTensorRTVersion(
return isValidForTensorRTVersionScatterOpImpl(
trtMajorVersion, dataElementType, indicesElementType);
}

//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//

bool tensorrt::AttentionOp::isValidForTensorRTVersion(
int64_t trtMajorVersion) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also check the minor version here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christopherbate Could you advise if there's a convenient way to do this? Thanks!

// IAttention layer is only supported in TensorRT >= 10.14.0
if (trtMajorVersion < 10)
return false;

return true;
}
Loading