Skip to content

Commit 45bb9be

Browse files
pranavm-nvidiayizhuoz004
authored andcommitted
[TensorRT] Sets TRT layer metadata and nvtx profiling verbosity
For now we query an environment variable `MTRT_TENSORRT_NVTX` to set the nvtx profiling verbosity. This is not ideal because it cannot support per-engine profiling verbosity. We will change that with a runtime option for TRT module.
1 parent 6584c93 commit 45bb9be

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,25 @@
4242
using namespace mlirtrt;
4343
using namespace mlirtrt::runtime;
4444

45+
static const char *kNvtxVerbosityEnvVariable = "MTRT_TENSORRT_NVTX";
46+
47+
/// Helper method that gets nvtx verbosity from environment value
48+
nvinfer1::ProfilingVerbosity getNvtxVerbosity() {
49+
const char *verbosity_str = std::getenv(kNvtxVerbosityEnvVariable);
50+
if (!verbosity_str)
51+
return nvinfer1::ProfilingVerbosity::kLAYER_NAMES_ONLY;
52+
switch (std::string_view(verbosity_str)) {
53+
case "NONE":
54+
return nvinfer1::ProfilingVerbosity::kNONE;
55+
case "DETAILED":
56+
return nvinfer1::ProfilingVerbosity::kDETAILED;
57+
default:
58+
return nvinfer1::ProfilingVerbosity::kLAYER_NAMES_ONLY;
59+
}
60+
}
61+
62+
static const nvinfer1::ProfilingVerbosity gNvtxVerbosity = getNvtxVerbosity();
63+
4564
namespace {
4665
/// A simple logger that implements TensorRT's logging interface. Errors and
4766
/// warnings are reported through TensorRT's diagnostic system, everything else
@@ -611,6 +630,8 @@ static Status enqueueV3Wrapper(AllocTracker &tracker,
611630
return getStatusWithMsg(StatusCode::InternalError,
612631
"failed to set input-consumed event");
613632

633+
context->setNvtxVerbosity(gNvtxVerbosity);
634+
614635
if (!context->enqueueV3(stream))
615636
return getStatusWithMsg(StatusCode::InternalError,
616637
"failed to enqueue engine execution on stream");
@@ -650,6 +671,8 @@ static Status enqueueAllocV3Wrapper(AllocTracker &tracker,
650671
// Number of results are known in advance.
651672
int64_t nbResults = outputDesc.getNumberOfResults();
652673

674+
context->setNvtxVerbosity(gNvtxVerbosity);
675+
653676
if (!context->enqueueV3(stream))
654677
return getStatusWithMsg(StatusCode::InternalError,
655678
"failed to enqueue engine execution on stream");

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ void NvInferNetworkEncoder::setMetadata(nvinfer1::ILayer *layer,
278278
Operation *sourceOp) {
279279
std::string name = createName(namesSet, sourceOp);
280280
layer->setName(name.c_str());
281+
282+
if (auto metadataAttr = sourceOp->getAttrOfType<StringAttr>("metadata")) {
283+
layer->setMetadata(metadataAttr.getValue().str().c_str());
284+
}
281285
}
282286

283287
nvinfer1::ITensor *NvInferNetworkEncoder::lookup(Value v) const {

mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -522,11 +522,10 @@ tensorrt::buildFunction(mlir::FunctionOpInterface op,
522522
<< "failed to set timing cache";
523523
}
524524

525-
// If created, engines and their layer information are
526-
// with detailed description.
527-
if (!opts.saveTensorRTEnginesToDirectory.empty() ||
528-
!opts.saveTensorRTLayerInfoDirectory.empty())
529-
config->setProfilingVerbosity(nvinfer1::ProfilingVerbosity::kDETAILED);
525+
// Enable kDETAILED verbosity unconditionally, then use
526+
// `IExecutionContext::setNvtxVerbosity` to change the verbosity at runtime
527+
// (lower verbosity performs better generally).
528+
config->setProfilingVerbosity(nvinfer1::ProfilingVerbosity::kDETAILED);
530529

531530
setBuilderOptimizationLevel(config.get(), opts.tensorrtBuilderOptLevel,
532531
builderContext.getTensorRTVersion());

0 commit comments

Comments
 (0)