Skip to content

Commit ef7f8f6

Browse files
authored
[TensorRT] Extract layer metadata from FusedLocation (#689)
Extracts TRT layer metadata from FusedLocation.metadata. Also removes the "metadata" string from TRT dialect as it is no longer needed.
1 parent 04f6cf4 commit ef7f8f6

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def TensorRT_Dialect : Dialect {
5151
static constexpr StringRef kTensorRTPerTensorDequantizationMarker = "tensorrt.pt_dq";
5252
static constexpr StringRef kTensorRTPerChannelDequantizationMarker = "tensorrt.pc_dq";
5353
static constexpr StringRef kTensorRTBlockDequantizationMarker = "tensorrt.block_dq";
54-
55-
/// TensorRT layer metadata markder.
56-
static constexpr StringRef kTensorRTLayerMetadataMarker = "metadata";
5754
}];
5855

5956
let dependentDialects = [

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,9 @@ void NvInferNetworkEncoder::setMetadata(nvinfer1::ILayer *layer,
287287
Operation *sourceOp) {
288288
std::string name = createName(namesSet, sourceOp);
289289
layer->setName(name.c_str());
290-
if (auto metadataAttr = sourceOp->getAttrOfType<StringAttr>(
291-
TensorRTDialect::kTensorRTLayerMetadataMarker)) {
292-
layer->setMetadata(metadataAttr.getValue().str().c_str());
293-
}
290+
if (auto fusedLoc = dyn_cast<FusedLoc>(sourceOp->getLoc()))
291+
if (auto metadataAttr = dyn_cast<StringAttr>(fusedLoc.getMetadata()))
292+
layer->setMetadata(metadataAttr.getValue().str().c_str());
294293
}
295294

296295
nvinfer1::ITensor *NvInferNetworkEncoder::lookup(Value v) const {

0 commit comments

Comments
 (0)