Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,6 @@ def sdpa_mask(
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`torch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
Comment on lines -343 to -345
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doubled comment, unrelated but should be removed nonetheless. Slipped through

allow_is_bidirectional_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
i.e. full attention without any padding. Default to `False`.
Expand Down Expand Up @@ -480,6 +477,7 @@ def eager_mask(
mask_function: Callable = causal_mask_function,
attention_mask: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
allow_is_bidirectional_skip: bool = False,
use_vmap: bool = False,
**kwargs,
) -> torch.Tensor:
Expand All @@ -503,13 +501,15 @@ def eager_mask(
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
dtype (`torch.dtype`, optional):
The dtype to use for the mask. By default, `torch.float32`.
allow_is_bidirectional_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
i.e. full attention without any padding. Default to `False`.
use_vmap (`bool`, optional):
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
index-based (for the cost of speed performance). By default `False`.
"""
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
_ = kwargs.pop("allow_is_causal_skip", None)
_ = kwargs.pop("allow_is_bidirectional_skip", None)
_ = kwargs.pop("allow_torch_fix", None)
mask = sdpa_mask(
batch_size=batch_size,
Expand All @@ -519,7 +519,7 @@ def eager_mask(
mask_function=mask_function,
attention_mask=attention_mask,
allow_is_causal_skip=False,
allow_is_bidirectional_skip=False,
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
allow_torch_fix=False,
use_vmap=use_vmap,
**kwargs,
Expand Down
Loading