Skip to content

Commit 57c5c8d

Browse files
committed
Add AttentionOp
1 parent b344e42 commit 57c5c8d

File tree

4 files changed

+220
-1
lines changed

4 files changed

+220
-1
lines changed

mlir-tensorrt/build_tools/docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ case "${LINUX_DISTRO}" in
3535
dnf install -y \
3636
which wget gcc zlib-devel bzip2 bzip2-devel readline-devel sqlite \
3737
sqlite-devel xz xz-devel libffi-devel curl git ncurses-devel \
38-
openssh-clients libcudnn8-devel zip jq \
38+
openssh-clients zip jq \
3939
protobuf-compiler autoconf automake libtool dnf-plugins-core cmake
4040
dnf config-manager --set-enabled powertools
4141
dnf -y install gcc-toolset-11-gcc gcc-toolset-11-gcc-c++

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,42 @@ def TensorRT_ScatterMode : TensorRT_I32EnumAttr<
378378
def TensorRT_ScatterModeAttr : TensorRT_EnumAttr<TensorRT_ScatterMode, "scatter_mode">{
379379
}
380380

381+
def TensorRT_AttentionNormalizationOp : TensorRT_I32EnumAttr<
382+
"AttentionNormalizationOp", "",
383+
[
384+
I32EnumAttrCase<"kNONE", 0>,
385+
I32EnumAttrCase<"kSOFTMAX", 1>
386+
]>
387+
{
388+
let cppNamespace = "::mlir::tensorrt";
389+
let genSpecializedAttr = 0;
390+
}
391+
392+
def TensorRT_AttentionNormalizationOpAttr : TensorRT_EnumAttr<TensorRT_AttentionNormalizationOp, "attention_normalization_op">{
393+
}
394+
395+
def TensorRT_DataType : TensorRT_I32EnumAttr<
396+
"DataType", "",
397+
[
398+
I32EnumAttrCase<"kFLOAT", 0>,
399+
I32EnumAttrCase<"kHALF", 1>,
400+
I32EnumAttrCase<"kINT8", 2>,
401+
I32EnumAttrCase<"kINT32", 3>,
402+
I32EnumAttrCase<"kBOOL", 4>,
403+
I32EnumAttrCase<"kUINT8", 5>,
404+
I32EnumAttrCase<"kFP8", 6>,
405+
I32EnumAttrCase<"kBF16", 7>,
406+
I32EnumAttrCase<"kINT64", 8>,
407+
I32EnumAttrCase<"kINT4", 9>,
408+
I32EnumAttrCase<"kFP4", 10>,
409+
I32EnumAttrCase<"kE8M0", 11>
410+
]>
411+
{
412+
let cppNamespace = "::mlir::tensorrt";
413+
let genSpecializedAttr = 0;
414+
}
415+
416+
def TensorRT_DataTypeAttr : TensorRT_EnumAttr<TensorRT_DataType, "data_type">{
417+
}
418+
381419
#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTENUMS

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4432,4 +4432,172 @@ def TensorRT_ScatterElementsOp : TensorRT_Op<"scatter_elements",
44324432
}];
44334433
}
44344434

