diff --git a/conifer/backends/cpp/writer.py b/conifer/backends/cpp/writer.py index df6cdc5d..3f7e1b32 100644 --- a/conifer/backends/cpp/writer.py +++ b/conifer/backends/cpp/writer.py @@ -24,6 +24,12 @@ class CPPModel(ModelBase): def __init__(self, ensembleDict, config, metadata=None): super(CPPModel, self).__init__(ensembleDict, config, metadata) self.config = CPPConfig(config) + + def load_shared_library(self, model_json, shared_library): + import importlib + spec = importlib.util.spec_from_file_location(os.path.basename(shared_library).split(".so")[0], shared_library) + self.bridge = importlib.util.module_from_spec(spec).BDT(model_json) + spec.loader.exec_module(self.bridge) @copydocstring(ModelBase.write) def write(self): @@ -99,10 +105,7 @@ def compile(self): try: logger.debug(f'Importing conifer_bridge_{self._stamp} from conifer_bridge_{self._stamp}.so') - import importlib.util - spec = importlib.util.spec_from_file_location(f'conifer_bridge_{self._stamp}', f'./conifer_bridge_{self._stamp}.so') - self.bridge = importlib.util.module_from_spec(spec).BDT(f"{cfg.project_name}.json") - spec.loader.exec_module(self.bridge) + self.load_shared_library(f"{cfg.project_name}.json", f"./conifer_bridge_{self._stamp}.so") except ImportError: os.chdir(curr_dir) raise Exception("Can't import pybind11 bridge, is it compiled?") diff --git a/conifer/backends/xilinxhls/writer.py b/conifer/backends/xilinxhls/writer.py index e9b28f4c..9b188978 100644 --- a/conifer/backends/xilinxhls/writer.py +++ b/conifer/backends/xilinxhls/writer.py @@ -512,6 +512,12 @@ def decision_function(self, X, trees=False): y = y.reshape(y.shape[0]) return y + def load_shared_library(self, model_json, shared_library): + import importlib + spec = importlib.util.spec_from_file_location(os.path.basename(shared_library).split(".so")[0], shared_library) + self.bridge = importlib.util.module_from_spec(spec) + spec.loader.exec_module(self.bridge) + @copydocstring(ModelBase.compile) def compile(self): self.write() @@ -534,10 +540,7 @@ def compile(self): try: logger.debug(f'Importing conifer_bridge_{self._stamp} from conifer_bridge_{self._stamp}.so') - import importlib.util - spec = importlib.util.spec_from_file_location(f'conifer_bridge_{self._stamp}', f'./conifer_bridge_{self._stamp}.so') - self.bridge = importlib.util.module_from_spec(spec) - spec.loader.exec_module(self.bridge) + self.load_shared_library(f"{cfg.project_name}.json", f"./conifer_bridge_{self._stamp}.so") except ImportError: os.chdir(curr_dir) raise Exception("Can't import pybind11 bridge, is it compiled?") diff --git a/conifer/model.py b/conifer/model.py index 2812fe60..4b16c0aa 100644 --- a/conifer/model.py +++ b/conifer/model.py @@ -495,6 +495,9 @@ def _profile(self, what : Literal["scores", "thresholds"], ax=None): return ax + def load_shared_library(self, model_json, shared_library): + pass + class ModelMetaData: def __init__(self): self.version = version @@ -532,7 +535,7 @@ def make_model(ensembleDict, config=None): backend = get_backend(backend) return backend.make_model(ensembleDict, config) -def load_model(filename, new_config=None): +def load_model(filename, new_config=None, shared_library=True): ''' Load a Model from JSON file @@ -542,6 +545,14 @@ def load_model(filename, new_config=None): filename to load from new_config: dictionary (optional) if provided, override the configuration specified in the JSON file + shared_library: string|bool (optional) + If True, the shared library will be looked for in the same directory as the JSON file, using the timestamp of the last metadata entry available + If False, the shared library will not be loaded + If a string, it could be: + - path to the shared library to load + - path to the directory where to look for the .so file, using the timestamp of the last metadata entry available + + No shared library will be loaded if a new configuration is provided ''' with open(filename, 'r') as json_file: js = json.load(json_file) @@ -561,4 +572,25 @@ def load_model(filename, new_config=None): model = make_model(js, config) model._metadata = metadata + model._metadata + + if new_config is None and shared_library is not False: + shared_library_path=None + if isinstance(shared_library, str) and shared_library.endswith(".so"): + shared_library_path=shared_library + else: + from glob import glob + shared_library_dirpath=os.path.abspath(os.path.dirname(filename)) if shared_library is True else os.path.abspath(shared_library) + timestamps=[int(md._to_dict()["time"]) for md in model._metadata[-2::-1]] + so_files=glob(os.path.join(shared_library_dirpath, 'conifer_bridge_*.so')) + so_files=[os.path.basename(so_file) for so_file in so_files] + for timestamp in timestamps: + if f"conifer_bridge_{timestamp}.so" in so_files: + shared_library_path=os.path.join(shared_library_dirpath, f'conifer_bridge_{timestamp}.so') + break + + try: + model.load_shared_library(filename, shared_library_path) + except Exception: + print("An existing shared library was either not found or could not be loaded. Run model.compile()") + return model diff --git a/tests/test_save_load.py b/tests/test_save_load.py index c617ec8a..734a9ee5 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -2,6 +2,7 @@ import numpy as np import conifer import json +import os ''' Test conifer's model saving and loading functionality by loading some models and checking the predictions @@ -16,6 +17,29 @@ def test_hls_save_load(hls_convert, train_skl): y_hls_0, y_hls_1 = util.predict_skl(orig_model, X, y, load_model) np.testing.assert_array_equal(y_hls_0, y_hls_1) +def test_hls_reload_last_shared_library(hls_convert, train_skl): + clf, X, y = train_skl + initial_model = conifer.model.load_model(f'{hls_convert.config.output_dir}/{hls_convert.config.project_name}.json', shared_library = False) + initial_model.config.output_dir += '_loaded' + initial_model.compile() + # Re-load without recompiling to check if the shared library is loaded correctly + reload_model = conifer.model.load_model(f'{hls_convert.config.output_dir}_loaded/{hls_convert.config.project_name}.json', shared_library=True) + y_hls, y_hls_reload = util.predict_skl(initial_model, X, y, reload_model) + np.testing.assert_array_equal(y_hls, y_hls_reload) + assert os.path.basename(initial_model.bridge.__file__) == os.path.basename(reload_model.bridge.__file__), "Loaded two different shared libraries" + +def test_hls_reload_manual_shared_library(hls_convert, train_skl): + clf, X, y = train_skl + initial_model = conifer.model.load_model(f'{hls_convert.config.output_dir}/{hls_convert.config.project_name}.json', shared_library = False) + initial_model.config.output_dir += '_loaded' + initial_model.compile() + so_path = os.path.basename(initial_model.bridge.__file__) # manually get the shared library path + # Re-load without recompiling to check if the shared library is loaded correctly + reload_model = conifer.model.load_model(f'{hls_convert.config.output_dir}_loaded/{hls_convert.config.project_name}.json', shared_library=so_path) # pass the shared library path manually + y_hls, y_hls_reload = util.predict_skl(initial_model, X, y, reload_model) + np.testing.assert_array_equal(y_hls, y_hls_reload) + assert os.path.basename(initial_model.bridge.__file__) == os.path.basename(reload_model.bridge.__file__), "Loaded two different shared libraries" + def test_hdl_save_load(vhdl_convert, train_skl): orig_model = vhdl_convert clf, X, y = train_skl