Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,6 @@ def sdpa_mask(
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`torch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
Comment on lines -343 to -345
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doubled comment, unrelated but should be removed nonetheless. Slipped through

allow_is_bidirectional_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
i.e. full attention without any padding. Default to `False`.
Expand Down Expand Up @@ -480,6 +477,7 @@ def eager_mask(
mask_function: Callable = causal_mask_function,
attention_mask: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
allow_is_bidirectional_skip: bool = False,
use_vmap: bool = False,
**kwargs,
) -> torch.Tensor:
Expand All @@ -503,13 +501,15 @@ def eager_mask(
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
dtype (`torch.dtype`, optional):
The dtype to use for the mask. By default, `torch.float32`.
allow_is_bidirectional_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
i.e. full attention without any padding. Default to `False`.
use_vmap (`bool`, optional):
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
index-based (for the cost of speed performance). By default `False`.
"""
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
_ = kwargs.pop("allow_is_causal_skip", None)
_ = kwargs.pop("allow_is_bidirectional_skip", None)
_ = kwargs.pop("allow_torch_fix", None)
mask = sdpa_mask(
batch_size=batch_size,
Expand All @@ -519,14 +519,16 @@ def eager_mask(
mask_function=mask_function,
attention_mask=attention_mask,
allow_is_causal_skip=False,
allow_is_bidirectional_skip=False,
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
allow_torch_fix=False,
use_vmap=use_vmap,
**kwargs,
)
min_dtype = torch.finfo(dtype).min
# we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
# only bidirectional masks can be skipped, otherwise we convert bool -> float
if mask is not None:
min_dtype = torch.finfo(dtype).min
# we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
return mask


Expand Down
59 changes: 39 additions & 20 deletions tests/utils/test_masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
cleanup,
is_torch_available,
require_torch,
require_torch_accelerator,
torch_device,
)

Expand Down Expand Up @@ -262,30 +261,25 @@ def test_chunked_mask_with_left_padding_decoding(self):

self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())

@require_torch_accelerator
def test_bidirectional_mask_cudagraphs(self):
"""
Checks whether the bidirectional mask creation is compatible with cuda graphs, i.e. we do not into any error
during this test.
"""

@staticmethod
def _run_bidirectional_mask(mask_fn, attn_implementation):
def run_mask_creation(mask_fn, config, input_embeds, encoder_mask, cross_mask, encoder_hidden_states):
_ = mask_fn(
encoder_attn_mask = mask_fn(
config=config,
input_embeds=input_embeds,
attention_mask=encoder_mask,
)

_ = mask_fn(
cross_attn_mask = mask_fn(
config=config,
input_embeds=input_embeds,
attention_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)
return encoder_attn_mask, cross_attn_mask

# We use llama but could be also bert/bart --> we only need the `_attn_implementation` here
config = LlamaConfig()
config._attn_implementation = "sdpa"
config._attn_implementation = attn_implementation

# Meta data
batch_size = 2
Expand All @@ -298,19 +292,17 @@ def run_mask_creation(mask_fn, config, input_embeds, encoder_mask, cross_mask, e
encoder_mask = torch.ones_like(input_embeds)[..., 0]
cross_mask = torch.ones_like(encoder_hidden_states)[..., 0]

mask_creation_function = torch.compile(create_bidirectional_mask, mode="reduce-overhead")

# Case 1: Full mask
run_mask_creation(
mask_fn=mask_creation_function,
full_mask_encoder_1, full_mask_cross_1 = run_mask_creation(
mask_fn=mask_fn,
config=config,
input_embeds=input_embeds,
encoder_mask=encoder_mask,
cross_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)
run_mask_creation(
mask_fn=mask_creation_function,
full_mask_encoder_2, full_mask_cross_2 = run_mask_creation(
mask_fn=mask_fn,
config=config,
input_embeds=input_embeds,
encoder_mask=None,
Expand All @@ -322,11 +314,38 @@ def run_mask_creation(mask_fn, config, input_embeds, encoder_mask, cross_mask, e
cross_mask[:, -1] = 0
encoder_mask[:, -1] = 0

run_mask_creation(
mask_fn=mask_creation_function,
padded_mask_encoder, padded_mask_cross = run_mask_creation(
mask_fn=mask_fn,
config=config,
input_embeds=input_embeds,
encoder_mask=encoder_mask,
cross_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)

full_masks = (full_mask_encoder_1, full_mask_encoder_2), (full_mask_cross_1, full_mask_cross_2)
padded_masks = (padded_mask_encoder, padded_mask_cross)
return full_masks, padded_masks

def test_bidirectional_mask_cudagraphs(self):
"""
Checks whether the bidirectional mask creation is compatible with cuda graphs, i.e. we do not into any error
during this test.
"""
mask_creation_function = torch.compile(create_bidirectional_mask, mode="reduce-overhead")
self._run_bidirectional_mask(mask_fn=mask_creation_function, attn_implementation="sdpa")

def test_bidirectional_mask_skip_eager(self):
"""
Checks whether the bidirectional mask creation can skip the mask creation if we have a full mask.
"""
full_masks, padded_mask = self._run_bidirectional_mask(
mask_fn=create_bidirectional_mask, attn_implementation="eager"
)

for alternative_masks in full_masks:
self.assertTrue(alternative_masks[0] is None)
self.assertTrue(alternative_masks[1] is None)

self.assertTrue(padded_mask[0] is not None)
self.assertTrue(padded_mask[1] is not None)