4435+
//===----------------------------------------------------------------------===//
4436+
// AttentionOp
4437+
//===----------------------------------------------------------------------===//
4438+
4439+
def TensorRT_AttentionOp : TensorRT_Op<"attention",
4440+
[Pure, AttrSizedOperandSegments, TensorRTInferTensorResultTypes,
4441+
AllElementTypesMatch<["query", "key", "value"]>,
4442+
AllRanksMatch<["query", "key", "value"]>]>{
4443+
let summary = "TensorRT attention (IAttention) operation";
4444+
let description = [{
4445+
The `tensorrt.attention` operation implements a fused attention mechanism
4446+
that consumes query, key, and value tensors. The operation implicitly includes
4447+
two matrix multiplication layers (BMM1 and BMM2) and a normalization operation
4448+
(typically softmax).
4449+
4450+
By default, TensorRT will try to use a single fused kernel for better efficiency.
4451+
The operation can optionally be decomposed into multiple kernels if no fused
4452+
kernel is available by setting `decomposable` to true.
4453+
4454+
#### Architecture:
4455+
4456+
```
4457+
Query Key Value Mask (optional) NormalizationQuantizeScale (optional)
4458+
| | | | |
4459+
| Transpose | | |
4460+
| | | | |
4461+
----BMM1---- | | |
4462+
| | | |
4463+
*--------------------------- |
4464+
| | |
4465+
Normalization | |
4466+
| | |
4467+
*------------------------------------------------
4468+
| |
4469+
-------BMM2------
4470+
|
4471+
Output
4472+
```
4473+
4474+
#### Inputs:
4475+
4476+
- Query: tensor of type f32, f16, or bf16 with shape
4477+
[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]
4478+
- Key: tensor of type f32, f16, or bf16 with shape
4479+
[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
4480+
- Value: tensor of type f32, f16, or bf16 with shape
4481+
[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
4482+
- Mask (optional): tensor of type i1 or same type as BMM1 output with shape
4483+
[batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue]
4484+
where batchSize and numHeadsQuery are broadcastable. For i1 mask, true
4485+
indicates the position is allowed to attend. For other types, mask values
4486+
are added to BMM1 output.
4487+
- NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16
4488+
with rank 0 or 1, used for quantizing the normalization output.
4489+
4490+
#### Attributes:
4491+
4492+
- normalization_operation: The normalization operation to use (default: kSOFTMAX)
4493+
- causal: Whether to use causal masking (default: false). Cannot be used with mask input.
4494+
- decomposable: Whether the operation can be decomposed (default: false)
4495+
- normalization_quantize_to_type: Optional output type for quantized normalization.
4496+
When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided.
4497+
4498+
#### Constraints:
4499+
4500+
- All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead]
4501+
- Query, key, and value must have the same element type (f32, f16, or bf16)
4502+
- If normalization_quantize_to_type is specified:
4503+
* It must be kFP8 or kINT8
4504+
* normalization_quantize_scale input must be provided
4505+
- Cannot use both mask input and causal=true simultaneously
4506+
4507+
#### Examples:
4508+
4509+
Basic attention:
4510+
```mlir
4511+
%output = tensorrt.attention ins(%query, %key, %value :
4512+
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
4513+
-> tensor<2x8x128x64xf16>
4514+
```
4515+
4516+
Causal attention:
4517+
```mlir
4518+
%output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value :
4519+
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
4520+
-> tensor<2x8x128x64xf16>
4521+
```
4522+
4523+
Attention with quantization:
4524+
```mlir
4525+
%scale = tensorrt.constant dense<1.0> : tensor<f32>
4526+
%output_quant = tensorrt.attention {
4527+
normalization_quantize_to_type = #tensorrt.data_type<kFP8>
4528+
} ins(%query, %key, %value,
4529+
normalization_quantize_scale = %scale :
4530+
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>,
4531+
tensor<2x8x128x64xf16>, tensor<f32>)
4532+
-> tensor<2x8x128x64xf16>
4533+
```
4534+
}];
4535+
4536+
let arguments = (ins
4537+
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query,
4538+
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key,
4539+
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
4540+
Optional<TensorRT_Tensor>:$mask,
4541+
Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
4542+
OptionalAttr<TensorRT_AttentionNormalizationOpAttr>:$normalization_operation,
4543+
DefaultValuedAttr<BoolAttr, "false">:$causal,
4544+
DefaultValuedAttr<BoolAttr, "false">:$decomposable,
4545+
OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
4546+
);
4547+
4548+
let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result);
4549+
4550+
let assemblyFormat = [{
4551+
attr-dict `ins` `(` $query `,` $key `,` $value
4552+
(`,` `mask` `=` $mask^)?
4553+
(`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)?
4554+
`:` type($query) `,` type($key) `,` type($value)
4555+
(`,` type($mask)^)?
4556+
(`,` type($normalization_quantize_scale)^)?
4557+
`)` `->` type($result)
4558+
}];
4559+
4560+
let hasVerifier = 1;
4561+
4562+
let extraClassDeclaration = [{
4563+
/// Returns true if created op is valid for TensorRT major version.
4564+
bool isValidForTensorRTVersion(int64_t trtMajorVersion);
4565+
}] # baseClassDeclaration;
4566+
4567+
let trtLayerAdd = [{
4568+
// Get normalization operation, default to kSOFTMAX
4569+
nvinfer1::AttentionNormalizationOp normOp = $normalization_operation
4570+
? *$normalization_operation
4571+
: nvinfer1::AttentionNormalizationOp::kSOFTMAX;
4572+
4573+
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal);
4574+
if (!layer)
4575+
return failure();
4576+
4577+
if ($mask)
4578+
layer->setMask(*$mask);
4579+
4580+
layer->setDecomposable($decomposable);
4581+
4582+
if ($normalization_quantize_scale) {
4583+
layer->setNormalizationQuantizeScale(*$normalization_quantize_scale);
4584+
}
4585+
4586+
if ($normalization_quantize_to_type) {
4587+
layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type);
4588+
}
4589+
4590+
if (!$e.isStronglyTyped()){
4591+
FailureOr<nvinfer1::DataType> outputTrtType = getNvInferDataType($op.getLoc(),
4592+
$op.getType().getElementType());
4593+
if (failed(outputTrtType))
4594+
return failure();
4595+
layer->setOutputType(0, *outputTrtType);
4596+
}
4597+
4598+
$results.push_back(layer->getOutput(0));
4599+
$e.setMetadata(layer, $op);
4600+
}];
4601+
}
4602+
44354603
#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTOPS_TD

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,3 +914,16 @@ bool tensorrt::ScatterElementsOp::isValidForTensorRTVersion(
914914
return isValidForTensorRTVersionScatterOpImpl(
915915
trtMajorVersion, dataElementType, indicesElementType);
916916
}
917+
918+
//===----------------------------------------------------------------------===//
919+
// AttentionOp
920+
//===----------------------------------------------------------------------===//
921+
922+
bool tensorrt::AttentionOp::isValidForTensorRTVersion(
923+
int64_t trtMajorVersion) {
924+
// IAttention layer is only supported in TensorRT >= 10.14.0
925+
if (trtMajorVersion < 10)
926+
return false;
927+
928+
return true;
929+
}

0 commit comments

Comments
 (0)