Skip to content

Commit 8ba77d2

Browse files
finishes implementation, adds test
1 parent 1ef8e4a commit 8ba77d2

File tree

10 files changed

+142
-63
lines changed

10 files changed

+142
-63
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "mlir-c/Support.h"
2929
#include "mlir-executor-c/Common/Common.h"
3030
#include "mlir-executor-c/Support/Status.h"
31+
#include "mlir-tensorrt-dialect/Utils/Types.h"
32+
#include "mlir/CAPI/IR.h"
3133

3234
#ifdef __cplusplus
3335
extern "C" {
@@ -60,6 +62,12 @@ typedef struct MTRT_StableHLOToExecutableOptions {
6062
void *ptr;
6163
} MTRT_StableHLOToExecutableOptions;
6264

65+
typedef struct MTRT_MetadataCallback {
66+
void *ptr;
67+
} MTRT_MetadataCallback;
68+
69+
DEFINE_C_API_PTR_METHODS(MTRT_MetadataCallback, mlirtrt::MetadataCallbackT)
70+
6371
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsCreate(
6472
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions *options,
6573
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped);
@@ -79,7 +87,7 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
7987

8088
MLIR_CAPI_EXPORTED MTRT_Status
8189
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
82-
MTRT_StableHLOToExecutableOptions options, void *callback);
90+
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback);
8391

8492
MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
8593
MTRT_StableHLOToExecutableOptions options);

mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StableHloToExecutable.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ struct StableHLOToExecutableOptions : public mlir::OptionsContext {
132132
DebugOptions debugOptions;
133133

134134
// TODO: Add a sane default here:
135-
MetadataCallbackT layerMetadataCallback = [](MlirOperation op) { return ""; };
135+
MetadataCallbackT layerMetadataCallback = [](const MlirOperation &op) {
136+
return "";
137+
};
136138

137139
/// Base class for extensions associated with StableHloToExecutableTask.
138140
class ExtensionBase : public TaskExtensionBase {

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,9 @@ MTRT_Status mtrtStableHloToExecutableOptionsSetDebugOptions(
201201

202202
MTRT_Status
203203
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
204-
MTRT_StableHLOToExecutableOptions options, void *callback) {
204+
MTRT_StableHLOToExecutableOptions options, MTRT_MetadataCallback callback) {
205205
StableHLOToExecutableOptions *cppOpts = unwrap(options);
206-
cppOpts->layerMetadataCallback =
207-
reinterpret_cast<std::string (*)(MlirOperation)>(callback);
206+
cppOpts->layerMetadataCallback = *unwrap(callback);
208207
return mtrtStatusGetOk();
209208
}
210209

mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -279,17 +279,20 @@ PYBIND11_MODULE(_api, m) {
279279
.def(
280280
"set_tensorrt_translation_metadata_callback",
281281
[](PyStableHLOToExecutableOptions &self, MetadataCallbackT callback) {
282-
auto *ptr = callback.target<MetadataCallbackT>();
282+
// auto ptr = callback.target<MetadataCallbackT>();
283283

284-
if (!ptr)
285-
throw std::runtime_error{
286-
"Metadata callback has incorrect signature. Expected a "
287-
"function that accepts an MLIR operation and returns a "
288-
"string."};
284+
// if (!ptr)
285+
// throw std::runtime_error{
286+
// "Metadata callback has incorrect signature. Expected a "
287+
// "function that accepts an MLIR operation and returns a "
288+
// "string."};
289+
// [](PyStableHLOToExecutableOptions &self, py::object callback) {
290+
291+
// auto ptr = callback.ptr();
289292

290293
THROW_IF_MTRT_ERROR(
291294
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
292-
self, reinterpret_cast<void *>(ptr)));
295+
self, wrap(&callback)));
293296
},
294297
py::arg("callback"))
295298
#endif

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class NvInferNetworkEncoder {
145145

146146
/// Set the name of the `trtLayer` to a unique string that contains the op
147147
/// name and location information from `sourceOp`.
148-
void setName(nvinfer1::ILayer *layer, Operation *sourceOp);
148+
void setMetadata(nvinfer1::ILayer *layer, Operation *sourceOp);
149149

150150
// Check if network uses fp16 types.
151151
bool hasFp16Usage() const { return usesFp16; }
@@ -245,7 +245,7 @@ class NvInferNetworkEncoder {
245245

246246
// TODO: Where to use this? encodeOp doesn't have a way for us to access the
247247
// layers.
248-
std::function<std::string(MlirOperation)> layerMetadataCallback;
248+
mlirtrt::MetadataCallbackT layerMetadataCallback;
249249
};
250250

251251
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TranslateToTensorRT.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,8 @@ FailureOr<TensorRTEngineResult> buildFunction(
215215
const TensorRTTranslationOptions &options =
216216
TensorRTTranslationOptions::fromCLFlags(),
217217
// TODO: Add a sane default here:
218-
mlirtrt::MetadataCallbackT layerMetadataCallback = [](MlirOperation op) {
219-
return "";
220-
});
218+
mlirtrt::MetadataCallbackT layerMetadataCallback =
219+
[](const MlirOperation &op) { return ""; });
221220

222221
/// Create an instance of a translate-to-tensorrt pass using an existing
223222
/// TensorRTBuilderContext.

0 commit comments

Comments
 (0)