diff --git a/transformers4rec/torch/features/tabular.py b/transformers4rec/torch/features/tabular.py index 83845b80c..3b1896468 100644 --- a/transformers4rec/torch/features/tabular.py +++ b/transformers4rec/torch/features/tabular.py @@ -168,6 +168,7 @@ def from_schema( # type: ignore if continuous_soft_embeddings: maybe_continuous_module = cls.SOFT_EMBEDDING_MODULE_CLASS.from_schema( schema, + max_sequence_length=max_sequence_length, tags=continuous_tags, **kwargs, ) @@ -177,7 +178,7 @@ def from_schema( # type: ignore ) if categorical_tags: maybe_categorical_module = cls.EMBEDDING_MODULE_CLASS.from_schema( - schema, tags=categorical_tags, **kwargs + schema, max_sequence_length=max_sequence_length, tags=categorical_tags, **kwargs ) if pretrained_embeddings_tags: maybe_pretrained_module = cls.PRETRAINED_EMBEDDING_MODULE_CLASS.from_schema(