@@ -4432,4 +4432,172 @@ def TensorRT_ScatterElementsOp : TensorRT_Op<"scatter_elements",
44324432 }];
44334433}
44344434
4435+ //===----------------------------------------------------------------------===//
4436+ // AttentionOp
4437+ //===----------------------------------------------------------------------===//
4438+
4439+ def TensorRT_AttentionOp : TensorRT_Op<"attention",
4440+ [Pure, AttrSizedOperandSegments, TensorRTInferTensorResultTypes,
4441+ AllElementTypesMatch<["query", "key", "value"]>,
4442+ AllRanksMatch<["query", "key", "value"]>]>{
4443+ let summary = "TensorRT attention (IAttention) operation";
4444+ let description = [{
4445+ The `tensorrt.attention` operation implements a fused attention mechanism
4446+ that consumes query, key, and value tensors. The operation implicitly includes
4447+ two matrix multiplication layers (BMM1 and BMM2) and a normalization operation
4448+ (typically softmax).
4449+
4450+ By default, TensorRT will try to use a single fused kernel for better efficiency.
4451+ The operation can optionally be decomposed into multiple kernels if no fused
4452+ kernel is available by setting `decomposable` to true.
4453+
4454+ #### Architecture:
4455+
4456+ ```
4457+ Query Key Value Mask (optional) NormalizationQuantizeScale (optional)
4458+ | | | | |
4459+ | Transpose | | |
4460+ | | | | |
4461+ ----BMM1---- | | |
4462+ | | | |
4463+ *--------------------------- |
4464+ | | |
4465+ Normalization | |
4466+ | | |
4467+ *------------------------------------------------
4468+ | |
4469+ -------BMM2------
4470+ |
4471+ Output
4472+ ```
4473+
4474+ #### Inputs:
4475+
4476+ - Query: tensor of type f32, f16, or bf16 with shape
4477+ [batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]
4478+ - Key: tensor of type f32, f16, or bf16 with shape
4479+ [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
4480+ - Value: tensor of type f32, f16, or bf16 with shape
4481+ [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
4482+ - Mask (optional): tensor of type i1 or same type as BMM1 output with shape
4483+ [batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue]
4484+ where batchSize and numHeadsQuery are broadcastable. For i1 mask, true
4485+ indicates the position is allowed to attend. For other types, mask values
4486+ are added to BMM1 output.
4487+ - NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16
4488+ with rank 0 or 1, used for quantizing the normalization output.
4489+
4490+ #### Attributes:
4491+
4492+ - normalization_operation: The normalization operation to use (default: kSOFTMAX)
4493+ - causal: Whether to use causal masking (default: false). Cannot be used with mask input.
4494+ - decomposable: Whether the operation can be decomposed (default: false)
4495+ - normalization_quantize_to_type: Optional output type for quantized normalization.
4496+ When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided.
4497+
4498+ #### Constraints:
4499+
4500+ - All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead]
4501+ - Query, key, and value must have the same element type (f32, f16, or bf16)
4502+ - If normalization_quantize_to_type is specified:
4503+ * It must be kFP8 or kINT8
4504+ * normalization_quantize_scale input must be provided
4505+ - Cannot use both mask input and causal=true simultaneously
4506+
4507+ #### Examples:
4508+
4509+ Basic attention:
4510+ ```mlir
4511+ %output = tensorrt.attention ins(%query, %key, %value :
4512+ tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
4513+ -> tensor<2x8x128x64xf16>
4514+ ```
4515+
4516+ Causal attention:
4517+ ```mlir
4518+ %output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value :
4519+ tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
4520+ -> tensor<2x8x128x64xf16>
4521+ ```
4522+
4523+ Attention with quantization:
4524+ ```mlir
4525+ %scale = tensorrt.constant dense<1.0> : tensor<f32>
4526+ %output_quant = tensorrt.attention {
4527+ normalization_quantize_to_type = #tensorrt.data_type<kFP8>
4528+ } ins(%query, %key, %value,
4529+ normalization_quantize_scale = %scale :
4530+ tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>,
4531+ tensor<2x8x128x64xf16>, tensor<f32>)
4532+ -> tensor<2x8x128x64xf16>
4533+ ```
4534+ }];
4535+
4536+ let arguments = (ins
4537+ TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query,
4538+ TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key,
4539+ TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
4540+ Optional<TensorRT_Tensor>:$mask,
4541+ Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
4542+ OptionalAttr<TensorRT_AttentionNormalizationOpAttr>:$normalization_operation,
4543+ DefaultValuedAttr<BoolAttr, "false">:$causal,
4544+ DefaultValuedAttr<BoolAttr, "false">:$decomposable,
4545+ OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
4546+ );
4547+
4548+ let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result);
4549+
4550+ let assemblyFormat = [{
4551+ attr-dict `ins` `(` $query `,` $key `,` $value
4552+ (`,` `mask` `=` $mask^)?
4553+ (`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)?
4554+ `:` type($query) `,` type($key) `,` type($value)
4555+ (`,` type($mask)^)?
4556+ (`,` type($normalization_quantize_scale)^)?
4557+ `)` `->` type($result)
4558+ }];
4559+
4560+ let hasVerifier = 1;
4561+
4562+ let extraClassDeclaration = [{
4563+ /// Returns true if created op is valid for TensorRT major version.
4564+ bool isValidForTensorRTVersion(int64_t trtMajorVersion);
4565+ }] # baseClassDeclaration;
4566+
4567+ let trtLayerAdd = [{
4568+ // Get normalization operation, default to kSOFTMAX
4569+ nvinfer1::AttentionNormalizationOp normOp = $normalization_operation
4570+ ? *$normalization_operation
4571+ : nvinfer1::AttentionNormalizationOp::kSOFTMAX;
4572+
4573+ nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal);
4574+ if (!layer)
4575+ return failure();
4576+
4577+ if ($mask)
4578+ layer->setMask(*$mask);
4579+
4580+ layer->setDecomposable($decomposable);
4581+
4582+ if ($normalization_quantize_scale) {
4583+ layer->setNormalizationQuantizeScale(*$normalization_quantize_scale);
4584+ }
4585+
4586+ if ($normalization_quantize_to_type) {
4587+ layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type);
4588+ }
4589+
4590+ if (!$e.isStronglyTyped()){
4591+ FailureOr<nvinfer1::DataType> outputTrtType = getNvInferDataType($op.getLoc(),
4592+ $op.getType().getElementType());
4593+ if (failed(outputTrtType))
4594+ return failure();
4595+ layer->setOutputType(0, *outputTrtType);
4596+ }
4597+
4598+ $results.push_back(layer->getOutput(0));
4599+ $e.setMetadata(layer, $op);
4600+ }];
4601+ }
4602+
44354603#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTOPS_TD
0 commit comments