Skip to content

Commit b344e42

Browse files
matthewflMatthew Francis-Landau
andauthored
[mlir-tensorrt] fix the way that layer names are set for TensorRT (#702)
This PR fixes the way that layer names are set for TensorRT. ~~Previously there was a bug in the `getUniqueName` function that caused it to generate names like `name_0_1_2_3_4_5_6_7` instead of `name_7`.~~ This was not a problem upstream. Also fixed the way that call sites are translated to layer names. Previous if a `Location` tracked the entire stack trace (function `a` calls `b` calls `c` calls `d` ), then the name would be set as `a:LINE_NUMBER`. Now, the layer name is set to `b:LINE -> c:LINE -> d:LINE` to provide enough information to identify where an layer comes from. Co-authored-by: Matthew Francis-Landau <[email protected]>
1 parent d1a8447 commit b344e42

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,28 @@ static std::string getUniqueName(NvInferNetworkEncoder::NamesSet &names,
233233
/// an open system of location attributes, there may be some location types that
234234
/// we cannot handle. We do not use the location's builtin printer because it
235235
/// could be extremely verbose.
236+
static void getCallSiteLocs(Location loc, SmallVector<Location> &locs) {
237+
if (auto callLoc = dyn_cast<CallSiteLoc>(loc)) {
238+
getCallSiteLocs(callLoc.getCaller(), locs);
239+
getCallSiteLocs(callLoc.getCallee(), locs);
240+
} else {
241+
locs.push_back(loc);
242+
}
243+
}
244+
236245
static void translateLocation(Location loc, llvm::raw_ostream &os) {
237246
if (auto callLoc = dyn_cast<CallSiteLoc>(loc)) {
238-
translateLocation(callLoc.getCaller(), os);
247+
SmallVector<Location> locs;
248+
getCallSiteLocs(callLoc, locs);
249+
// only include the last 3 locations in the names as this should be
250+
// sufficient to identify the call site for an op
251+
for (size_t i = locs.size() > 3 ? locs.size() - 3 : 0; i < locs.size();
252+
i++) {
253+
translateLocation(locs[i], os);
254+
if (i < locs.size() - 1) {
255+
os << " -> ";
256+
}
257+
}
239258
return;
240259
}
241260
if (auto fileLoc = dyn_cast<FileLineColLoc>(loc)) {
@@ -274,7 +293,8 @@ static std::string createName(NvInferNetworkEncoder::NamesSet &names,
274293
ss.flush();
275294
}
276295
// Truncate to TRT limit.
277-
static constexpr size_t kLayerNameSizeLimit = 2048;
296+
static constexpr size_t kLayerNameSizeLimit =
297+
2048 - 6; // -6 to give some space for UniqueName
278298
if (name.size() > kLayerNameSizeLimit)
279299
name = name.substr(0, kLayerNameSizeLimit);
280300

0 commit comments

Comments
 (0)