Skip to content

sdpa_mask_recent_torch can cause Torch Compile C++ Errors #42320

@KyleMylonakisProtopia

Description

@KyleMylonakisProtopia

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.57.1
  • Platform: macOS-26.1-arm64-arm-64bit-Mach-O
  • Python version: 3.13.5
  • Huggingface_hub version: 0.36.0
  • Safetensors version: 0.6.2
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0 (NA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: N/A

Who can help?

@vasqu @ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When performing torch.compile(model_forward, dynamic=True) of a Mistral and Llama 3.2 1B model which is using SDPA attention I get the following error on subsequent calls to the forward

E Exception: CppCompileError: C++ compile error
E
E Command:
E clang++ /var/folders/nm/msp2wznx7tx3pn36whjlwjb80000gp/T/torchinductor_kylemylonakis/po/cpoqa5seijx2x3dxbaxk2ahmo37h5m6yfuy2mj5apcq5475ikek3.main.cpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -D CPU_CAPABILITY_NEON -D AT_BUILD_ARM_VEC256_WITH_SLEEF -shared -fPIC -undefined dynamic_lookup -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -Werror=ignored-optimization-argument -Xclang -fopenmp -include /var/folders/nm/msp2wznx7tx3pn36whjlwjb80000gp/T/torchinductor_kylemylonakis/precompiled_headers/cgdc5o2of4y2qxukfmhmxeqi43bvdknueojypevyeprzopw4bssy.h -I/opt/homebrew/opt/[[email protected]](vscode-file://vscode-app/Applications/Visual%20Studio%20Code.app/Contents/Resources/app/out/vs/code/electron-browser/workbench/workbench.html)/Frameworks/Python.framework/Versions/3.13/include/python3.13 -I/Users/kylemylonakis/stained-glass-proxy/.venv/lib/python3.13/site-packages/torch/include -I/Users/kylemylonakis/stained-glass-proxy/.venv/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/opt/homebrew/opt/libomp/include -D_GLIBCXX_USE_CXX11_ABI=1 -o /var/folders/nm/msp2wznx7tx3pn36whjlwjb80000gp/T/torchinductor_kylemylonakis/po/cpoqa5seijx2x3dxbaxk2ahmo37h5m6yfuy2mj5apcq5475ikek3.main.so -lomp -L/opt/homebrew/opt/[[email protected]](vscode-file://vscode-app/Applications/Visual%20Studio%20Code.app/Contents/Resources/app/out/vs/code/electron-browser/workbench/workbench.html)/Frameworks/Python.framework/Versions/3.13/lib -L/Users/kylemylonakis/stained-glass-proxy/.venv/lib/python3.13/site-packages/torch/lib -L/opt/homebrew/opt/libomp/lib
E
E Output:
E /var/folders/nm/msp2wznx7tx3pn36whjlwjb80000gp/T/torchinductor_kylemylonakis/po/cpoqa5seijx2x3dxbaxk2ahmo37h5m6yfuy2mj5apcq5475ikek3.main.cpp:95:66: error: use of undeclared identifier 'tmp2'
E 95 | TORCH_CHECK((at::vec::VecMask<int64_t,2>(tmp2 < at::vec::VectorizedN<int64_t,2>(ks1))).all_masked(), "index out of bounds: tmp2 < ks1");
E | ^
E /var/folders/nm/msp2wznx7tx3pn36whjlwjb80000gp/T/torchinductor_kylemylonakis/po/cpoqa5seijx2x3dxbaxk2ahmo37h5m6yfuy2mj5apcq5475ikek3.main.cpp:134:41: error: use of undeclared identifier 'tmp1'
E 134 | TORCH_CHECK(tmp1 < ks1, "index out of bounds: tmp1 < ks1");
E | ^
E 2 errors generated.

After doing significant deep diving in torch compile/dynamo/inductor, I found that this is a bug in torch compile that has been surfaced due to recent changes in sdpa_mask_recent_torch. Using copilot, I was able to resolve the issue by introducing some changes with how the masking was functioning:

def sdpa_mask_recent_torch(
    batch_size: int,
    cache_position: torch.Tensor,
    kv_length: int,
    kv_offset: int = 0,
    mask_function: Callable = causal_mask_function,
    attention_mask: torch.Tensor | None = None,
    local_size: int | None = None,
    allow_is_causal_skip: bool = True,
    **kwargs,
) -> torch.Tensor | None:
    q_length = cache_position.shape[0]
    # Potentially pad the 2D mask, and slice it correctly
    padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)

    # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
    if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
        return None

    # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
    # but without data-dependent slicing (i.e. torch.compile friendly)
    kv_arange = torch.arange(kv_length, device=cache_position.device)
    kv_arange += kv_offset

    # Potentially add the padding 2D mask
    compiling = is_torchdynamo_compiling()
    if padding_mask is not None and not compiling:
        mask_function = and_masks(mask_function, padding_mask_function(padding_mask))

    batch_arange = torch.arange(batch_size, device=cache_position.device)
    head_arange = torch.arange(1, device=cache_position.device)
    # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
    # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
    # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
    with TransformGetItemToIndex():
        causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
        if compiling and padding_mask is not None:
            # padding_mask is [batch, s], broadcast to [batch,1,q,s]
            causal_mask = causal_mask & padding_mask[:, None, None, :]

    return causal_mask

This resolves the torch compile issue, but I'm not sure how robust this solution is. Given this is fairly important code, I was wondering if it could be reviewed for correctness or improved by you all.

This C++ torch compile error is very hard to reproduce. I have only seen it when running multiple Pytest AsyncIO concurrently, and even then, running just a single suite of tests at a time does not cause the bug to surface, there need to be multiple tests launched in order for the error to occur. I am currently trying to get a minimum reproducible example of the bug, but I am not sure.

Expected behavior

No torch.compile C++ error should be raised.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions