Skip to content

Commit 95438de

Browse files
working
1 parent 15f7673 commit 95438de

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
212212
accum += std::string(str.data, str.length);
213213
};
214214

215-
cppOpts->layerMetadataCallback = [&](Operation *op) {
215+
// Capturing by reference here will cause `callback` to point to the wrong
216+
// place at the time this callback is invoked.
217+
cppOpts->layerMetadataCallback = [=](Operation *op) {
216218
std::string accum;
217219
void *appendCtx = reinterpret_cast<void *>(&accum);
218220
callback(wrap(op), appendFunc, appendCtx, userData);

mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ class PyStableHLOToExecutableOptions
6868
mtrtPythonCapsuleToStableHLOToExecutableOptions,
6969
mtrtPythonStableHLOToExecutableOptionsToCapsule};
7070

71-
std::function<void(MlirOperation op, MlirStringCallback append,
72-
void *appendCtx, void *userData)>
73-
callback;
71+
std::function<std::string(MlirOperation)> callback;
7472
};
7573
} // namespace
7674

@@ -228,6 +226,25 @@ static void bindTensorRTPluginAdaptorObjects(py::module m) {
228226
#endif
229227
#endif
230228

229+
void layerMetadataCallback(MlirOperation op, MlirStringCallback append,
230+
void *appendCtx, void *userDataVoid) {
231+
// auto pyCallback =
232+
// *static_cast<std::function<std::string(MlirOperation)>
233+
// *>(userDataVoid);
234+
235+
// std::string result;
236+
// try {
237+
// result = pyCallback(op);
238+
// } catch (const std::exception &e) {
239+
// // TODO: What to do here? Change callback to return a status
240+
// // instead of void?
241+
// // return mtrtStatusCreate(
242+
// // MTRT_StatusCode::MTRT_StatusCode_Unknown, e.what());
243+
// }
244+
245+
// append(MlirStringRef{result.data(), result.size()}, appendCtx);
246+
}
247+
231248
PYBIND11_MODULE(_api, m) {
232249

233250
populateCommonBindingsInModule(m);
@@ -285,7 +302,7 @@ PYBIND11_MODULE(_api, m) {
285302
// Since we're constructing a C callback, our closures must not
286303
// capture. We can pass in the Python callback via the userData
287304
// argument.
288-
self.callback = [](MlirOperation op, MlirStringCallback append,
305+
auto callback = [](MlirOperation op, MlirStringCallback append,
289306
void *appendCtx, void *userDataVoid) {
290307
auto pyCallback =
291308
*static_cast<std::function<std::string(MlirOperation)> *>(
@@ -304,13 +321,10 @@ PYBIND11_MODULE(_api, m) {
304321
append(MlirStringRef{result.data(), result.size()}, appendCtx);
305322
};
306323

324+
self.callback = pyCallback;
307325
THROW_IF_MTRT_ERROR(
308326
mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback(
309-
self,
310-
self.callback.target<void(
311-
MlirOperation op, MlirStringCallback append,
312-
void *appendCtx, void *userData)>(),
313-
reinterpret_cast<void *>(&pyCallback)));
327+
self, callback, reinterpret_cast<void *>(&self.callback)));
314328
},
315329
py::arg("callback"), py::keep_alive<1, 2>{})
316330
#endif

0 commit comments

Comments
 (0)