Skip to content

Commit f0a824f

Browse files
Adds a layer metadata callback API
- Adds a new API which allows for setting a layer metadata callback which will be invoked for each MLIR operation in order to set metadata for the corresponding TensorRT network layers.
1 parent f8b5db3 commit f0a824f

File tree

15 files changed

+349
-113
lines changed

15 files changed

+349
-113
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ typedef struct MTRT_StableHLOToExecutableOptions {
6060
void *ptr;
6161
} MTRT_StableHLOToExecutableOptions;
6262

63+
typedef void (*MTRT_MetadataCallback)(MlirOperation op,
64+
MlirStringCallback append,
65+
void *appendCtx, void *userData);
66+
6367
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsCreate(
6468
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions *options,
6569
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped);
@@ -77,6 +81,11 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
7781
const char **debugTypes, size_t debugTypeSizes,
7882
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);
7983

84+
MLIR_CAPI_EXPORTED MTRT_Status
85+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
86+
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
87+
void *userData);
88+
8089
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
8190
MTRT_StableHLOToExecutableOptions options);
8291

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/Client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ class CompilerClient {
106106
auto key = std::make_pair(mlir::TypeID::get<CompilationTaskType>(),
107107
options.getHash());
108108
auto it = cachedPassManagers.find(key);
109-
if (it == cachedPassManagers.end()) {
109+
if (it == cachedPassManagers.end() || options.shouldInvalidateCache()) {
110110
auto pm = std::make_unique<CompilationTaskType>(context, options);
111111
setupPassManagerLogging(*pm, options.debugOptions);
112112
auto *ptr = pm.get();
113-
cachedPassManagers.insert(std::make_pair(key, std::move(pm)));
113+
cachedPassManagers[key] = std::move(pm);
114114
return *ptr;
115115
}
116116
return *it->second;

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
#include "mlir-executor/Runtime/API/API.h"
3535
#include "mlir-executor/Support/Status.h"
36-
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
3736
#include "mlir-tensorrt/Compiler/Client.h"
3837
#include "mlir-tensorrt/Compiler/Extension.h"
3938
#include "mlir-tensorrt/Compiler/Options.h"
@@ -106,6 +105,15 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
106105
/// Get the mutable DebugOptions.
107106
DebugOptions &getDebugOptions() { return debugOptions; }
108107

108+
llvm::hash_code getHash() const override;
109+
110+
bool shouldInvalidateCache() const override {
111+
// If a callback is provided, we have no way of verifying whether it is
112+
// equivalent to a callback from another set of options. Therefore, we are
113+
// forced to invalidate the cache entry if it is present at all.
114+
return static_cast<bool>(layerMetadataCallback);
115+
}
116+
109117
/// The host index bit-width.
110118
int64_t executorIndexBitwidth{64};
111119

@@ -125,11 +133,13 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
125133
/// Whether to disallow host tensors in TensorRT clusters.
126134
bool disallowHostTensorsInTensorRTClusters = false;
127135

128-
/// Entrypiont function name.
136+
/// Entrypoint function name.
129137
std::string entrypoint = "main";
130138

131139
DebugOptions debugOptions;
132140

141+
std::function<std::string(mlir::Operation *)> layerMetadataCallback{nullptr};
142+
133143
/// Base class for extensions associated with StableHloToExecutableTask.
134144
class ExtensionBase : public TaskExtensionBase {
135145
public:

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
3333
#include "mlir/CAPI/IR.h"
3434
#include "llvm/ADT/StringExtras.h"
35-
#include "llvm/Support/raw_ostream.h"
3635

3736
using namespace mlirtrt;
3837
using namespace mlirtrt::compiler;
@@ -199,6 +198,32 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
199198
return mtrtStatusGetOk();
200199
}
201200

201+
MTRT_Status
202+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
203+
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback,
204+
void *userData) {
205+
StableHLOToExecutableOptions *cppOpts = unwrap(options);
206+
207+
// Construct the append callback which we will pass to the callback provided
208+
// by the user. We do it this way to avoid needing a string construct in the C
209+
// API.
210+
auto appendFunc = [](MlirStringRef str, void *appendCtx) {
211+
std::string &accum = *reinterpret_cast<std::string *>(appendCtx);
212+
accum += std::string(str.data, str.length);
213+
};
214+
215+
// Capturing by reference here will cause `callback` to point to the wrong
216+
// place at the time this callback is invoked.
217+
cppOpts->layerMetadataCallback = [=](Operation *op) {
218+
std::string accum;
219+
void *appendCtx = reinterpret_cast<void *>(&accum);
220+
callback(wrap(op), appendFunc, appendCtx, userData);
221+
return accum;
222+
};
223+
224+
return mtrtStatusGetOk();
225+
}
226+
202227
MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
203228
MTRT_StableHLOToExecutableOptions options) {
204229
delete reinterpret_cast<StableHLOToExecutableOptions *>(options.ptr);

mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,13 @@ Status StableHLOToExecutableOptions::inferDeviceOptionsFromHost() {
271271
return Status::getOk();
272272
}
273273

274+
llvm::hash_code StableHLOToExecutableOptions::getHash() const {
275+
llvm::hash_code hash = OptionsContext::getHash();
276+
if (layerMetadataCallback)
277+
return llvm::hash_combine(hash, &layerMetadataCallback);
278+
return hash;
279+
}
280+
274281
//===----------------------------------------------------------------------===//
275282
// StableHloToExecutableTask
276283
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/Compiler/TensorRTExtension/TensorRTExtension.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ void StableHLOToExecutableTensorRTExtension::populatePasses(
6464
auto &trtPM = pm.nest<tensorrt::TensorRTModuleOp>();
6565
tensorrt::buildTensorRTModuleTransformationPipeline(
6666
trtPM, translationOptions.enableStronglyTyped);
67-
trtPM.addPass(
68-
tensorrt::createTranslateTensorRTPass(nullptr, translationOptions));
67+
trtPM.addPass(tensorrt::createTranslateTensorRTPass(
68+
nullptr, options.layerMetadataCallback, translationOptions));
6969
return;
7070
}
7171

mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
#include "mlir/Bindings/Python/PybindAdaptors.h"
2020
#include "pybind11/pybind11.h"
2121
#include "llvm/Support/DynamicLibrary.h"
22+
#include "llvm/Support/raw_ostream.h"
2223
#include <pybind11/attr.h>
24+
#include <pybind11/functional.h>
2325

2426
#ifdef MLIR_TRT_TARGET_TENSORRT
2527
#include "mlir-tensorrt-dialect/Utils/NvInferAdaptor.h"
@@ -66,6 +68,11 @@ class PyStableHLOToExecutableOptions
6668
mtrtStableHloToExecutableOptionsDestroy,
6769
mtrtPythonCapsuleToStableHLOToExecutableOptions,
6870
mtrtPythonStableHLOToExecutableOptionsToCapsule};
71+
72+
// We need this member so we can keep the Python callback alive long enough.
73+
std::function<std::string(MlirOperation)> callback;
74+
75+
~PyStableHLOToExecutableOptions() { callback = nullptr; }
6976
};
7077
} // namespace
7178

