Skip to content

[BUG] Getting error when loading back a saved session-based model #1132

@rnyak

Description

@rnyak

Bug description

I am getting different errors when I try to load back a saved session-based model.

Error 1:

...
ValueError: Unable to restore custom object of type _tf_keras_metric. Please make sure that any custom layers are included in the `custom_objects` arg when calling `load_model()` and make sure that all layers implement `get_config` and `from_config`.

This error goes away if I add import merlin.models.tf as mm after I import tensorflow, and I get another error:

Error 2:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 3
      1 import tensorflow as tf
      2 import merlin.models.tf as mm
----> 3 loaded_model = tf.keras.models.load_model('/models/examples/saved_model')

File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:129, in SequentialBlock.build(self, input_shape)
    121 """Builds the sequential block
    122 
    123 Parameters
   (...)
    126     The input shape, by default None
    127 """
    128 self._maybe_propagate_context(input_shape)
--> 129 build_sequentially(self, self.layers, input_shape)

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:859, in build_sequentially(self, layers, input_shape)
    854             v = TypeError(
    855                 f"Couldn't build {layer}, "
    856                 f"did you forget to add aggregation to {last_layer}?"
    857             )
    858         six.reraise(t, v, tb)
--> 859     input_shape = layer.compute_output_shape(input_shape)
    860     last_layer = layer
    861 self.built = True

File /usr/local/lib/python3.8/dist-packages/merlin/models/tf/blocks/mlp.py:256, in _Dense.compute_output_shape(self, input_shape)
    253     agg = tabular_aggregation_registry.parse(self.pre_aggregation)
    254     input_shape = agg.compute_output_shape(input_shape)
--> 256 return self.dense.compute_output_shape(input_shape)

ValueError: The last dimension of the input shape of a Dense layer should be defined. Found None. Received: input_shape=(None, None)

Steps/Code to reproduce bug

import os
import itertools

import numpy as np
import tensorflow as tf

import merlin.models.tf as mm
from merlin.dataloader.ops.embeddings import EmbeddingOperator
from merlin.io import Dataset
from merlin.schema import Tags
from merlin.datasets.synthetic import generate_data

sequence_testing_data = generate_data("sequence-testing", num_rows=100)
sequence_testing_data.schema = sequence_testing_data.schema.select_by_tag(
        Tags.SEQUENCE
    ).select_by_tag(Tags.CATEGORICAL)
seq_schema = sequence_testing_data.schema

item_id_name = seq_schema.select_by_tag(Tags.ITEM).first.properties['domain']['name']

target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]

query_schema = seq_schema
output_schema = seq_schema.select_by_name(target)

d_model = 48
BATCH_SIZE = 32

dmodel = int(os.environ.get("dmodel", '48'))

input_block = mm.InputBlockV2(
    query_schema,    
    embeddings=mm.Embeddings(
        seq_schema.select_by_tag(Tags.CATEGORICAL), 
        sequence_combiner=None,
        dim=dmodel
        ))

xlnet_block = mm.XLNetBlock(d_model=dmodel, n_head=2, n_layer=2)

def get_output_block(schema, input_block=None):
    candidate_table = input_block["categorical"][item_id_name]
    to_call = candidate_table
    outputs = mm.CategoricalOutput(to_call=to_call)
    return outputs

output_block = get_output_block(seq_schema, input_block=input_block)

projection = mm.MLPBlock(
    [128, output_block.to_call.table.dim],
    no_activation_last_layer=True,
)
session_encoder = mm.Encoder(
    input_block,
    mm.MLPBlock([128, dmodel], no_activation_last_layer=True),
    xlnet_block,
    projection,
)
model = mm.RetrievalModelV2(query=session_encoder, output=output_block)
optimizer = tf.keras.optimizers.Adam(
    learning_rate=0.005,
)

loss = tf.keras.losses.CategoricalCrossentropy(
    from_logits=True
)
model.compile(
    run_eagerly=False, 
    optimizer=optimizer, 
    loss=loss,
    metrics=mm.TopKMetricsAggregator.default_metrics(top_ks=[10])
)
model.fit(
    sequence_testing_data, 
    batch_size=32, 
    epochs=1, 
    pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3, transformer=xlnet_block)
)

model.save('./saved_model')

Once the model is saved, please restart the kernel, and load back the model with the following script:

import tensorflow as tf
import merlin.models.tf as mm
loaded_model = tf.keras.models.load_model('./saved_model')

Expected behavior

We should be able to load back the model and then do offline evaluation or predictions, accordingly.

Environment details

  • Merlin version:
  • Platform: Docker image
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?): merlin-tensorflow:23.04 (I am on dev branch of all the repos).

Additional context

Metadata

Metadata

Labels

P1bugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions