Skip to content

Commit 1c8829d

Browse files
Matthew Francis-Landaumatthewfl
authored andcommitted
[mlir-tensorrt] fix the way that layer names are set for TensorRT
Previously the layer name was set to the root caller and ignored callee. This change updates the layer name to capture the last 3 callees to enable better tracking of locations in the generated TensorRT engine.
1 parent d1a8447 commit 1c8829d

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,28 @@ static std::string getUniqueName(NvInferNetworkEncoder::NamesSet &names,
233233
/// an open system of location attributes, there may be some location types that
234234
/// we cannot handle. We do not use the location's builtin printer because it
235235
/// could be extremely verbose.
236+
static void getCallSiteLocs(Location loc, SmallVector<Location> &locs) {
237+
if (auto callLoc = dyn_cast<CallSiteLoc>(loc)) {
238+
getCallSiteLocs(callLoc.getCaller(), locs);
239+
getCallSiteLocs(callLoc.getCallee(), locs);
240+
} else {
241+
locs.push_back(loc);
242+
}
243+
}
244+
236245
static void translateLocation(Location loc, llvm::raw_ostream &os) {
237246
if (auto callLoc = dyn_cast<CallSiteLoc>(loc)) {
238-
translateLocation(callLoc.getCaller(), os);
247+
SmallVector<Location> locs;
248+
getCallSiteLocs(callLoc, locs);
249+
// only include the last 3 locations in the names as this should be
250+
// sufficient to identify the call site for an op
251+
for (size_t i = locs.size() > 3 ? locs.size() - 3 : 0; i < locs.size();
252+
i++) {
253+
translateLocation(locs[i], os);
254+
if (i < locs.size() - 1) {
255+
os << " -> ";
256+
}
257+
}
239258
return;
240259
}
241260
if (auto fileLoc = dyn_cast<FileLineColLoc>(loc)) {
@@ -274,7 +293,8 @@ static std::string createName(NvInferNetworkEncoder::NamesSet &names,
274293
ss.flush();
275294
}
276295
// Truncate to TRT limit.
277-
static constexpr size_t kLayerNameSizeLimit = 2048;
296+
static constexpr size_t kLayerNameSizeLimit =
297+
2048 - 6; // -6 to give some space for UniqueName
278298
if (name.size() > kLayerNameSizeLimit)
279299
name = name.substr(0, kLayerNameSizeLimit);
280300

0 commit comments

Comments
 (0)