-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Description
System Info
transformersversion: 4.57.1- Platform: Linux-5.10.233-223.887.amzn2.x86_64-x86_64-with-glibc2.26
- Python version: 3.10.19
- Huggingface_hub version: 0.36.0
- Safetensors version: 0.6.2
- Accelerate version: not installed
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.6.0+cu124 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA H100 80GB HBM3
Who can help?
@vasqu @ArthurZucker @Cyrilvallez When using a custom attention function registered via the new AttentionInterface and selecting it with attn_implementation="<custom_name>", passing output_attentions=True to model.forward(...) triggers a UserWarning like:
UserWarning:
output_attentions=Trueis not supported withattn_implementationother than ['eager', 'eager_paged', 'flex_attention']. Please usemodel.set_attn_implementation('eager')to enable capturing attention outputs.
This warning is misleading for custom backends that do compute and return attention probabilities (same shape as eager). In addition, some models still set outputs.attentions=None unless the implementation name is exactly "eager", even though the custom backend returns (attn_output, attn_probs).
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
The below code snippet triggers the undesirable UserWarning.
import torch
import torch.nn as nn
from transformers.models.esm.modeling_esm import TransformersKwargs
from typing import Optional
from transformers import AutoModel, AttentionInterface
def eager_with_bias_attention_forward(
module: nn.Module,
query: torch.Tensor, # [B, H, T, D]
key: torch.Tensor, # [B, H, S, D]
value: torch.Tensor, # [B, H, S, D]
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
**kwargs: TransformersKwargs,
):
"""
Adds `attention_bias` (broadcastable to [B, H or 1, T, S]) to logits before softmax.
Pass it via model(..., attention_bias=your_bias).
"""
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling # [B, H, T, S]
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
# Add the bias matrix to the attention weights
attention_bias = kwargs.get("attention_bias", None)
if attention_bias is not None:
# allow [B, 1, T, S], [B, H, T, S], or [1, 1, T, S]; truncate S if needed
if attention_bias.size(-1) != key.shape[-2]:
attention_bias = attention_bias[..., : key.shape[-2]]
attention_bias = attention_bias.to(
dtype=attn_weights.dtype, device=attn_weights.device
)
attn_weights = attn_weights + attention_bias
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Register custom attention implementation
AttentionInterface.register("eager_with_bias", eager_with_bias_attention_forward)
# Load ESM2 model using the custom attention backend
model = AutoModel.from_pretrained(
"facebook/esm2_t33_650M_UR50D",
token_dropout=False,
local_files_only=True,
attn_implementation="eager_with_bias",
)
# --- dummy batch ---
B, T = 2, 64
hidden_size = model.config.hidden_size
H = model.config.num_attention_heads
# inputs_embeds must be [B, T, hidden_size]
emb = torch.randn(B, T, hidden_size, device=next(model.parameters()).device)
# attention_mask must be [B, T] with 1 for tokens you want to keep
attention_mask = torch.ones(B, T, dtype=torch.long, device=emb.device)
# bias should broadcast to [B, H, T, T]; using shared-across-heads:
attention_bias = torch.zeros(B, 1, T, T, device=emb.device)
# Triggers a UserWarning even though backend returns attention weights,
# and some models set outputs.attentions = None unless impl == "eager".
out = model(
inputs_embeds=emb,
attention_mask=attention_mask,
output_attentions=True,
attention_bias=attention_bias,
)
# Check if attention weights are being returned
assert (
out.attentions is not None and len(out.attentions) == model.config.num_hidden_layers
)
print("OK: got attention weights from custom backend")Expected behavior
Expected behavior
- If the selected attention backend returns attention probabilities,
outputs.attentionsshould be populated and no warning should be emitted. - The warning (or error) should trigger only when the chosen backend cannot provide attention probabilities.
Actual behavior
- A UserWarning is emitted whenever
attn_implementation != "eager", regardless of whether the custom backend supports returning attention weights. - In some models,
outputs.attentionsisNoneunless the implementation name is literally"eager".
Where this comes from / related context
- There’s an “early-error if
output_attentions=Trueand impl isn’t eager” change discussed in PR #38288 (config path). - The Attention Interface docs show how to register/select custom implementations and say extra kwargs are forwarded to the attention function, but they don’t document a way to declare that a custom backend supports returning attentions.
Proposed solutions
1. Capability flag on backends
Extend AttentionInterface.register(name, fn, supports_attn_probs: bool = False) (or use a small descriptor object) so model code can check capability instead of name equality.
If supports_attn_probs=True, allow output_attentions=True without warnings and surface the returned probabilities.
2. Name-agnostic check
Replace impl != "eager" string checks with an interface query like AttentionInterface.supports_attn_probs(impl) to decide warning/error behavior, so custom backends that return weights aren’t penalized.
3. Documented workaround
If changing the check is not desirable, document an official way to declare a custom backend as “eager-compatible,” or provide a supported alias/registration API that treats a custom backend like "eager" for the purpose of attention-weight return (avoiding the need for users to override "eager" globally just to silence the warning).