diff --git a/coremltools/converters/_converters_entry.py b/coremltools/converters/_converters_entry.py index 1169bc2d2..6598c518a 100644 --- a/coremltools/converters/_converters_entry.py +++ b/coremltools/converters/_converters_entry.py @@ -64,6 +64,7 @@ def convert( compute_units=_ComputeUnit.ALL, package_dir=None, debug=False, + **kwargs ): """ Convert a TensorFlow or PyTorch model to the Core ML model format as either @@ -363,6 +364,9 @@ def skip_real_div_ops(op): - For Tensorflow conversion, it will cause to display extra logging and visualizations. + Note that for TensorFlow SaveModel models with more than 1 tag set, + ``tags: list[str]``, can be used to specify a set of tags. + Returns ------- @@ -458,6 +462,7 @@ def skip_real_div_ops(op): package_dir=package_dir, debug=debug, specification_version=specification_version, + **kwargs ) if exact_target == 'milinternal': diff --git a/coremltools/converters/mil/frontend/tensorflow/load.py b/coremltools/converters/mil/frontend/tensorflow/load.py index 3dd38cbec..b770b21ad 100644 --- a/coremltools/converters/mil/frontend/tensorflow/load.py +++ b/coremltools/converters/mil/frontend/tensorflow/load.py @@ -62,7 +62,8 @@ def load(self): logging.info("Loading TensorFlow model '{}'".format(self.model)) outputs = self.kwargs.get("outputs", None) output_names = get_output_names(outputs) - self._graph_def = self._graph_def_from_model(output_names) + tags = self.kwargs.get("tags", None) + self._graph_def = self._graph_def_from_model(output_names, tags) if self._graph_def is not None and len(self._graph_def.node) == 0: msg = "tf.Graph should have at least 1 node, Got empty graph." @@ -88,7 +89,7 @@ def load(self): return program # @abstractmethod - def _graph_def_from_model(self, output_names=None): + def _graph_def_from_model(self, output_names=None, tags=None): """Load TensorFlow model into GraphDef. Overwrite for different TF versions.""" pass @@ -139,7 +140,7 @@ def __init__(self, model, debug=False, **kwargs): """ TFLoader.__init__(self, model, debug, **kwargs) - def _graph_def_from_model(self, output_names=None): + def _graph_def_from_model(self, output_names=None, tags=None): """Overwrites TFLoader._graph_def_from_model()""" msg = "Expected model format: [tf.Graph | .pb | SavedModel | tf.keras.Model | .h5], got {}" if isinstance(self.model, tf.Graph) and hasattr(self.model, "as_graph_def"): @@ -170,7 +171,7 @@ def _graph_def_from_model(self, output_names=None): graph_def = self._from_tf_keras_model(self.model) return self.extract_sub_graph(graph_def, output_names) elif os.path.isdir(str(self.model)): - graph_def = self._from_saved_model(self.model) + graph_def = self._from_saved_model(self.model, tags=tags) return self.extract_sub_graph(graph_def, output_names) else: raise NotImplementedError(msg.format(self.model)) diff --git a/coremltools/converters/mil/frontend/tensorflow2/load.py b/coremltools/converters/mil/frontend/tensorflow2/load.py index 546c6ee75..50e6d6937 100644 --- a/coremltools/converters/mil/frontend/tensorflow2/load.py +++ b/coremltools/converters/mil/frontend/tensorflow2/load.py @@ -96,7 +96,7 @@ def __init__(self, model, debug=False, **kwargs): fuse_dilation_conv, ] - def _get_concrete_functions_and_graph_def(self): + def _get_concrete_functions_and_graph_def(self, tags=None): msg = ( "Expected model format: [SavedModel | [concrete_function] | " "tf.keras.Model | .h5], got {}" @@ -120,7 +120,7 @@ def _get_concrete_functions_and_graph_def(self): and (self.model.endswith(".h5") or self.model.endswith(".hdf5")): cfs = self._concrete_fn_from_tf_keras_or_h5(self.model) elif _os_path.isdir(self.model): - saved_model = _tf.saved_model.load(self.model) + saved_model = _tf.saved_model.load(self.model, tags=tags) sv = saved_model.signatures.values() cfs = sv if isinstance(sv, list) else list(sv) else: @@ -132,9 +132,9 @@ def _get_concrete_functions_and_graph_def(self): return cfs, graph_def - def _graph_def_from_model(self, output_names=None): + def _graph_def_from_model(self, output_names=None, tags=None): """Overwrites TFLoader._graph_def_from_model()""" - _, graph_def = self._get_concrete_functions_and_graph_def() + _, graph_def = self._get_concrete_functions_and_graph_def(tags=tags) return self.extract_sub_graph(graph_def, output_names) def _tf_ssa_from_graph_def(self, fn_name="main"): diff --git a/coremltools/test/api/test_api_examples.py b/coremltools/test/api/test_api_examples.py index 9177ca981..94512426e 100644 --- a/coremltools/test/api/test_api_examples.py +++ b/coremltools/test/api/test_api_examples.py @@ -302,6 +302,42 @@ def test_convert_from_saved_model_dir(): mlmodel = ct.convert("./saved_model") mlmodel.save("./model.mlmodel") + @staticmethod + def test_convert_from_two_tags_saved_model_dir(tmpdir): + import tensorflow as tf + from tensorflow.compat.v1.saved_model import build_tensor_info + from tensorflow.compat.v1.saved_model import signature_constants + from tensorflow.compat.v1.saved_model import signature_def_utils + + @tf.function + def add(a, b): + return a + b + + c = add.get_concrete_function(tf.constant(21.0), tf.constant(21.0)) + + save_path = str(tmpdir) + builder = tf.compat.v1.saved_model.Builder(save_path) + + with tf.compat.v1.Session(graph=c.graph) as sess: + tensor_info_a = build_tensor_info(c.graph.inputs[0]) + tensor_info_b = build_tensor_info(c.graph.inputs[1]) + tensor_info_y = build_tensor_info(c.graph.outputs[0]) + + prediction_signature = signature_def_utils.build_signature_def( + inputs={'a': tensor_info_a, 'b': tensor_info_b}, + outputs={'output': tensor_info_y}, + method_name=signature_constants.PREDICT_METHOD_NAME) + + builder.add_meta_graph_and_variables(sess, ["serve"], + signature_def_map={ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + prediction_signature, + }) + + builder.add_meta_graph(["serve", "tpu"]) + builder.save() + + ct.convert(save_path, source="tensorflow", tags=["serve"]) @staticmethod def test_keras_custom_layer_model():