Skip to content

Commit f10feb0

Browse files
Adds a layer metadata callback API
- Adds a new API which allows for setting a layer metadata callback which will be invoked for each MLIR operation in order to set metadata for the corresponding TensorRT network layers.
1 parent f72a7af commit f10feb0

File tree

11 files changed

+228
-77
lines changed

11 files changed

+228
-77
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ typedef struct MTRT_StableHLOToExecutableOptions {
6060
void *ptr;
6161
} MTRT_StableHLOToExecutableOptions;
6262

63+
typedef void (*MTRT_MetadataCallback)(MlirOperation op,
64+
MlirStringCallback append,
65+
void *appendCtx, void *userData);
66+
6367
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsCreate(
6468
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions *options,
6569
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped);
@@ -77,6 +81,11 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
7781
const char **debugTypes, size_t debugTypeSizes,
7882
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);
7983

84+
MLIR_CAPI_EXPORTED MTRT_Status
85+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
86+
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
87+
void *userData);
88+
8089
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
8190
MTRT_StableHLOToExecutableOptions options);
8291

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
#include "mlir-executor/Runtime/API/API.h"
3535
#include "mlir-executor/Support/Status.h"
36-
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
3736
#include "mlir-tensorrt/Compiler/Client.h"
3837
#include "mlir-tensorrt/Compiler/Extension.h"
3938
#include "mlir-tensorrt/Compiler/Options.h"
@@ -125,11 +124,14 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
125124
/// Whether to disallow host tensors in TensorRT clusters.
126125
bool disallowHostTensorsInTensorRTClusters = false;
127126

128-
/// Entrypiont function name.
127+
/// Entrypoint function name.
129128
std::string entrypoint = "main";
130129

131130
DebugOptions debugOptions;
132131

132+
std::function<std::string(mlir::Operation *)> layerMetadataCallback =
133+
[](mlir::Operation *) { return ""; };
134+
133135
/// Base class for extensions associated with StableHloToExecutableTask.
134136
class ExtensionBase : public TaskExtensionBase {
135137
public:

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
3333
#include "mlir/CAPI/IR.h"
3434
#include "llvm/ADT/StringExtras.h"
35-
#include "llvm/Support/raw_ostream.h"
3635

3736
using namespace mlirtrt;
3837
using namespace mlirtrt::compiler;
@@ -199,6 +198,32 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
199198
return mtrtStatusGetOk();
200199
}
201200

201+
MTRT_Status
202+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
203+
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
204+
void *userData) {
205+
StableHLOToExecutableOptions *cppOpts = unwrap(options);
206+
207+
// Construct the append callback which we will pass to the callback provided
208+
// by the user. We do it this way to avoid needing a string construct in the C
209+
// API.
210+
auto appendFunc = [](MlirStringRef str, void *appendCtx) {
211+
std::string &accum = *reinterpret_cast<std::string *>(appendCtx);
212+
accum += std::string(str.data, str.length);
213+
};
214+
215+
// Capturing by reference here will cause `callback` to point to the wrong
216+
// place at the time this callback is invoked.
217+
cppOpts->layerMetadataCallback = [=](Operation *op) {
218+
std::string accum;
219+
void *appendCtx = reinterpret_cast<void *>(&accum);
220+
callback(wrap(op), appendFunc, appendCtx, userData);
221+
return accum;
222+
};
223+
224+
return mtrtStatusGetOk();
225+
}
226+
202227
MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
203228
MTRT_StableHLOToExecutableOptions options) {
204229
delete reinterpret_cast<StableHLOToExecutableOptions *>(options.ptr);

mlir-tensorrt/compiler/lib/Compiler/TensorRTExtension/TensorRTExtension.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ void StableHLOToExecutableTensorRTExtension::populatePasses(
6464
auto &trtPM = pm.nest<tensorrt::TensorRTModuleOp>();
6565
tensorrt::buildTensorRTModuleTransformationPipeline(
6666
trtPM, translationOptions.enableStronglyTyped);
67-
trtPM.addPass(
68-
tensorrt::createTranslateTensorRTPass(nullptr, translationOptions));
67+
trtPM.addPass(tensorrt::createTranslateTensorRTPass(
68+
nullptr, options.layerMetadataCallback, translationOptions));
6969
return;
7070
}
7171

mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
#include "mlir/Bindings/Python/PybindAdaptors.h"
2020
#include "pybind11/pybind11.h"
2121
#include "llvm/Support/DynamicLibrary.h"
22+
#include <iostream>
2223
#include <pybind11/attr.h>
24+
#include <pybind11/functional.h>
2325

2426
#ifdef MLIR_TRT_TARGET_TENSORRT
2527
#include "mlir-tensorrt-dialect/Utils/NvInferAdaptor.h"
@@ -66,6 +68,9 @@ class PyStableHLOToExecutableOptions
6668
mtrtStableHloToExecutableOptionsDestroy,
6769
mtrtPythonCapsuleToStableHLOToExecutableOptions,
6870
mtrtPythonStableHLOToExecutableOptionsToCapsule};
71+
72+
// We need this member so we can keep the Python callback alive long enough.
73+
std::function<std::string(MlirOperation)> callback;
6974
};
7075
} // namespace
7176

