Skip to content

Commit adad2c0

Browse files
* Add an argument robust_masking to the Softmax layer to enable better numerical handling of the mask (currently if the mask violates any of the assumptions it will do numerically silly things silently).
* Plumb an argument that would opt into the usage of the new softmax layer for the official keras `MultiHeadAttention` layer and the model garden `TransformerEncoderBlock` layer. PiperOrigin-RevId: 831896060
1 parent 924b2d0 commit adad2c0

File tree

7 files changed

+46
-16
lines changed

7 files changed

+46
-16
lines changed

tf_keras/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'softmax_robust_masking\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], "
132+
argspec: "args=[\'self\', \'axis\', \'robust_masking\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'softmax_robust_masking\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], "
132+
argspec: "args=[\'self\', \'axis\', \'robust_masking\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/layers/activation/softmax.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class Softmax(Layer):
7070
Args:
7171
axis: Integer, or list of Integers, axis along which the softmax
7272
normalization is applied.
73+
robust_masking: Bool, if true will use a more robust implementation when
74+
dealing with masks.
7375
Call arguments:
7476
inputs: The inputs, or logits to the softmax layer.
7577
mask: A boolean mask of the same shape as `inputs`. The mask
@@ -80,23 +82,34 @@ class Softmax(Layer):
8082
Softmaxed output with the same shape as `inputs`.
8183
"""
8284

83-
def __init__(self, axis=-1, **kwargs):
85+
def __init__(self, axis=-1, robust_masking=False, **kwargs):
8486
super().__init__(**kwargs)
8587
self.supports_masking = True
88+
self.robust_masking = robust_masking
8689
self.axis = axis
8790

8891
def call(self, inputs, mask=None):
8992
if mask is not None:
90-
# Since mask is 1.0 for positions we want to keep and 0.0 for masked
91-
# positions, this operation will create a tensor which is 0.0 for
92-
# positions we want to attend and -1e.9 for masked positions.
93-
adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
94-
_large_compatible_negative(inputs.dtype)
95-
)
96-
97-
# Since we are adding it to the raw scores before the softmax, this
98-
# is effectively the same as removing these entirely.
99-
inputs += adder
93+
if self.robust_masking:
94+
# We keep the positions where the mask is True or > 0.5, and set
95+
# the other (masked) positions to -1e.9.
96+
if mask.dtype is not tf.bool:
97+
mask = tf.greater(mask, tf.constant(0.5, dtype=mask.dtype))
98+
inputs = tf.where(
99+
mask, inputs, _large_compatible_negative(inputs.dtype)
100+
)
101+
else:
102+
# Since mask is 1.0 for positions we want to keep and 0.0 for
103+
# masked positions, this operation will create a tensor which is
104+
# 0.0 for positions we want to attend and -1e.9 for masked
105+
# positions.
106+
adder = (1.0 - tf.cast(mask, inputs.dtype)) * (
107+
_large_compatible_negative(inputs.dtype)
108+
)
109+
110+
# Since we are adding it to the raw scores before the softmax, this
111+
# is effectively the same as removing these entirely.
112+
inputs += adder
100113
if isinstance(self.axis, (tuple, list)):
101114
if len(self.axis) > 1:
102115
return tf.exp(
@@ -109,6 +122,8 @@ def call(self, inputs, mask=None):
109122

110123
def get_config(self):
111124
config = {"axis": self.axis}
125+
if self.robust_masking:
126+
config["robust_masking"] = True
112127
base_config = super().get_config()
113128
return dict(list(base_config.items()) + list(config.items()))
114129

tf_keras/layers/activation/softmax_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def test_softmax(self):
3131
supports_masking=True,
3232
)
3333

34+
def test_softmax_robust_masking(self):
35+
test_utils.layer_test(
36+
keras.layers.Softmax,
37+
kwargs={"axis": 1, "robust_masking": True},
38+
input_shape=(2, 3, 4),
39+
supports_masking=True,
40+
)
41+
3442

3543
if __name__ == "__main__":
3644
tf.test.main()

tf_keras/layers/attention/multi_head_attention.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ class MultiHeadAttention(Layer):
198198
activity_regularizer: Regularizer for dense layer activity.
199199
kernel_constraint: Constraint for dense layer kernels.
200200
bias_constraint: Constraint for dense layer kernels.
201+
softmax_robust_masking: If true will use a more numerically robust
202+
masking impl.
201203
202204
Call arguments:
203205
query: Query `Tensor` of shape `(B, T, dim)`.
@@ -247,6 +249,7 @@ def __init__(
247249
activity_regularizer=None,
248250
kernel_constraint=None,
249251
bias_constraint=None,
252+
softmax_robust_masking=False,
250253
**kwargs,
251254
):
252255
super().__init__(**kwargs)
@@ -264,6 +267,7 @@ def __init__(
264267
self._activity_regularizer = regularizers.get(activity_regularizer)
265268
self._kernel_constraint = constraints.get(kernel_constraint)
266269
self._bias_constraint = constraints.get(bias_constraint)
270+
self._softmax_robust_masking = softmax_robust_masking
267271
if attention_axes is not None and not isinstance(
268272
attention_axes, collections.abc.Sized
269273
):
@@ -298,6 +302,7 @@ def get_config(self):
298302
"query_shape": self._query_shape,
299303
"key_shape": self._key_shape,
300304
"value_shape": self._value_shape,
305+
"softmax_robust_masking": self._softmax_robust_masking,
301306
}
302307
base_config = super().get_config()
303308
return dict(list(base_config.items()) + list(config.items()))
@@ -476,7 +481,9 @@ def _build_attention(self, rank):
476481
)
477482
)
478483
self._softmax = activation.Softmax(
479-
axis=norm_axes, dtype=self._dtype_policy
484+
axis=norm_axes,
485+
robust_masking=self._softmax_robust_masking,
486+
dtype=self._dtype_policy,
480487
)
481488
self._dropout_layer = regularization.Dropout(
482489
rate=self._dropout, dtype=self._dtype_policy

0 commit comments

Comments
 (0)