Skip to content

Commit 9980689

Browse files
sararbrnyakedknv
authored
Add support of transformer-based retrieval models (#1128)
* add ragged support in topk block * extend candidate embeddings extraction to CategoricalOutput * make bias term optional in the weight-tying class * add comment about top-k only works for the last item in the session --------- Co-authored-by: rnyak <[email protected]> Co-authored-by: edknv <[email protected]>
1 parent ed45657 commit 9980689

File tree

5 files changed

+35
-10
lines changed

5 files changed

+35
-10
lines changed

merlin/models/tf/models/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from merlin.models.tf.metrics.topk import TopKMetricsAggregator, filter_topk_metrics, split_metrics
5858
from merlin.models.tf.models.utils import parse_prediction_blocks
5959
from merlin.models.tf.outputs.base import ModelOutput, ModelOutputType
60+
from merlin.models.tf.outputs.classification import CategoricalOutput
6061
from merlin.models.tf.outputs.contrastive import ContrastiveOutput
6162
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
6263
from merlin.models.tf.transforms.features import PrepareFeatures, expected_input_cols_from_schema
@@ -2374,7 +2375,7 @@ def candidate_embeddings(
23742375

23752376
return candidate.encode(dataset, index=index, **kwargs)
23762377

2377-
if isinstance(self.last, ContrastiveOutput):
2378+
if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)):
23782379
return self.last.to_dataset()
23792380

23802381
raise Exception(...)

merlin/models/tf/outputs/classification.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ class EmbeddingTablePrediction(Layer):
304304
The embedding table to use as the weight matrix
305305
bias_initializer : str, optional
306306
Initializer for the bias vector, by default "zeros"
307+
use_bias: bool, optional
308+
Whether to add a bias term to weight-tying, by default False
307309
308310
References:
309311
----------
@@ -312,18 +314,20 @@ class EmbeddingTablePrediction(Layer):
312314
arXiv:1611.01462 (2016).
313315
"""
314316

315-
def __init__(self, table: EmbeddingTable, bias_initializer="zeros", **kwargs):
317+
def __init__(self, table: EmbeddingTable, bias_initializer="zeros", use_bias=False, **kwargs):
316318
self.table = table
317319
self.num_classes = table.input_dim
318320
self.bias_initializer = bias_initializer
321+
self.use_bias = use_bias
319322
super().__init__(**kwargs)
320323

321324
def build(self, input_shape):
322-
self.bias = self.add_weight(
323-
name="output_layer_bias",
324-
shape=(self.num_classes,),
325-
initializer=self.bias_initializer,
326-
)
325+
if self.use_bias:
326+
self.bias = self.add_weight(
327+
name="output_layer_bias",
328+
shape=(self.num_classes,),
329+
initializer=self.bias_initializer,
330+
)
327331
self.table.build(input_shape)
328332
return super().build(input_shape)
329333

@@ -333,7 +337,8 @@ def call(self, inputs, training=False, **kwargs) -> tf.Tensor:
333337
original_inputs = inputs
334338
inputs = inputs.flat_values
335339
logits = tf.matmul(inputs, self.table.table.embeddings, transpose_b=True)
336-
logits = tf.nn.bias_add(logits, self.bias)
340+
if self.use_bias:
341+
logits = tf.nn.bias_add(logits, self.bias)
337342
if is_ragged:
338343
logits = original_inputs.with_flat_values(logits)
339344
return logits

merlin/models/tf/outputs/topk.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ def call(
206206
"You should call the `index` method first to " "set the _candidates index."
207207
)
208208

209+
if isinstance(inputs, tf.RaggedTensor):
210+
# Evaluates on last session's item only
211+
# (which is the default mode during inference too).
212+
# TODO extend top-k generation to other items in the input session.
213+
inputs = tf.squeeze(inputs.to_tensor(), axis=1)
209214
tf.assert_equal(
210215
tf.shape(inputs)[1],
211216
tf.shape(self._candidates)[1],
@@ -220,6 +225,11 @@ def call(
220225
assert targets is not None, ValueError(
221226
"Targets should be provided during the evaluation mode"
222227
)
228+
if isinstance(targets, tf.RaggedTensor):
229+
targets = tf.ragged.boolean_mask(
230+
targets, targets._keras_mask.with_row_splits_dtype(targets.row_splits.dtype)
231+
)
232+
targets = targets.to_tensor()
223233
targets = tf.cast(tf.squeeze(targets), tf.int32)
224234
targets = tf.cast(tf.expand_dims(targets, -1) == top_ids, tf.float32)
225235
targets = tf.reshape(targets, tf.shape(top_scores))

tests/unit/tf/outputs/test_classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def test_categorical_output(sequence_testing_data: Dataset, run_eagerly):
9898

9999

100100
@pytest.mark.parametrize("run_eagerly", [True, False])
101-
def test_last_item_prediction(sequence_testing_data: Dataset, run_eagerly):
101+
@pytest.mark.parametrize("use_bias", [True, False])
102+
def test_last_item_prediction(sequence_testing_data: Dataset, run_eagerly, use_bias):
102103
dataloader, schema = testing_utils.loader_for_last_item_prediction(sequence_testing_data)
103104
embeddings = mm.Embeddings(
104105
schema,
@@ -110,7 +111,7 @@ def test_last_item_prediction(sequence_testing_data: Dataset, run_eagerly):
110111
schema["item_id_seq"],
111112
CategoricalTarget(schema["item_id_seq"]),
112113
embeddings["item_id_seq"],
113-
EmbeddingTablePrediction(embeddings["item_id_seq"]),
114+
EmbeddingTablePrediction(embeddings["item_id_seq"], use_bias=use_bias),
114115
]
115116

116117
for target in predictions:

tests/unit/tf/outputs/test_topk.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import numpy as np
1617
import pytest
1718
import tensorflow as tf
1819

@@ -28,6 +29,10 @@ def test_brute_force_layer():
2829
candidates = tf.random.uniform(shape=(num_candidates, 4), dtype=tf.float32)
2930
query = tf.random.uniform(shape=(num_queries, 4), dtype=tf.float32)
3031

32+
# Create a ragged query
33+
elements = np.random.rand(num_queries, 1, 4)
34+
ragged_query = tf.ragged.constant(elements)
35+
3136
wrong_candidates_rank = tf.random.uniform(shape=(num_candidates,), dtype=tf.float32)
3237
wrong_query_dim = tf.random.uniform(shape=(num_queries, 8), dtype=tf.float32)
3338
wrong_identifiers_shape = tf.range(num_candidates + 1, dtype=tf.int32)
@@ -60,6 +65,9 @@ def test_brute_force_layer():
6065
assert list(topk_output.scores.shape) == [num_queries, top_k]
6166
assert list(topk_output.identifiers.shape) == [num_queries, top_k]
6267
assert isinstance(topk_output, TopKPrediction)
68+
assert list(topk_output.scores.shape) == [num_queries, top_k]
69+
ragged_topk_output = brute_force(ragged_query)
70+
assert list(ragged_topk_output.scores.shape) == [num_queries, top_k]
6371

6472
with pytest.raises(Exception) as excinfo:
6573
brute_force(query, targets=None, testing=True)

0 commit comments

Comments
 (0)