Skip to content

Commit 6d4ff73

Browse files
pranavm-nvidiayizhuoz004
authored andcommitted
[TensorRT] Sets TRT layer metadata and nvtx profiling verbosity
For now we query an environment variable `MTRT_TENSORRT_NVTX` to set the nvtx profiling verbosity. This is not ideal because it cannot support per-engine profiling verbosity. We will change that with a runtime option for TRT module.
1 parent 6584c93 commit 6d4ff73

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,22 @@
4242
using namespace mlirtrt;
4343
using namespace mlirtrt::runtime;
4444

45+
static const char *kNvtxVerbosityEnvVariable = "MTRT_TENSORRT_NVTX";
46+
47+
/// Helper method that gets nvtx verbosity from environment value
48+
nvinfer1::ProfilingVerbosity getNvtxVerbosity() {
49+
const char *verbosity_str = std::getenv(kNvtxVerbosityEnvVariable);
50+
if (!verbosity_str)
51+
return nvinfer1::ProfilingVerbosity::kLAYER_NAMES_ONLY;
52+
if (std::string_view(verbosity_str) == "NONE")
53+
return nvinfer1::ProfilingVerbosity::kNONE;
54+
if (std::string_view(verbosity_str) == "DETAILED")
55+
return nvinfer1::ProfilingVerbosity::kDETAILED;
56+
return nvinfer1::ProfilingVerbosity::kLAYER_NAMES_ONLY;
57+
}
58+
59+
static const nvinfer1::ProfilingVerbosity gNvtxVerbosity = getNvtxVerbosity();
60+
4561
namespace {
4662
/// A simple logger that implements TensorRT's logging interface. Errors and
4763
/// warnings are reported through TensorRT's diagnostic system, everything else
@@ -611,6 +627,8 @@ static Status enqueueV3Wrapper(AllocTracker &tracker,
611627
return getStatusWithMsg(StatusCode::InternalError,
612628
"failed to set input-consumed event");
613629

630+
context->setNvtxVerbosity(gNvtxVerbosity);
631+
614632
if (!context->enqueueV3(stream))
615633
return getStatusWithMsg(StatusCode::InternalError,
616634
"failed to enqueue engine execution on stream");
@@ -650,6 +668,8 @@ static Status enqueueAllocV3Wrapper(AllocTracker &tracker,
650668
// Number of results are known in advance.
651669
int64_t nbResults = outputDesc.getNumberOfResults();
652670

671+
context->setNvtxVerbosity(gNvtxVerbosity);
672+
653673
if (!context->enqueueV3(stream))
654674
return getStatusWithMsg(StatusCode::InternalError,
655675
"failed to enqueue engine execution on stream");

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ void NvInferNetworkEncoder::setMetadata(nvinfer1::ILayer *layer,
278278
Operation *sourceOp) {
279279
std::string name = createName(namesSet, sourceOp);
280280
layer->setName(name.c_str());
281+
282+
if (auto metadataAttr = sourceOp->getAttrOfType<StringAttr>("metadata")) {
283+
layer->setMetadata(metadataAttr.getValue().str().c_str());
284+
}
281285
}
282286

283287
nvinfer1::ITensor *NvInferNetworkEncoder::lookup(Value v) const {

mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -522,11 +522,10 @@ tensorrt::buildFunction(mlir::FunctionOpInterface op,
522522
<< "failed to set timing cache";
523523
}
524524

525-
// If created, engines and their layer information are
526-
// with detailed description.
527-
if (!opts.saveTensorRTEnginesToDirectory.empty() ||
528-
!opts.saveTensorRTLayerInfoDirectory.empty())
529-
config->setProfilingVerbosity(nvinfer1::ProfilingVerbosity::kDETAILED);
525+
// Enable kDETAILED verbosity unconditionally, then use
526+
// `IExecutionContext::setNvtxVerbosity` to change the verbosity at runtime
527+
// (lower verbosity performs better generally).
528+
config->setProfilingVerbosity(nvinfer1::ProfilingVerbosity::kDETAILED);
530529

531530
setBuilderOptimizationLevel(config.get(), opts.tensorrtBuilderOptLevel,
532531
builderContext.getTensorRTVersion());

0 commit comments

Comments
 (0)