Skip to content

Performance regression: allow_is_causal_skip incorrectly disabled when use_cache=False #41856

@williamsnell

Description

@williamsnell

System Info

  • transformers version: 4.54.0.dev0
  • Platform: Linux-6.6.105+-x86_64-with-glibc2.35
  • Python version: 3.12.12
  • Huggingface_hub version: 0.35.3
  • Safetensors version: 0.6.2
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cu126 (CUDA)
  • Tensorflow version (GPU?): 2.19.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.10.7 (gpu)
  • Jax version: 0.7.2
  • JaxLib version: 0.7.2
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: Tesla T4

Who can help?

@Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Full colab reproduction: https://colab.research.google.com/drive/1vQ_UgSKMIlKMZGI7AM4DjqbWtYRTunIE#scrollTo=SuMxmbBBj4zx

Otherwise:

  1. install transformers <= 4.53.0
  2. run pretraining with use_cache=False with, for example, gpt_neox (using SDPA as the attention backend)
  3. compare iteration time against transformers >= 4.53.1 There should be a ~10-30% performance regression, depending on the exact model, context length, etc.

Expected behavior

I noticed a ~25% drop in throughput when pretraining a variant of EleutherAI/pythia-14m, after upgrading transformers. I bisected the commit that caused this drop to 0cf27916, which introduced packed tensor masks.

The issue seems to be that allow_is_causal_skip gets set to False whenever use_cache=False. This prevents SDPA using the fast-path when we don't provide an attention mask and are using causal attention.

Looking in create_causal_mask, allow_is_causal_skip is disabled because packed_sequence_mask is not None:

# ---------- def create_causal_mask - masking_utils.py:877 -------------
    # If we detected packing format
    if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
        mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
        allow_is_causal_skip = False

packed_sequence_mask comes from _preprocess_mask_arguments. In _preprocess_mask_arguments, we see the following:

  • find_packed_sequence_indices always returns a Tensor
  • If we enter this block, we will always set allow_is_causal_skip to False, regardless of if a packed sequence is actually detected.

(src)

    # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
    # and we don't have past_key_values, i.e. generally a training setup)
    packed_sequence_mask = None
    if position_ids is not None and attention_mask is None and past_key_values is None:
        batch_size = input_embeds.shape[0]
        # The position ids are sometimes just unsqueezed, without being expanded
        if batch_size != position_ids.shape[0]:
            position_ids = position_ids.expand(batch_size, -1)
        packed_sequence_mask = find_packed_sequence_indices(position_ids)

    return False, attention_mask, packed_sequence_mask, kv_length, kv_offset

I'm happy to put in a pull request with the following change:

    # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
    # and we don't have past_key_values, i.e. generally a training setup)
    packed_sequence_mask = None
    if position_ids is not None and attention_mask is None and past_key_values is None:
        batch_size = input_embeds.shape[0]
        # The position ids are sometimes just unsqueezed, without being expanded
        if batch_size != position_ids.shape[0]:
            position_ids = position_ids.expand(batch_size, -1)
        packed_sequence_mask = find_packed_sequence_indices(position_ids)
        
+       # Only return the mask if we detected any packed sequences.
+       if (packed_sequence_mask[:, -1] == 0).all():
+           packed_sequence_mask = None

However, the reason this is an Issue and not a PR is that in the source code, there's this comment:

    # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
    # but it causes issues with export
    return packed_sequence_mask

I presume my proposed change would also run afoul of export, and it's not clear to me how to resolve this.

I'm very happy to resubmit this as a PR.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions