diff --git a/keras_hub/src/layers/modeling/cached_multi_head_attention.py b/keras_hub/src/layers/modeling/cached_multi_head_attention.py index 0441e71845..cf8a4d3f19 100644 --- a/keras_hub/src/layers/modeling/cached_multi_head_attention.py +++ b/keras_hub/src/layers/modeling/cached_multi_head_attention.py @@ -14,11 +14,11 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): - No cache, same as regular multi-head attention. - Static cache (`cache_update_index` is None). In this case, the - cached key/value projections will be used and the input values will - be ignored. + cached key/value projections will be used and the input values will + be ignored. - Updated cache (`cache_update_index` is not None). In this case, new - key/value projections are computed using the input, and spliced into - the cache at the specified index. + key/value projections are computed using the input, and spliced into + the cache at the specified index. Note that caching is useful only during inference and should not be used during training. @@ -56,12 +56,11 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): training mode or in inference mode. Returns: - An `(attention_output, cache)` tuple. `attention_output` is the result - of the computation, of shape `(B, T, dim)`, where `T` is for target - sequence shapes and `dim` is the query input last dimension if - `output_shape` is `None`. Otherwise, the multi-head outputs are - projected to the shape specified by `output_shape`. `cache` is the - updated cache. + Depending on the `return_attention_scores` and `cache` arguments, this method returns one of: + - `attention_output` + - `(attention_output, attention_scores)` + - `(attention_output, cache)` + - `(attention_output, attention_scores, cache)` """ def call( @@ -72,6 +71,7 @@ def call( attention_mask=None, cache=None, cache_update_index=None, + return_attention_scores=False, training=None, ): if key is None: @@ -79,12 +79,6 @@ def call( query = self._query_dense(query) - # If cache is not `None`, we will use the cache to compute the final key - # and value tensors. If `cache_update_index` is not None, we will first - # update the cache before use. To do this, we first call the - # `_key_dense` and `_value_dense` layers, and copy the outputs into the - # cache at the specified index. `cache = None` handles the training - # case, where we don't use the cache at all. if cache is not None: key_cache = cache[:, 0, ...] value_cache = cache[:, 1, ...] @@ -118,6 +112,11 @@ def call( attention_output = self._output_dense(attention_output) - if cache is not None: + if return_attention_scores and cache is not None: + return attention_output, attention_scores, cache + elif return_attention_scores: + return attention_output, attention_scores + elif cache is not None: return attention_output, cache - return attention_output + else: + return attention_output \ No newline at end of file diff --git a/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py b/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py index 6690667ade..e5351828ae 100644 --- a/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py +++ b/keras_hub/src/layers/modeling/cached_multi_head_attention_test.py @@ -101,3 +101,21 @@ def test_training_propagation(self): attention_output = layer._output_dense(attention_output) self.assertAllClose(outputs, attention_output, atol=1e-5) + + def test_returns_attention_scores(self): + batch_size = 2 + seq_len = 4 + num_heads = 2 + key_dim = 4 + hidden_dim = num_heads * key_dim + + query = random.uniform(shape=(batch_size, seq_len, hidden_dim)) + value = random.uniform(shape=(batch_size, seq_len, hidden_dim)) + + layer = CachedMultiHeadAttention(num_heads=num_heads, key_dim=key_dim) + output, scores = layer(query, value, return_attention_scores=True) + + self.assertEqual(output.shape, (batch_size, seq_len, hidden_dim)) + self.assertIsNotNone(scores) + self.assertEqual(scores.shape[0], batch_size) + self.assertEqual(len(scores.shape), 4) # Expected: (B, H, T, S) \ No newline at end of file