@@ -270,7 +277,43 @@ PYBIND11_MODULE(_api, m) {
270277
py::arg("enabled"),
271278
py::arg("debug_types") = std::vector<std::string>{},
272279
py::arg("dump_ir_tree_dir") = py::none(),
273-
py::arg("dump_tensorrt_dir") = py::none());
280+
py::arg("dump_tensorrt_dir") = py::none())
281+
282+
#ifdef MLIR_TRT_TARGET_TENSORRT
283+
.def(
284+
"set_tensorrt_translation_metadata_callback",
285+
[](PyStableHLOToExecutableOptions &self,
286+
std::function<std::string(MlirOperation)> pyCallback) {
287+
// Since we're constructing a C callback, our closures must not
288+
// capture. We can pass in the Python callback via the userData
289+
// argument.
290+
auto callback = [](MlirOperation op, MlirStringCallback append,
291+
void *appendCtx, void *userDataVoid) {
292+
auto &pyCallback =
293+
*static_cast<std::function<std::string(MlirOperation)> *>(
294+
userDataVoid);
295+
296+
if (!pyCallback)
297+
return;
298+
299+
std::string result;
300+
try {
301+
result = pyCallback(op);
302+
} catch (const std::exception &e) {
303+
llvm::errs() << e.what() << '\n';
304+
}
305+
306+
append(MlirStringRef{result.data(), result.size()}, appendCtx);
307+
};
308+
309+
self.callback = pyCallback;
310+
THROW_IF_MTRT_ERROR(
311+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
312+
self, callback, reinterpret_cast<void *>(&self.callback)));
313+
},
314+
py::arg("callback"), py::keep_alive<1, 2>{})
315+
#endif
316+
;
274317

275318
m.def(
276319
"compiler_stablehlo_to_executable",
@@ -308,4 +351,4 @@ PYBIND11_MODULE(_api, m) {
308351
bindTensorRTPluginAdaptorObjects(m);
309352
#endif
310353
#endif
311-
}
354+
}

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,31 @@ static constexpr nvinfer1::Weights kNullWeights =
7474

