@@ -68,9 +68,7 @@ class PyStableHLOToExecutableOptions
6868 mtrtPythonCapsuleToStableHLOToExecutableOptions,
6969 mtrtPythonStableHLOToExecutableOptionsToCapsule};
7070
71- std::function<void (MlirOperation op, MlirStringCallback append,
72- void *appendCtx, void *userData)>
73- callback;
71+ std::function<std::string(MlirOperation)> callback;
7472};
7573} // namespace
7674
@@ -228,6 +226,25 @@ static void bindTensorRTPluginAdaptorObjects(py::module m) {
228226#endif
229227#endif
230228
229+ void layerMetadataCallback (MlirOperation op, MlirStringCallback append,
230+ void *appendCtx, void *userDataVoid) {
231+ // auto pyCallback =
232+ // *static_cast<std::function<std::string(MlirOperation)>
233+ // *>(userDataVoid);
234+
235+ // std::string result;
236+ // try {
237+ // result = pyCallback(op);
238+ // } catch (const std::exception &e) {
239+ // // TODO: What to do here? Change callback to return a status
240+ // // instead of void?
241+ // // return mtrtStatusCreate(
242+ // // MTRT_StatusCode::MTRT_StatusCode_Unknown, e.what());
243+ // }
244+
245+ // append(MlirStringRef{result.data(), result.size()}, appendCtx);
246+ }
247+
231248PYBIND11_MODULE (_api, m) {
232249
233250 populateCommonBindingsInModule (m);
@@ -285,7 +302,7 @@ PYBIND11_MODULE(_api, m) {
285302 // Since we're constructing a C callback, our closures must not
286303 // capture. We can pass in the Python callback via the userData
287304 // argument.
288- self. callback = [](MlirOperation op, MlirStringCallback append,
305+ auto callback = [](MlirOperation op, MlirStringCallback append,
289306 void *appendCtx, void *userDataVoid) {
290307 auto pyCallback =
291308 *static_cast <std::function<std::string (MlirOperation)> *>(
@@ -304,13 +321,10 @@ PYBIND11_MODULE(_api, m) {
304321 append (MlirStringRef{result.data (), result.size ()}, appendCtx);
305322 };
306323
324+ self.callback = pyCallback;
307325 THROW_IF_MTRT_ERROR (
308326 mtrtStableHloToExecutableOptionsSetTensorRTTranslationMetadataCallback (
309- self,
310- self.callback .target <void (
311- MlirOperation op, MlirStringCallback append,
312- void *appendCtx, void *userData)>(),
313- reinterpret_cast <void *>(&pyCallback)));
327+ self, callback, reinterpret_cast <void *>(&self.callback )));
314328 },
315329 py::arg (" callback" ), py::keep_alive<1 , 2 >{})
316330#endif
0 commit comments