Skip to content

Commit 26aa046

Browse files
committed
Add SDPA and FlashAttention-2 support for LayoutLMv3
- Implement unified attention interface following BERT pattern - Add layoutlmv3_eager_attention_forward with support for relative position bias and spatial attention bias - Add support flags _supports_flash_attn and _supports_sdpa - Update attention classes to use unified interface - Automatically set _attn_implementation='eager' when relative/spatial biases are enabled in config - Fix test configurations to use eager attention by default - Override incompatible SDPA/FlashAttention tests with skipTest - Fix missing case for spatial-only attention bias handling - Fix position_ids expansion to support inputs_embeds - Replace get_extended_attention_mask with create_bidirectional_mask Fixes #35467
1 parent 66d5711 commit 26aa046

File tree

3 files changed

+284
-72
lines changed

3 files changed

+284
-72
lines changed

src/transformers/models/layoutlmv3/configuration_layoutlmv3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,12 @@ def __init__(
174174
self.patch_size = patch_size
175175
self.classifier_dropout = classifier_dropout
176176

177+
# LayoutLMv3's relative position bias and spatial attention bias are incompatible with SDPA/FlashAttention
178+
# Automatically set eager attention when these biases are enabled, unless explicitly set by user
179+
if has_relative_attention_bias or has_spatial_attention_bias:
180+
# Only set if not already explicitly set via kwargs (attn_implementation is processed in super().__init__)
181+
if self._attn_implementation is None:
182+
self._attn_implementation = "eager"
183+
177184

178185
__all__ = ["LayoutLMv3Config"]

src/transformers/models/layoutlmv3/modeling_layoutlmv3.py

Lines changed: 124 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import collections
1818
import math
19+
from collections.abc import Callable
1920
from typing import Optional, Union
2021

2122
import torch
@@ -25,16 +26,19 @@
2526

2627
from ... import initialization as init
2728
from ...activations import ACT2FN
29+
from ...masking_utils import create_bidirectional_mask
2830
from ...modeling_layers import GradientCheckpointingLayer
2931
from ...modeling_outputs import (
3032
BaseModelOutput,
3133
QuestionAnsweringModelOutput,
3234
SequenceClassifierOutput,
3335
TokenClassifierOutput,
3436
)
35-
from ...modeling_utils import PreTrainedModel
37+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38+
from ...processing_utils import Unpack
3639
from ...pytorch_utils import apply_chunking_to_forward
3740
from ...utils import (
41+
TransformersKwargs,
3842
auto_docstring,
3943
logging,
4044
torch_int,
@@ -203,6 +207,8 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel):
203207
config: LayoutLMv3Config
204208
base_model_prefix = "layoutlmv3"
205209
input_modalities = ["image", "text"]
210+
_supports_flash_attn = True
211+
_supports_sdpa = True
206212

207213
@torch.no_grad()
208214
def _init_weights(self, module):
@@ -214,18 +220,80 @@ def _init_weights(self, module):
214220
init.zeros_(module.pos_embed)
215221

216222

223+
def layoutlmv3_eager_attention_forward(
224+
module: nn.Module,
225+
query: torch.Tensor,
226+
key: torch.Tensor,
227+
value: torch.Tensor,
228+
attention_mask: Optional[torch.Tensor],
229+
scaling: Optional[float] = None,
230+
dropout: float = 0.0,
231+
rel_pos: Optional[torch.Tensor] = None,
232+
rel_2d_pos: Optional[torch.Tensor] = None,
233+
**kwargs: Unpack[TransformersKwargs],
234+
):
235+
"""
236+
LayoutLMv3 eager attention with support for relative position bias and spatial attention bias.
237+
Based on the CogView attention trick for training stability.
238+
"""
239+
if scaling is None:
240+
scaling = 1.0 / math.sqrt(query.size(-1))
241+
242+
# Take the dot product between "query" and "key" to get the raw attention scores.
243+
# The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
244+
# Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
245+
attention_scores = torch.matmul(query / math.sqrt(query.size(-1)), key.transpose(-1, -2))
246+
247+
# Add relative position bias and spatial attention bias if available
248+
if (
249+
module.has_relative_attention_bias
250+
and module.has_spatial_attention_bias
251+
and rel_pos is not None
252+
and rel_2d_pos is not None
253+
):
254+
attention_scores = attention_scores + (rel_pos + rel_2d_pos) / math.sqrt(query.size(-1))
255+
elif module.has_relative_attention_bias and rel_pos is not None:
256+
attention_scores = attention_scores + rel_pos / math.sqrt(query.size(-1))
257+
elif module.has_spatial_attention_bias and rel_2d_pos is not None:
258+
attention_scores = attention_scores + rel_2d_pos / math.sqrt(query.size(-1))
259+
260+
if attention_mask is not None:
261+
# Apply the attention mask
262+
attention_scores = attention_scores + attention_mask
263+
264+
# Normalize the attention scores to probabilities.
265+
# Use the trick of the CogView paper to stabilize training
266+
# https://huggingface.co/papers/2105.13290 Section 2.4
267+
alpha = 32
268+
scaled_attention_scores = attention_scores / alpha
269+
max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
270+
new_attention_scores = (scaled_attention_scores - max_value) * alpha
271+
attention_probs = nn.functional.softmax(new_attention_scores, dim=-1)
272+
273+
# This is actually dropping out entire tokens to attend to, which might
274+
# seem a bit unusual, but is taken from the original Transformer paper.
275+
attention_probs = nn.functional.dropout(attention_probs, p=dropout, training=module.training)
276+
277+
attn_output = torch.matmul(attention_probs, value)
278+
attn_output = attn_output.transpose(1, 2).contiguous()
279+
280+
return attn_output, attention_probs
281+
282+
217283
class LayoutLMv3SelfAttention(nn.Module):
218-
def __init__(self, config):
284+
def __init__(self, config, layer_idx=None):
219285
super().__init__()
220286
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
221287
raise ValueError(
222288
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
223289
f"heads ({config.num_attention_heads})"
224290
)
291+
self.config = config
225292

226293
self.num_attention_heads = config.num_attention_heads
227294
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
228295
self.all_head_size = self.num_attention_heads * self.attention_head_size
296+
self.scaling = self.attention_head_size**-0.5
229297

230298
self.query = nn.Linear(config.hidden_size, self.all_head_size)
231299
self.key = nn.Linear(config.hidden_size, self.all_head_size)
@@ -234,18 +302,7 @@ def __init__(self, config):
234302
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
235303
self.has_relative_attention_bias = config.has_relative_attention_bias
236304
self.has_spatial_attention_bias = config.has_spatial_attention_bias
237-
238-
def cogview_attention(self, attention_scores, alpha=32):
239-
"""
240-
https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
241-
(PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
242-
will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
243-
cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
244-
"""
245-
scaled_attention_scores = attention_scores / alpha
246-
max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
247-
new_attention_scores = (scaled_attention_scores - max_value) * alpha
248-
return nn.Softmax(dim=-1)(new_attention_scores)
305+
self.layer_idx = layer_idx
249306

250307
def forward(
251308
self,
@@ -254,54 +311,45 @@ def forward(
254311
output_attentions=False,
255312
rel_pos=None,
256313
rel_2d_pos=None,
314+
**kwargs: Unpack[TransformersKwargs],
257315
):
258-
batch_size, seq_length, _ = hidden_states.shape
259-
query_layer = (
260-
self.query(hidden_states)
261-
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
262-
.transpose(1, 2)
263-
)
264-
key_layer = (
265-
self.key(hidden_states)
266-
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
267-
.transpose(1, 2)
268-
)
269-
value_layer = (
270-
self.value(hidden_states)
271-
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
272-
.transpose(1, 2)
273-
)
274-
275-
# Take the dot product between "query" and "key" to get the raw attention scores.
276-
# The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
277-
# Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
278-
attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
279-
280-
if self.has_relative_attention_bias and self.has_spatial_attention_bias:
281-
attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
282-
elif self.has_relative_attention_bias:
283-
attention_scores += rel_pos / math.sqrt(self.attention_head_size)
284-
285-
if attention_mask is not None:
286-
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
287-
attention_scores = attention_scores + attention_mask
288-
289-
# Normalize the attention scores to probabilities.
290-
# Use the trick of the CogView paper to stabilize training
291-
attention_probs = self.cogview_attention(attention_scores)
292-
293-
# This is actually dropping out entire tokens to attend to, which might
294-
# seem a bit unusual, but is taken from the original Transformer paper.
295-
attention_probs = self.dropout(attention_probs)
296-
297-
context_layer = torch.matmul(attention_probs, value_layer)
298-
299-
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300-
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
301-
context_layer = context_layer.view(*new_context_layer_shape)
316+
input_shape = hidden_states.shape[:-1]
317+
hidden_shape = (*input_shape, -1, self.attention_head_size)
318+
319+
# Get query, key, value projections
320+
query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
321+
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
322+
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
323+
324+
# Determine attention implementation
325+
attention_interface: Callable = layoutlmv3_eager_attention_forward
326+
use_eager = self.config._attn_implementation == "eager"
327+
328+
if not use_eager:
329+
# SDPA and Flash Attention don't support custom relative position bias and spatial attention bias
330+
if self.has_relative_attention_bias or self.has_spatial_attention_bias:
331+
raise ValueError(
332+
f"You are using {self.config._attn_implementation} as attention type. However, LayoutLMv3's "
333+
"relative position bias and spatial attention bias are not compatible with it. "
334+
'Please load the model with `attn_implementation="eager"`.'
335+
)
336+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
302337

303-
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
338+
attn_output, attn_weights = attention_interface(
339+
self,
340+
query_layer,
341+
key_layer,
342+
value_layer,
343+
attention_mask,
344+
dropout=0.0 if not self.training else self.dropout.p,
345+
scaling=self.scaling,
346+
rel_pos=rel_pos,
347+
rel_2d_pos=rel_2d_pos,
348+
**kwargs,
349+
)
350+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
304351

352+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
305353
return outputs
306354

307355

@@ -320,11 +368,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
320368
return hidden_states
321369

322370

323-
# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
371+
# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
324372
class LayoutLMv3Attention(nn.Module):
325-
def __init__(self, config):
373+
def __init__(self, config, layer_idx=None):
326374
super().__init__()
327-
self.self = LayoutLMv3SelfAttention(config)
375+
self.self = LayoutLMv3SelfAttention(config, layer_idx=layer_idx)
328376
self.output = LayoutLMv3SelfOutput(config)
329377

330378
def forward(
@@ -334,26 +382,28 @@ def forward(
334382
output_attentions=False,
335383
rel_pos=None,
336384
rel_2d_pos=None,
385+
**kwargs: Unpack[TransformersKwargs],
337386
):
338387
self_outputs = self.self(
339388
hidden_states,
340389
attention_mask,
341390
output_attentions,
342391
rel_pos=rel_pos,
343392
rel_2d_pos=rel_2d_pos,
393+
**kwargs,
344394
)
345395
attention_output = self.output(self_outputs[0], hidden_states)
346396
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
347397
return outputs
348398

349399

350-
# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
400+
# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
351401
class LayoutLMv3Layer(GradientCheckpointingLayer):
352-
def __init__(self, config):
402+
def __init__(self, config, layer_idx=None):
353403
super().__init__()
354404
self.chunk_size_feed_forward = config.chunk_size_feed_forward
355405
self.seq_len_dim = 1
356-
self.attention = LayoutLMv3Attention(config)
406+
self.attention = LayoutLMv3Attention(config, layer_idx=layer_idx)
357407
self.intermediate = LayoutLMv3Intermediate(config)
358408
self.output = LayoutLMv3Output(config)
359409

@@ -364,13 +414,15 @@ def forward(
364414
output_attentions=False,
365415
rel_pos=None,
366416
rel_2d_pos=None,
417+
**kwargs: Unpack[TransformersKwargs],
367418
):
368419
self_attention_outputs = self.attention(
369420
hidden_states,
370421
attention_mask,
371422
output_attentions=output_attentions,
372423
rel_pos=rel_pos,
373424
rel_2d_pos=rel_2d_pos,
425+
**kwargs,
374426
)
375427
attention_output = self_attention_outputs[0]
376428

@@ -393,9 +445,8 @@ class LayoutLMv3Encoder(nn.Module):
393445
def __init__(self, config):
394446
super().__init__()
395447
self.config = config
396-
self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
448+
self.layer = nn.ModuleList([LayoutLMv3Layer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
397449
self.gradient_checkpointing = False
398-
399450
self.has_relative_attention_bias = config.has_relative_attention_bias
400451
self.has_spatial_attention_bias = config.has_spatial_attention_bias
401452

@@ -803,18 +854,20 @@ def forward(
803854
final_bbox = bbox
804855
if self.config.has_relative_attention_bias:
805856
position_ids = self.embeddings.position_ids[:, : input_shape[1]]
806-
position_ids = position_ids.expand_as(input_ids)
857+
position_ids = position_ids.expand(input_shape)
807858
final_position_ids = position_ids
808859

809-
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
810-
attention_mask, None, device, dtype=embedding_output.dtype
860+
attention_mask = create_bidirectional_mask(
861+
config=self.config,
862+
input_embeds=embedding_output,
863+
attention_mask=attention_mask,
811864
)
812865

813866
encoder_outputs = self.encoder(
814867
embedding_output,
815868
bbox=final_bbox,
816869
position_ids=final_position_ids,
817-
attention_mask=extended_attention_mask,
870+
attention_mask=attention_mask,
818871
output_attentions=output_attentions,
819872
output_hidden_states=output_hidden_states,
820873
return_dict=return_dict,

0 commit comments

Comments
 (0)