-
Notifications
You must be signed in to change notification settings - Fork 53
Closed
Description
Bug description
I am getting different errors when I try to load back a saved session-based model.
Error 1:
...
ValueError: Unable to restore custom object of type _tf_keras_metric. Please make sure that any custom layers are included in the `custom_objects` arg when calling `load_model()` and make sure that all layers implement `get_config` and `from_config`.
This error goes away if I add import merlin.models.tf as mm after I import tensorflow, and I get another error:
Error 2:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[1], line 3
1 import tensorflow as tf
2 import merlin.models.tf as mm
----> 3 loaded_model = tf.keras.models.load_model('/models/examples/saved_model')
File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.__traceback__)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:129, in SequentialBlock.build(self, input_shape)
121 """Builds the sequential block
122
123 Parameters
(...)
126 The input shape, by default None
127 """
128 self._maybe_propagate_context(input_shape)
--> 129 build_sequentially(self, self.layers, input_shape)
File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:859, in build_sequentially(self, layers, input_shape)
854 v = TypeError(
855 f"Couldn't build {layer}, "
856 f"did you forget to add aggregation to {last_layer}?"
857 )
858 six.reraise(t, v, tb)
--> 859 input_shape = layer.compute_output_shape(input_shape)
860 last_layer = layer
861 self.built = True
File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/blocks/mlp.py:256, in _Dense.compute_output_shape(self, input_shape)
253 agg = tabular_aggregation_registry.parse(self.pre_aggregation)
254 input_shape = agg.compute_output_shape(input_shape)
--> 256 return self.dense.compute_output_shape(input_shape)
ValueError: The last dimension of the input shape of a Dense layer should be defined. Found None. Received: input_shape=(None, None)
Steps/Code to reproduce bug
import os
import itertools
import numpy as np
import tensorflow as tf
import merlin.models.tf as mm
from merlin.dataloader.ops.embeddings import EmbeddingOperator
from merlin.io import Dataset
from merlin.schema import Tags
from merlin.datasets.synthetic import generate_data
sequence_testing_data = generate_data("sequence-testing", num_rows=100)
sequence_testing_data.schema = sequence_testing_data.schema.select_by_tag(
Tags.SEQUENCE
).select_by_tag(Tags.CATEGORICAL)
seq_schema = sequence_testing_data.schema
item_id_name = seq_schema.select_by_tag(Tags.ITEM).first.properties['domain']['name']
target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
query_schema = seq_schema
output_schema = seq_schema.select_by_name(target)
d_model = 48
BATCH_SIZE = 32
dmodel = int(os.environ.get("dmodel", '48'))
input_block = mm.InputBlockV2(
query_schema,
embeddings=mm.Embeddings(
seq_schema.select_by_tag(Tags.CATEGORICAL),
sequence_combiner=None,
dim=dmodel
))
xlnet_block = mm.XLNetBlock(d_model=dmodel, n_head=2, n_layer=2)
def get_output_block(schema, input_block=None):
candidate_table = input_block["categorical"][item_id_name]
to_call = candidate_table
outputs = mm.CategoricalOutput(to_call=to_call)
return outputs
output_block = get_output_block(seq_schema, input_block=input_block)
projection = mm.MLPBlock(
[128, output_block.to_call.table.dim],
no_activation_last_layer=True,
)
session_encoder = mm.Encoder(
input_block,
mm.MLPBlock([128, dmodel], no_activation_last_layer=True),
xlnet_block,
projection,
)
model = mm.RetrievalModelV2(query=session_encoder, output=output_block)
optimizer = tf.keras.optimizers.Adam(
learning_rate=0.005,
)
loss = tf.keras.losses.CategoricalCrossentropy(
from_logits=True
)
model.compile(
run_eagerly=False,
optimizer=optimizer,
loss=loss,
metrics=mm.TopKMetricsAggregator.default_metrics(top_ks=[10])
)
model.fit(
sequence_testing_data,
batch_size=32,
epochs=1,
pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3, transformer=xlnet_block)
)
model.save('./saved_model')
Once the model is saved, please restart the kernel, and load back the model with the following script:
import tensorflow as tf
import merlin.models.tf as mm
loaded_model = tf.keras.models.load_model('./saved_model')
Expected behavior
We should be able to load back the model and then do offline evaluation or predictions, accordingly.
Environment details
- Merlin version:
- Platform: Docker image
- Python version:
- PyTorch version (GPU?):
- Tensorflow version (GPU?): merlin-tensorflow:23.04 (I am on dev branch of all the repos).