@@ -270,7 +275,40 @@ PYBIND11_MODULE(_api, m) {
270275
py::arg("enabled"),
271276
py::arg("debug_types") = std::vector<std::string>{},
272277
py::arg("dump_ir_tree_dir") = py::none(),
273-
py::arg("dump_tensorrt_dir") = py::none());
278+
py::arg("dump_tensorrt_dir") = py::none())
279+
280+
#ifdef MLIR_TRT_TARGET_TENSORRT
281+
.def(
282+
"set_tensorrt_translation_metadata_callback",
283+
[](PyStableHLOToExecutableOptions &self,
284+
std::function<std::string(MlirOperation)> pyCallback) {
285+
// Since we're constructing a C callback, our closures must not
286+
// capture. We can pass in the Python callback via the userData
287+
// argument.
288+
auto callback = [](MlirOperation op, MlirStringCallback append,
289+
void *appendCtx, void *userDataVoid) {
290+
auto pyCallback =
291+
*static_cast<std::function<std::string(MlirOperation)> *>(
292+
userDataVoid);
293+
294+
std::string result;
295+
try {
296+
result = pyCallback(op);
297+
} catch (const std::exception &e) {
298+
std::cerr << e.what() << std::endl;
299+
}
300+
301+
append(MlirStringRef{result.data(), result.size()}, appendCtx);
302+
};
303+
304+
self.callback = pyCallback;
305+
THROW_IF_MTRT_ERROR(
306+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
307+
self, callback, reinterpret_cast<void *>(&self.callback)));
308+
},
309+
py::arg("callback"), py::keep_alive<1, 2>{})
310+
#endif
311+
;
274312

275313
m.def(
276314
"compiler_stablehlo_to_executable",
@@ -308,4 +346,4 @@ PYBIND11_MODULE(_api, m) {
308346
bindTensorRTPluginAdaptorObjects(m);
309347
#endif
310348
#endif
311-
}
349+
}

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,14 @@ static constexpr nvinfer1::Weights kNullWeights =
7474

7575
class NvInferNetworkEncoder {
7676
public:
77-
NvInferNetworkEncoder(nvinfer1::INetworkDefinition *network,
78-
nvinfer1::IOptimizationProfile *profile,
79-
TensorRTVersion version, bool usesStronglyTyped)
77+
NvInferNetworkEncoder(
78+
nvinfer1::INetworkDefinition *network,
79+
nvinfer1::IOptimizationProfile *profile, TensorRTVersion version,
80+
bool usesStronglyTyped,
81+
std::function<std::string(Operation *)> metadataCallback)
8082
: network(network), profile(profile), version(std::move(version)),
81-
usesStronglyTyped(usesStronglyTyped) {}
83+
usesStronglyTyped(usesStronglyTyped),
84+
layerMetadataCallback(std::move(metadataCallback)) {}
8285

8386
/// Lookup the TRT ITensor* equivalent of a Value.
8487
nvinfer1::ITensor *lookup(Value v) const;
@@ -141,7 +144,7 @@ class NvInferNetworkEncoder {
141144

142145
/// Set the name of the `trtLayer` to a unique string that contains the op
143146
/// name and location information from `sourceOp`.
144-
void setName(nvinfer1::ILayer *layer, Operation *sourceOp);
147+
void setMetadata(nvinfer1::ILayer *layer, Operation *sourceOp);
145148

146149
// Check if network uses fp16 types.
147150
bool hasFp16Usage() const { return usesFp16; }
@@ -238,6 +241,8 @@ class NvInferNetworkEncoder {
238241
bool hasQDQOps{false};
239242

240243
PluginManager pluginMgr;
244+
245+
std::function<std::string(Operation *)> layerMetadataCallback;
241246
};
242247

243248
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TranslateToTensorRT.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
#ifdef MLIR_TRT_TARGET_TENSORRT
2424
#include "mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h"
25-
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2625
#include "mlir-tensorrt-dialect/Utils/Options.h"
2726
#include "mlir/Support/LogicalResult.h"
2827
#include "llvm/Support/raw_ostream.h"
@@ -208,17 +207,19 @@ class TensorRTSerializedTimingCache {
208207
/// `tensorrt.shape_profile` arguments have been populated for each argument
209208
/// that has unknown dimensions.
210209
/// TODO(cbate): add additional options here for builder configuration.
211-
FailureOr<TensorRTEngineResult>
212-
buildFunction(mlir::FunctionOpInterface op,
213-
TensorRTBuilderContext &builderContext,
214-
TensorRTSerializedTimingCache &serializedTimingCache,
215-
const TensorRTTranslationOptions &options =
216-
TensorRTTranslationOptions::fromCLFlags());
210+
FailureOr<TensorRTEngineResult> buildFunction(
211+
mlir::FunctionOpInterface op, TensorRTBuilderContext &builderContext,
212+
TensorRTSerializedTimingCache &serializedTimingCache,
213+
const TensorRTTranslationOptions &options =
214+
TensorRTTranslationOptions::fromCLFlags(),
215+
std::function<std::string(Operation *)> layerMetadataCallback =
216+
[](Operation *op) { return ""; });
217217

218218
/// Create an instance of a translate-to-tensorrt pass using an existing
219219
/// TensorRTBuilderContext.
220220
std::unique_ptr<mlir::Pass> createTranslateTensorRTPass(
221221
std::shared_ptr<tensorrt::TensorRTBuilderContext> context,
222+
std::function<std::string(Operation *)> layerMetadataCallback,
222223
TensorRTTranslationOptions options =
223224
TensorRTTranslationOptions::fromCLFlags());
224225

0 commit comments

Comments
 (0)