Skip to content

[FEA] Simplify the processing of the context variable features #978

@sararb

Description

@sararb

🚀 Feature request

This feature request aims to improve the processing and passing of the context variable features in MM blocks by providing clear documentation and improving how this variable is passed around the model's blocks.

Motivation

Limitations of TF serving signature:

  • The call method of a servable block should only have fixed arguments that are present at the serving time: {inputs, targets=None, training=False, testing=False, output_context=False}. (I could have missed other parameters?)
  • As a result, the context variable features should not be an argument of the call method of blocks like BaseModel and Encoder.
  • When list columns are passed as Ragged, the Model SavedSignature replaces the feature names with args_0, args_1, etc. This means that inputs to the model should not be ragged tensors (an exception is raised otherwise).

Current Processing of the context variable features:

  • The context features is created and processed in the model's call here. The inputs are kept as they are and passed to the model's blocks.
  • In a servable model block (such as the Encoder class), the argument features is removed from **kwargs (here) since it is not passed at serving time. This parameter is then redefined inside the call method (here) from the inputs (assuming features == inputs).

Limitations of the current processing:

  • The features passed to the Encoder will only contain the filtered inputs passed to that block (for example, in a User tower, the features parameter will only contain variables tagged as USER). This means that the user cannot access a context variable that is not part of the Encoder's inputs.
  • Each time a user creates a servable model block, they need to make sure to remove features from **kwargs and redefine it inside the call method, if needed. If not removed, This could lead to missing arguments in the call method signature that are only discovered at serving time.
  • If the context features contains list variables, they also need to call ProcessList(..)(features) to convert them into the right format for MM blocks.
  • The conversion of list features to ragged/dense tensors (using ProcessList()) is happening in three different places:
  1. When creating the context variable inside the model's call.
  2. Inside the input block prepare inputs.
  3. Inside the Encoder's call to redefine the context features from inputs.
    ==> This repetition of the conversion process can impact the performance of the model.

Your contribution

  • Repeating the conversion of inputs/features in different blocks is mainly related to not being able to serve a model trained with tf.RaggedTensor inputs. There is an ongoing work (discusssed here) in Tensorflow to support serving ragged tensors. It is expected to be finished in 2023-Q2 so it might be useful to track this work and include it in MM once it is released.

Metadata

Metadata

Assignees

Labels

P1documentationImprovements or additions to documentationenhancementNew feature or request

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions