Skip to content

Commit 60c38a2

Browse files
adds py bindings
1 parent 75bd2b1 commit 60c38a2

File tree

7 files changed

+36
-16
lines changed

7 files changed

+36
-16
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
7979

8080
MLIR_CAPI_EXPORTED MTRT_Status
8181
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
82-
MTRT_StableHLOToExecutableOptions options,
83-
const char *(*callback)(MlirOperation));
82+
MTRT_StableHLOToExecutableOptions options, void *callback);
8483

8584
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
8685
MTRT_StableHLOToExecutableOptions options);

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,10 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
201201

202202
MTRT_Status
203203
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
204-
MTRT_StableHLOToExecutableOptions options,
205-
const char *(*callback)(MlirOperation)) {
204+
MTRT_StableHLOToExecutableOptions options, void *callback) {
206205
StableHLOToExecutableOptions *cppOpts = unwrap(options);
207-
cppOpts->layerMetadataCallback = callback;
206+
cppOpts->layerMetadataCallback =
207+
reinterpret_cast<std::string (*)(MlirOperation)>(callback);
208208
return mtrtStatusGetOk();
209209
}
210210

mlir-tensorrt/compiler/lib/Compiler/TensorRTExtension/TensorRTExtension.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void StableHLOToExecutableTensorRTExtension::populatePasses(
6565
tensorrt::buildTensorRTModuleTransformationPipeline(
6666
trtPM, translationOptions.enableStronglyTyped);
6767
trtPM.addPass(tensorrt::createTranslateTensorRTPass(
68-
nullptr, translationOptions, options.layerMetadataCallback));
68+
nullptr, options.layerMetadataCallback, translationOptions));
6969
return;
7070
}
7171

mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
#include "pybind11/pybind11.h"
2121
#include "llvm/Support/DynamicLibrary.h"
2222
#include <pybind11/attr.h>
23+
#include <pybind11/functional.h>
24+
#include <stdexcept>
2325

2426
#ifdef MLIR_TRT_TARGET_TENSORRT
2527
#include "mlir-tensorrt-dialect/Utils/NvInferAdaptor.h"
28+
#include "mlir-tensorrt-dialect/Utils/Types.h"
2629
#endif
2730

2831
namespace py = pybind11;
@@ -270,7 +273,27 @@ PYBIND11_MODULE(_api, m) {
270273
py::arg("enabled"),
271274
py::arg("debug_types") = std::vector<std::string>{},
272275
py::arg("dump_ir_tree_dir") = py::none(),
273-
py::arg("dump_tensorrt_dir") = py::none());
276+
py::arg("dump_tensorrt_dir") = py::none())
277+
278+
#ifdef MLIR_TRT_TARGET_TENSORRT
279+
.def(
280+
"set_tensorrt_translation_metadata_callback",
281+
[](PyStableHLOToExecutableOptions &self, MetadataCallbackT callback) {
282+
auto *ptr = callback.target<MetadataCallbackT>();
283+
284+
if (!ptr)
285+
throw std::runtime_error{
286+
"Metadata callback has incorrect signature. Expected a "
287+
"function that accepts an MLIR operation and returns a "
288+
"string."};
289+
290+
THROW_IF_MTRT_ERROR(
291+
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
292+
self, reinterpret_cast<void *>(ptr)));
293+
},
294+
py::arg("callback"))
295+
#endif
296+
;
274297

275298
m.def(
276299
"compiler_stablehlo_to_executable",
@@ -308,4 +331,4 @@ PYBIND11_MODULE(_api, m) {
308331
bindTensorRTPluginAdaptorObjects(m);
309332
#endif
310333
#endif
311-
}
334+
}

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TranslateToTensorRT.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,9 @@ FailureOr<TensorRTEngineResult> buildFunction(
223223
/// TensorRTBuilderContext.
224224
std::unique_ptr<mlir::Pass> createTranslateTensorRTPass(
225225
std::shared_ptr<tensorrt::TensorRTBuilderContext> context,
226+
mlirtrt::MetadataCallbackT layerMetadataCallback,
226227
TensorRTTranslationOptions options =
227-
TensorRTTranslationOptions::fromCLFlags(),
228-
// TODO: Add a sane default here:
229-
mlirtrt::MetadataCallbackT layerMetadataCallback = [](MlirOperation op) {
230-
return "";
231-
});
228+
TensorRTTranslationOptions::fromCLFlags());
232229

233230
/// Register llvm::cl opts related to TensorRT translation. This should be
234231
/// called before having LLVM parse CL options.

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Utils/Types.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
#include "mlir/IR/Operation.h"
55
#include <functional>
6+
#include <string>
67

78
namespace mlirtrt {
8-
using MetadataCallbackT = std::function<const char *(MlirOperation)>;
9+
using MetadataCallbackT = std::function<std::string(MlirOperation)>;
910
} // namespace mlirtrt
1011

1112
#endif // MLIR_TENSORRT_UTILS_TYPES_H

mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,8 +830,8 @@ class TranslateToTensorRTEnginePass
830830

831831
std::unique_ptr<mlir::Pass> tensorrt::createTranslateTensorRTPass(
832832
std::shared_ptr<tensorrt::TensorRTBuilderContext> context,
833-
TensorRTTranslationOptions options,
834-
mlirtrt::MetadataCallbackT layerMetadataCallback) {
833+
mlirtrt::MetadataCallbackT layerMetadataCallback,
834+
TensorRTTranslationOptions options) {
835835
return std::make_unique<TranslateToTensorRTEnginePass>(context, options,
836836
layerMetadataCallback);
837837
}

0 commit comments

Comments
 (0)