Skip to content

Commit 91a9fac

Browse files
committed
[TensorRT] Extract layer metadata from FusedLocation
Extracts TRT layer metadata from FusedLocation.metadata. Also removes the "metadata" string from TRT dialect as it is no longer needed.
1 parent f821499 commit 91a9fac

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def TensorRT_Dialect : Dialect {
5151
static constexpr StringRef kTensorRTPerTensorDequantizationMarker = "tensorrt.pt_dq";
5252
static constexpr StringRef kTensorRTPerChannelDequantizationMarker = "tensorrt.pc_dq";
5353
static constexpr StringRef kTensorRTBlockDequantizationMarker = "tensorrt.block_dq";
54-
55-
/// TensorRT layer metadata markder.
56-
static constexpr StringRef kTensorRTLayerMetadataMarker = "metadata";
5754
}];
5855

5956
let dependentDialects = [

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,9 @@ void NvInferNetworkEncoder::setMetadata(nvinfer1::ILayer *layer,
278278
Operation *sourceOp) {
279279
std::string name = createName(namesSet, sourceOp);
280280
layer->setName(name.c_str());
281-
if (auto metadataAttr = sourceOp->getAttrOfType<StringAttr>(
282-
TensorRTDialect::kTensorRTLayerMetadataMarker)) {
283-
layer->setMetadata(metadataAttr.getValue().str().c_str());
284-
}
281+
if (auto fusedLoc = dyn_cast<FusedLoc>(sourceOp->getLoc()))
282+
if (auto metadataAttr = dyn_cast<StringAttr>(fusedLoc.getMetadata()))
283+
layer->setMetadata(metadataAttr.getValue().str().c_str());
285284
}
286285

287286
nvinfer1::ITensor *NvInferNetworkEncoder::lookup(Value v) const {

0 commit comments

Comments
 (0)