Skip to content

Commit 50c6388

Browse files
Adds infrastructure to automatically map operations to layers
1 parent c1cb619 commit 50c6388

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,16 @@ class NvInferNetworkEncoder {
8989
/// Lookup the TRT ITensor* equivalents of a ValueRange.
9090
SmallVector<nvinfer1::ITensor *> lookupValues(ValueRange values);
9191

92-
/// Add a map from a Value to a TRT ITEnsor*.
92+
/// Add a map from a Value to a TRT ITensor*.
9393
void map(Value from, nvinfer1::ITensor *to);
9494

9595
/// Remap values in `from` to each layer in `to` using the output at index 0
9696
/// for each layer.
9797
void map(ValueRange from, ArrayRef<nvinfer1::ILayer *> to);
9898

99+
// Add a map from an Operation to a TRT ILayer*
100+
void map(Operation *op, nvinfer1::ILayer *layer);
101+
99102
/// Check whether the value map contains `v`.
100103
size_t contains(Value v) { return valueMap.count(v); }
101104

@@ -135,6 +138,10 @@ class NvInferNetworkEncoder {
135138
/// and other temporary buffers.
136139
using WeightsMap = llvm::DenseMap<mlir::Attribute, std::vector<int8_t>>;
137140

141+
// Tracks the mapping of mlir::Operations to layers. Note that one operation
142+
// may map to multiple layers.
143+
using LayerMap = llvm::DenseMap<Operation *, std::vector<nvinfer1::ILayer *>>;
144+
138145
using NamesSet = llvm::StringSet<>;
139146

140147
TensorMap &getTensorMap() { return valueMap; }
@@ -210,6 +217,9 @@ class NvInferNetworkEncoder {
210217
// build ends.
211218
SmallVector<NvInferPluginPtr> pluginReferences;
212219

220+
// Tracks the mapping between mlir::Operations and TensorRT ILayers.
221+
LayerMap layerMap;
222+
213223
/// Holds the set of strings currently assigned as names to TensorRT ILayers.
214224
/// This is required because we must make new names unique. The TensorRT API
215225
/// does not have a set object to query names.

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,13 @@ void NvInferNetworkEncoder::map(ValueRange from,
287287
valueMap.insert(v, l->getOutput(0));
288288
}
289289

290+
void NvInferNetworkEncoder::map(Operation *op, nvinfer1::ILayer *layer) {
291+
if (!layerMap.count(op))
292+
layerMap[op] = {};
293+
294+
layerMap[op].push_back(layer);
295+
}
296+
290297
bool NvInferNetworkEncoder::isStronglyTyped() const {
291298
if (!usesStronglyTyped)
292299
return false;

mlir-tensorrt/tools/MlirTensorRtTblgen.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,11 @@ static bool emitLayerAddDefinitions(const llvm::RecordKeeper &recordKeeper,
397397
}
398398

399399
// Emit the body with substitutions.
400+
os << "auto const numLayersBefore = network->getNbLayers();\n";
400401
os << tblgen::tgfmt(expr, &ctx);
402+
os << "auto const numLayersAfter = network->getNbLayers();\n";
403+
os << "for (int64_t i = numLayersBefore; i < numLayersAfter; ++i) "
404+
"encoder.map(tensorrtOp, network->getLayer(i));\n";
401405
os << "return success();\n";
402406
os.unindent();
403407
os << "}\n";
@@ -603,4 +607,4 @@ int main(int argc, char **argv) {
603607
});
604608

605609
return mlir::MlirTblgenMain(argc, argv);
606-
}
610+
}

0 commit comments

Comments
 (0)