Skip to content

Commit 3c32574

Browse files
author
Matthew Francis-Landau
committed
[mlir-tensorrt] fix the way that layer names are set for TensorRT
1 parent d1a8447 commit 3c32574

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,20 +222,39 @@ nvinfer1::Permutation tensorrt::getNvInferPermutation(ArrayRef<int64_t> array) {
222222
static std::string getUniqueName(NvInferNetworkEncoder::NamesSet &names,
223223
std::string name) {
224224
static unsigned i = 0;
225-
std::string uniqueName = name;
226-
while (names.contains(uniqueName))
227-
uniqueName = name + "_" + std::to_string(i++);
228-
names.insert(uniqueName);
229-
return uniqueName;
225+
std::string newName = name;
226+
while (names.contains(newName))
227+
newName = name + "_" + std::to_string(i++);
228+
names.insert(newName);
229+
return newName;
230230
}
231231

232232
/// Print a representation of the given location to the string. Since MLIR has
233233
/// an open system of location attributes, there may be some location types that
234234
/// we cannot handle. We do not use the location's builtin printer because it
235235
/// could be extremely verbose.
236+
static void getCallSiteLocs(Location loc, SmallVector<Location> &locs) {
237+
if (auto callLoc = dyn_cast<CallSiteLoc>(loc)) {
238+
getCallSiteLocs(callLoc.getCaller(), locs);
239+
getCallSiteLocs(callLoc.getCallee(), locs);
240+
} else {
241+
locs.push_back(loc);
242+
}
243+
}
244+
236245
static void translateLocation(Location loc, llvm::raw_ostream &os) {
237246
if (auto callLoc = dyn_cast<CallSiteLoc>(loc)) {
238-
translateLocation(callLoc.getCaller(), os);
247+
SmallVector<Location> locs;
248+
getCallSiteLocs(callLoc, locs);
249+
// only include the last 3 locations in the names as this should be
250+
// sufficient to identify the call site for an op
251+
for (size_t i = locs.size() > 3 ? locs.size() - 3 : 0; i < locs.size();
252+
i++) {
253+
translateLocation(locs[i], os);
254+
if (i < locs.size() - 1) {
255+
os << " -> ";
256+
}
257+
}
239258
return;
240259
}
241260
if (auto fileLoc = dyn_cast<FileLineColLoc>(loc)) {
@@ -274,7 +293,8 @@ static std::string createName(NvInferNetworkEncoder::NamesSet &names,
274293
ss.flush();
275294
}
276295
// Truncate to TRT limit.
277-
static constexpr size_t kLayerNameSizeLimit = 2048;
296+
static constexpr size_t kLayerNameSizeLimit =
297+
2048 - 6; // -6 to give some space for UniqueName
278298
if (name.size() > kLayerNameSizeLimit)
279299
name = name.substr(0, kLayerNameSizeLimit);
280300

0 commit comments

Comments
 (0)