7575
class NvInferNetworkEncoder {
7676
public:
77-
NvInferNetworkEncoder(nvinfer1::INetworkDefinition *network,
78-
nvinfer1::IOptimizationProfile *profile,
79-
TensorRTVersion version, bool usesStronglyTyped)
77+
NvInferNetworkEncoder(
78+
nvinfer1::INetworkDefinition *network,
79+
nvinfer1::IOptimizationProfile *profile, TensorRTVersion version,
80+
bool usesStronglyTyped,
81+
std::function<std::string(Operation *)> metadataCallback)
8082
: network(network), profile(profile), version(std::move(version)),
81-
usesStronglyTyped(usesStronglyTyped) {}
83+
usesStronglyTyped(usesStronglyTyped),
84+
layerMetadataCallback(std::move(metadataCallback)) {}
8285

8386
/// Lookup the TRT ITensor* equivalent of a Value.
8487
nvinfer1::ITensor *lookup(Value v) const;
8588

8689
/// Lookup the TRT ITensor* equivalents of a ValueRange.
8790
SmallVector<nvinfer1::ITensor *> lookupValues(ValueRange values);
8891

89-
/// Add a map from a Value to a TRT ITEnsor*.
92+
/// Add a map from a Value to a TRT ITensor*.
9093
void map(Value from, nvinfer1::ITensor *to);
9194

9295
/// Remap values in `from` to each layer in `to` using the output at index 0
9396
/// for each layer.
9497
void map(ValueRange from, ArrayRef<nvinfer1::ILayer *> to);
9598

99+
// Add a map from an Operation to a TRT ILayer*
100+
void map(Operation *op, nvinfer1::ILayer *layer);
101+
96102
/// Check whether the value map contains `v`.
97103
size_t contains(Value v) { return valueMap.count(v); }
98104

@@ -133,6 +139,10 @@ class NvInferNetworkEncoder {
133139
/// and other temporary buffers.
134140
using WeightsMap = llvm::DenseMap<mlir::Attribute, std::vector<int8_t>>;
135141

142+
// Tracks the mapping of mlir::Operations to layers. Note that one operation
143+
// may map to multiple layers.
144+
using LayerMap = llvm::DenseMap<Operation *, std::vector<nvinfer1::ILayer *>>;
145+
136146
using NamesSet = llvm::StringSet<>;
137147

138148
TensorMap &getTensorMap() { return valueMap; }
@@ -142,7 +152,7 @@ class NvInferNetworkEncoder {
142152

143153
/// Set the name of the `trtLayer` to a unique string that contains the op
144154
/// name and location information from `sourceOp`.
145-
void setName(nvinfer1::ILayer *layer, Operation *sourceOp);
155+
void setMetadata(nvinfer1::ILayer *layer, Operation *sourceOp);
146156

147157
// Check if network uses fp16 types.
148158
bool hasFp16Usage() const { return usesFp16; }
@@ -208,6 +218,9 @@ class NvInferNetworkEncoder {
208218
// build ends.
209219
SmallVector<NvInferPluginPtr> pluginReferences;
210220

221+
// Tracks the mapping between mlir::Operations and TensorRT ILayers.
222+
LayerMap layerMap;
223+
211224
/// Holds the set of strings currently assigned as names to TensorRT ILayers.
212225
/// This is required because we must make new names unique. The TensorRT API
213226
/// does not have a set object to query names.
@@ -239,6 +252,8 @@ class NvInferNetworkEncoder {
239252
bool hasQDQOps{false};
240253

241254
PluginManager pluginMgr;
255+
256+
std::function<std::string(Operation *)> layerMetadataCallback;
242257
};
243258

244259
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TranslateToTensorRT.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
#ifdef MLIR_TRT_TARGET_TENSORRT
2424
#include "mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h"
25-
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2625
#include "mlir-tensorrt-dialect/Utils/Options.h"
2726
#include "mlir/Support/LogicalResult.h"
2827
#include "llvm/Support/raw_ostream.h"
@@ -229,17 +228,18 @@ class TensorRTSerializedTimingCache {
229228
/// `tensorrt.shape_profile` arguments have been populated for each argument
230229
/// that has unknown dimensions.
231230
/// TODO(cbate): add additional options here for builder configuration.
232-
FailureOr<TensorRTEngineResult>
233-
buildFunction(mlir::FunctionOpInterface op,
234-
TensorRTBuilderContext &builderContext,
235-
TensorRTSerializedTimingCache &serializedTimingCache,
236-
const TensorRTTranslationOptions &options =
237-
TensorRTTranslationOptions::fromCLFlags());
231+
FailureOr<TensorRTEngineResult> buildFunction(
232+
mlir::FunctionOpInterface op, TensorRTBuilderContext &builderContext,
233+
TensorRTSerializedTimingCache &serializedTimingCache,
234+
const TensorRTTranslationOptions &options =
235+
TensorRTTranslationOptions::fromCLFlags(),
236+
std::function<std::string(Operation *)> layerMetadataCallback = nullptr);
238237

239238
/// Create an instance of a translate-to-tensorrt pass using an existing
240239
/// TensorRTBuilderContext.
241240
std::unique_ptr<mlir::Pass> createTranslateTensorRTPass(
242241
std::shared_ptr<tensorrt::TensorRTBuilderContext> context,
242+
std::function<std::string(Operation *)> layerMetadataCallback,
243243
TensorRTTranslationOptions options =
244244
TensorRTTranslationOptions::fromCLFlags());
245245

0 commit comments

Comments
 (0)