diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td index 4c5e5e6bb..a0d583c2e 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td @@ -51,9 +51,6 @@ def TensorRT_Dialect : Dialect { static constexpr StringRef kTensorRTPerTensorDequantizationMarker = "tensorrt.pt_dq"; static constexpr StringRef kTensorRTPerChannelDequantizationMarker = "tensorrt.pc_dq"; static constexpr StringRef kTensorRTBlockDequantizationMarker = "tensorrt.block_dq"; - - /// TensorRT layer metadata markder. - static constexpr StringRef kTensorRTLayerMetadataMarker = "metadata"; }]; let dependentDialects = [ diff --git a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp index a6d71bc81..f0974160c 100644 --- a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp +++ b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp @@ -278,10 +278,9 @@ void NvInferNetworkEncoder::setMetadata(nvinfer1::ILayer *layer, Operation *sourceOp) { std::string name = createName(namesSet, sourceOp); layer->setName(name.c_str()); - if (auto metadataAttr = sourceOp->getAttrOfType( - TensorRTDialect::kTensorRTLayerMetadataMarker)) { - layer->setMetadata(metadataAttr.getValue().str().c_str()); - } + if (auto fusedLoc = dyn_cast(sourceOp->getLoc())) + if (auto metadataAttr = dyn_cast(fusedLoc.getMetadata())) + layer->setMetadata(metadataAttr.getValue().str().c_str()); } nvinfer1::ITensor *NvInferNetworkEncoder::lookup(Value v) const {