From 91a9facfd3209d889ce22e3d8bb6f5bed30324a4 Mon Sep 17 00:00:00 2001 From: Yizhuo Zhang Date: Tue, 5 Aug 2025 14:57:27 -0700 Subject: [PATCH] [TensorRT] Extract layer metadata from FusedLocation Extracts TRT layer metadata from FusedLocation.metadata. Also removes the "metadata" string from TRT dialect as it is no longer needed. --- .../mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td | 3 --- .../Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp | 7 +++---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td index 4c5e5e6bb..a0d583c2e 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td @@ -51,9 +51,6 @@ def TensorRT_Dialect : Dialect { static constexpr StringRef kTensorRTPerTensorDequantizationMarker = "tensorrt.pt_dq"; static constexpr StringRef kTensorRTPerChannelDequantizationMarker = "tensorrt.pc_dq"; static constexpr StringRef kTensorRTBlockDequantizationMarker = "tensorrt.block_dq"; - - /// TensorRT layer metadata markder. - static constexpr StringRef kTensorRTLayerMetadataMarker = "metadata"; }]; let dependentDialects = [ diff --git a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp index a6d71bc81..f0974160c 100644 --- a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp +++ b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp @@ -278,10 +278,9 @@ void NvInferNetworkEncoder::setMetadata(nvinfer1::ILayer *layer, Operation *sourceOp) { std::string name = createName(namesSet, sourceOp); layer->setName(name.c_str()); - if (auto metadataAttr = sourceOp->getAttrOfType( - TensorRTDialect::kTensorRTLayerMetadataMarker)) { - layer->setMetadata(metadataAttr.getValue().str().c_str()); - } + if (auto fusedLoc = dyn_cast(sourceOp->getLoc())) + if (auto metadataAttr = dyn_cast(fusedLoc.getMetadata())) + layer->setMetadata(metadataAttr.getValue().str().c_str()); } nvinfer1::ITensor *NvInferNetworkEncoder::lookup(Value v) const {