@@ -4539,7 +4539,7 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45394539 TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
45404540 Optional<TensorRT_Tensor>:$mask,
45414541 Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
4542- OptionalAttr <TensorRT_AttentionNormalizationOpAttr>:$normalization_operation,
4542+ DefaultValuedAttr <TensorRT_AttentionNormalizationOpAttr, "tensorrt::AttentionNormalizationOp::kSOFTMAX" >:$normalization_operation,
45434543 DefaultValuedAttr<BoolAttr, "false">:$causal,
45444544 DefaultValuedAttr<BoolAttr, "false">:$decomposable,
45454545 OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
@@ -4565,12 +4565,7 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45654565 }] # baseClassDeclaration;
45664566
45674567 let trtLayerAdd = [{
4568- // Get normalization operation, default to kSOFTMAX
4569- nvinfer1::AttentionNormalizationOp normOp = $normalization_operation
4570- ? *$normalization_operation
4571- : nvinfer1::AttentionNormalizationOp::kSOFTMAX;
4572-
4573- nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, normOp, $causal);
4568+ nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal);
45744569 if (!layer)
45754570 return failure();
45764571
@@ -4584,19 +4579,22 @@ def TensorRT_AttentionOp : TensorRT_Op<"attention",
45844579 }
45854580
45864581 if ($normalization_quantize_to_type) {
4587- layer->setNormalizationQuantizeToType(*$normalization_quantize_to_type);
4582+ auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type);
4583+ if (!convertedDataType)
4584+ return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum";
4585+ layer->setNormalizationQuantizeToType(*convertedDataType);
45884586 }
45894587
45904588 if (!$e.isStronglyTyped()){
45914589 FailureOr<nvinfer1::DataType> outputTrtType = getNvInferDataType($op.getLoc(),
45924590 $op.getType().getElementType());
45934591 if (failed(outputTrtType))
45944592 return failure();
4595- layer->setOutputType(0, *outputTrtType);
45964593 }
45974594
45984595 $results.push_back(layer->getOutput(0));
4599- $e.setMetadata(layer, $op);
4596+ // TODO: nvinfer1::IAttention does not have setMetadata API in 10.14
4597+ // layer->setMetadata($op);
46004598 }];
46014599}
46024600
0 commit comments