Skip to content

Commit 9f50d7f

Browse files
authored
Add EmbedReduce layer. (#83)
This layer looks up multiple embeddings per row and then combines the result to create a single embedding per row. Also renamed `types.TensorShape` to `types.Shape`.
1 parent 9cad6b7 commit 9f50d7f

File tree

11 files changed

+398
-15
lines changed

11 files changed

+398
-15
lines changed

keras_rs/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras_rs.src.layers.embedding.embed_reduce import (
8+
EmbedReduce as EmbedReduce,
9+
)
710
from keras_rs.src.layers.feature_interaction.dot_interaction import (
811
DotInteraction as DotInteraction,
912
)

keras_rs/src/layers/embedding/__init__.py

Whitespace-only changes.
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from typing import Any, Optional
2+
3+
import keras
4+
from keras import ops
5+
6+
from keras_rs.src import types
7+
from keras_rs.src.api_export import keras_rs_export
8+
from keras_rs.src.utils.keras_utils import check_shapes_compatible
9+
10+
SUPPORTED_COMBINERS = ("mean", "sum", "sqrtn")
11+
12+
13+
@keras_rs_export("keras_rs.layers.EmbedReduce")
14+
class EmbedReduce(keras.layers.Embedding):
15+
"""An embedding layer that reduces with a combiner.
16+
17+
This layer embeds inputs and then applies a reduction to combine a set of
18+
embeddings into a single embedding. This is typically used to embed a
19+
sequence of items as a single embedding.
20+
21+
If the inputs passed to `__call__` are 1D, no reduction is applied. If the
22+
inputs are 2D, dimension 1 is reduced using the combiner so that the result
23+
is of shape `(batch_size, output_dim`). Inputs of rank 3 and higher are not
24+
allowed. Weights can optionally be passed to the `__call__` method to
25+
apply weights to different samples before reduction.
26+
27+
This layer supports sparse inputs and ragged inputs with backends that
28+
support them. The output after reduction is dense. For ragged inputs, the
29+
ragged dimension must be 1 as it is the dimension that is reduced.
30+
31+
Args:
32+
input_dim: Integer. Size of the vocabulary, maximum integer index + 1.
33+
output_dim: Integer. Dimension of the dense embedding.
34+
embeddings_initializer: Initializer for the `embeddings` matrix (see
35+
`keras.initializers`).
36+
embeddings_regularizer: Regularizer function applied to the `embeddings`
37+
matrix (see `keras.regularizers`).
38+
embeddings_constraint: Constraint function applied to the `embeddings`
39+
matrix (see `keras.constraints`).
40+
mask_zero: Boolean, whether or not the input value 0 is a special
41+
"padding" value that should be masked out. This is useful when using
42+
recurrent layers which may take variable length input. If this is
43+
`True`, then all subsequent layers in the model need to support
44+
masking or an exception will be raised. If `mask_zero` is set to
45+
`True`, as a consequence, index 0 cannot be used in the vocabulary
46+
(`input_dim` should equal size of vocabulary + 1).
47+
weights: Optional floating-point matrix of size
48+
`(input_dim, output_dim)`. The initial embeddings values to use.
49+
combiner: Specifies how to reduce if there are multiple entries in a
50+
single row. Currently `mean`, `sqrtn` and `sum` are supported.
51+
`mean` is the default. `sqrtn` often achieves good accuracy, in
52+
particular with bag-of-words columns.
53+
**kwargs: Additional keyword arguments passed to `Embedding`.
54+
"""
55+
56+
def __init__(
57+
self,
58+
input_dim: int,
59+
output_dim: int,
60+
embeddings_initializer: types.InitializerLike = "uniform",
61+
embeddings_regularizer: Optional[types.RegularizerLike] = None,
62+
embeddings_constraint: Optional[types.ConstraintLike] = None,
63+
mask_zero: bool = False,
64+
weights: types.Tensor = None,
65+
combiner: str = "mean",
66+
**kwargs: Any,
67+
) -> None:
68+
super().__init__(
69+
input_dim,
70+
output_dim,
71+
embeddings_initializer=embeddings_initializer,
72+
embeddings_regularizer=embeddings_regularizer,
73+
embeddings_constraint=embeddings_constraint,
74+
mask_zero=mask_zero,
75+
weights=weights,
76+
**kwargs,
77+
)
78+
if combiner not in SUPPORTED_COMBINERS:
79+
raise ValueError(
80+
f"Invalid `combiner`: '{combiner}', "
81+
f"use one of {', '.join(SUPPORTED_COMBINERS)}."
82+
)
83+
self.combiner = combiner
84+
85+
def call(
86+
self,
87+
inputs: types.Tensor,
88+
weights: Optional[types.Tensor] = None,
89+
) -> types.Tensor:
90+
"""Apply embedding and reduction.
91+
92+
Args:
93+
inputs: 1D tensor to embed or 2D tensor to embed and reduce.
94+
weights: Optional tensor of weights to apply before reduction, which
95+
can be 1D or 2D and must match for the first dimension of
96+
`inputs` (1D case) or match the shape of `inputs` (2D case).
97+
98+
Returns:
99+
A dense 2D tensor of shape `(batch_size, output_dim)`.
100+
"""
101+
x = super().call(inputs)
102+
unreduced_rank = len(x.shape)
103+
104+
# Check that weights has a compatible shape.
105+
if weights is not None:
106+
weights_rank = len(weights.shape)
107+
if weights_rank > unreduced_rank or not check_shapes_compatible(
108+
x.shape[0:weights_rank], weights.shape
109+
):
110+
raise ValueError(
111+
f"The shape of `weights`: {weights.shape} is not compatible"
112+
f" with the shape of `inputs` after embedding: {x.shape}."
113+
)
114+
115+
dtype = (
116+
x.dtype
117+
if weights is None
118+
else keras.backend.result_type(x.dtype, weights.dtype)
119+
)
120+
121+
# When `weights` is `None`:
122+
# - For ragged inputs, after embedding, we get a ragged result that has
123+
# a ragged dimension of 1, but when we do the "mean" or "sqrtn", we
124+
# need to divide by the number of items in each row. However, there is
125+
# no explicit cross backend API to get the row length. `ones_like`
126+
# gives us a ragged tensor that is ragged in the same way as the
127+
# inputs. When we do `ops.sum(weights, axis=-2)`, it gives us the
128+
# number of items per row.
129+
# - For sparse inputs, after embedding, we get a dense tensor, not a
130+
# sparse tensor. What it does for missing values is use embedding 0.
131+
# These are bogus embedding ands should be ignored. `ones_like` gives
132+
# us a sparse tensor with the exact same missing values. Later, when
133+
# we do `x = ops.multiply(x, weights)`, which masks the bogus values
134+
# (note that `weights` has been densified beforehand). Additionally,
135+
# when we do `ops.sum(weights, axis=-2)`, it gives us the number of
136+
# items per row.
137+
#
138+
# When `unreduced_rank <= 2`, this means that the inputs where 1D and
139+
# dense, there is only one embedding per row, so there is no real
140+
# reduction is going on.
141+
# - For mean: result = weights * x / weights = x we don't need `weights`
142+
# - For sqrtn: result = weights * x / sqrt(square(weights)) = x we don't
143+
# needs `weights`
144+
# - For sum however: `result = weights * x` we do need `weights`.
145+
# So for mean and sqrtn we don't need the weights, we use ones instead.
146+
# This is to avoid divisions by zero and improve the precision.
147+
if weights is None or (unreduced_rank <= 2 and self.combiner != "sum"):
148+
# Discard the weights if there were some and create a mask for
149+
# ragged and sparse tensors to mask the result correctly (sparse
150+
# only) and the apply the reduction correctly (ragged and sparse).
151+
weights = ops.ones_like(inputs, dtype=dtype)
152+
else:
153+
weights = ops.cast(weights, dtype)
154+
155+
# When looking up using sparse indices, the result is dense but contains
156+
# values that should be ignored as all missing values use index 0. We
157+
# use `weights` as a mask, but it needs to be densified as
158+
# `expand_dims` and broadcasting a sparse tensor does not produce the
159+
# expected result.
160+
weights = ops.convert_to_tensor(weights, sparse=False)
161+
162+
# Make weights and the unreduced embeddings have the same rank.
163+
weights_rank = len(weights.shape)
164+
if weights_rank < unreduced_rank:
165+
weights = ops.expand_dims(
166+
weights, axis=tuple(range(weights_rank, unreduced_rank))
167+
)
168+
169+
# Note that `x` and `weights` are:
170+
# - ragged if `inputs` was ragged and `weights` was ragged or None
171+
# - dense otherwise (even if `inputs` and `weights` were sparse).
172+
x = ops.multiply(x, weights)
173+
174+
if unreduced_rank <= 2:
175+
# No reduction is applied.
176+
return x
177+
178+
# After this reduction, `x` is always dense as we reduce the ragged
179+
# dimension in the ragged case.
180+
x = ops.sum(x, axis=-2)
181+
182+
# Apply the right divisor for the combiner.
183+
# Where we use `weights` in the divisor, we use
184+
# `ops.sum(weights, axis=-2)` which always makes it dense as we reduce
185+
# the ragged dimension in the ragged case.
186+
if self.combiner == "mean":
187+
return ops.divide_no_nan(x, ops.sum(weights, axis=-2))
188+
elif self.combiner == "sum":
189+
return x
190+
elif self.combiner == "sqrtn":
191+
return ops.divide_no_nan(
192+
x, ops.sqrt(ops.sum(ops.square(weights), axis=-2))
193+
)
194+
195+
def get_config(self) -> dict[str, Any]:
196+
config: dict[str, Any] = super().get_config()
197+
198+
config.update(
199+
{
200+
"combiner": self.combiner,
201+
}
202+
)
203+
204+
return config
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import math
2+
3+
import keras
4+
from absl.testing import parameterized
5+
from keras import ops
6+
from keras.layers import deserialize
7+
from keras.layers import serialize
8+
9+
from keras_rs.src import testing
10+
from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce
11+
12+
13+
class EmbedReduceTest(testing.TestCase, parameterized.TestCase):
14+
@parameterized.named_parameters(
15+
[
16+
(
17+
(
18+
f"{combiner}_{input_type}_{input_rank}d"
19+
f"{'_weights' if use_weights else ''}"
20+
),
21+
combiner,
22+
input_type,
23+
input_rank,
24+
use_weights,
25+
)
26+
for combiner in ("sum", "mean", "sqrtn")
27+
for input_type, input_rank in (
28+
("dense", 1),
29+
("dense", 2),
30+
("ragged", 2),
31+
("sparse", 2),
32+
)
33+
for use_weights in (False, True)
34+
]
35+
)
36+
def test_call(self, combiner, input_type, input_rank, use_weights):
37+
if input_type == "ragged" and keras.backend.backend() != "tensorflow":
38+
self.skipTest(f"ragged not supported on {keras.backend.backend()}")
39+
if input_type == "sparse" and keras.backend.backend() not in (
40+
"jax",
41+
"tensorflow",
42+
):
43+
self.skipTest(f"sparse not supported on {keras.backend.backend()}")
44+
45+
if input_type == "dense" and input_rank == 1:
46+
inputs = ops.convert_to_tensor([1, 2])
47+
weights = ops.convert_to_tensor([1.0, 2.0])
48+
elif input_type == "dense" and input_rank == 2:
49+
inputs = ops.convert_to_tensor([[1, 2], [3, 4]])
50+
weights = ops.convert_to_tensor([[1.0, 2.0], [3.0, 4.0]])
51+
elif input_type == "ragged" and input_rank == 2:
52+
import tensorflow as tf
53+
54+
inputs = tf.ragged.constant([[1], [2, 3, 4, 5]])
55+
weights = tf.ragged.constant([[1.0], [1.0, 2.0, 3.0, 4.0]])
56+
elif input_type == "sparse" and input_rank == 2:
57+
indices = [[0, 0], [1, 0], [1, 1], [1, 2], [1, 3]]
58+
59+
if keras.backend.backend() == "tensorflow":
60+
import tensorflow as tf
61+
62+
inputs = tf.sparse.reorder(
63+
tf.SparseTensor(indices, [1, 2, 3, 4, 5], (2, 4))
64+
)
65+
weights = tf.sparse.reorder(
66+
tf.SparseTensor(indices, [1.0, 1.0, 2.0, 3.0, 4.0], (2, 4))
67+
)
68+
elif keras.backend.backend() == "jax":
69+
from jax.experimental import sparse as jax_sparse
70+
71+
inputs = jax_sparse.BCOO(
72+
([1, 2, 3, 4, 5], indices),
73+
shape=(2, 4),
74+
unique_indices=True,
75+
)
76+
weights = jax_sparse.BCOO(
77+
([1.0, 1.0, 2.0, 3.0, 4.0], indices),
78+
shape=(2, 4),
79+
unique_indices=True,
80+
)
81+
82+
if not use_weights:
83+
weights = None
84+
85+
layer = EmbedReduce(10, 20, combiner=combiner)
86+
res = layer(inputs, weights)
87+
88+
self.assertEqual(res.shape, (2, 20))
89+
90+
e = layer.embeddings
91+
if input_type == "dense" and input_rank == 1:
92+
if combiner == "sum" and use_weights:
93+
expected = [e[1], e[2] * 2.0]
94+
else:
95+
expected = [e[1], e[2]]
96+
elif input_type == "dense" and input_rank == 2:
97+
if use_weights:
98+
expected = [e[1] + e[2] * 2.0, e[3] * 3.0 + e[4] * 4.0]
99+
else:
100+
expected = [e[1] + e[2], e[3] + e[4]]
101+
102+
if combiner == "mean":
103+
expected[0] /= 3.0 if use_weights else 2.0
104+
expected[1] /= 7.0 if use_weights else 2.0
105+
elif combiner == "sqrtn":
106+
expected[0] /= math.sqrt(5.0 if use_weights else 2.0)
107+
expected[1] /= math.sqrt(25.0 if use_weights else 2.0)
108+
else: # ragged, sparse and input_rank == 2
109+
if use_weights:
110+
expected = [e[1], e[2] + e[3] * 2.0 + e[4] * 3.0 + e[5] * 4.0]
111+
else:
112+
expected = [e[1], e[2] + e[3] + e[4] + e[5]]
113+
114+
if combiner == "mean":
115+
expected[1] /= 10.0 if use_weights else 4.0
116+
elif combiner == "sqrtn":
117+
expected[1] /= math.sqrt(30.0 if use_weights else 4.0)
118+
119+
self.assertAllClose(res, expected)
120+
121+
def test_predict(self):
122+
input = keras.random.randint((5, 7), minval=0, maxval=10)
123+
model = keras.models.Sequential([EmbedReduce(10, 20)])
124+
model.predict(input, batch_size=2)
125+
126+
def test_serialization(self):
127+
layer = EmbedReduce(10, 20, combiner="sqrtn")
128+
restored = deserialize(serialize(layer))
129+
self.assertDictEqual(layer.get_config(), restored.get_config())
130+
131+
def test_model_saving(self):
132+
input = keras.random.randint((5, 7), minval=0, maxval=10)
133+
model = keras.models.Sequential([EmbedReduce(10, 20)])
134+
135+
self.run_model_saving_test(
136+
model=model,
137+
input_data=input,
138+
)

keras_rs/src/layers/feature_interaction/dot_interaction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ def call(self, inputs: list[types.Tensor]) -> types.Tensor:
205205
return activations
206206

207207
def compute_output_shape(
208-
self, input_shape: list[types.TensorShape]
209-
) -> types.TensorShape:
208+
self, input_shape: list[types.Shape]
209+
) -> types.Shape:
210210
num_features = len(input_shape)
211211
batch_size = input_shape[0][0]
212212

keras_rs/src/layers/feature_interaction/feature_cross.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(
129129
f"`diag_scale={self.diag_scale}`"
130130
)
131131

132-
def build(self, input_shape: types.TensorShape) -> None:
132+
def build(self, input_shape: types.Shape) -> None:
133133
last_dim = input_shape[-1]
134134

135135
if self.projection_dim is not None:

keras_rs/src/metrics/ranking_metrics_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def get_shuffled_indices(
10-
shape: types.TensorShape,
10+
shape: types.Shape,
1111
mask: Optional[types.Tensor] = None,
1212
shuffle_ties: bool = True,
1313
seed: Optional[Union[int, keras.random.SeedGenerator]] = None,

0 commit comments

Comments
 (0)