1616
1717import collections
1818import math
19+ from collections .abc import Callable
1920from typing import Optional , Union
2021
2122import torch
2526
2627from ... import initialization as init
2728from ...activations import ACT2FN
29+ from ...masking_utils import create_bidirectional_mask
2830from ...modeling_layers import GradientCheckpointingLayer
2931from ...modeling_outputs import (
3032 BaseModelOutput ,
3133 QuestionAnsweringModelOutput ,
3234 SequenceClassifierOutput ,
3335 TokenClassifierOutput ,
3436)
35- from ...modeling_utils import PreTrainedModel
37+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
38+ from ...processing_utils import Unpack
3639from ...pytorch_utils import apply_chunking_to_forward
3740from ...utils import (
41+ TransformersKwargs ,
3842 auto_docstring ,
3943 logging ,
4044 torch_int ,
@@ -203,6 +207,8 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel):
203207 config : LayoutLMv3Config
204208 base_model_prefix = "layoutlmv3"
205209 input_modalities = ["image" , "text" ]
210+ _supports_flash_attn = True
211+ _supports_sdpa = True
206212
207213 @torch .no_grad ()
208214 def _init_weights (self , module ):
@@ -214,18 +220,80 @@ def _init_weights(self, module):
214220 init .zeros_ (module .pos_embed )
215221
216222
223+ def layoutlmv3_eager_attention_forward (
224+ module : nn .Module ,
225+ query : torch .Tensor ,
226+ key : torch .Tensor ,
227+ value : torch .Tensor ,
228+ attention_mask : Optional [torch .Tensor ],
229+ scaling : Optional [float ] = None ,
230+ dropout : float = 0.0 ,
231+ rel_pos : Optional [torch .Tensor ] = None ,
232+ rel_2d_pos : Optional [torch .Tensor ] = None ,
233+ ** kwargs : Unpack [TransformersKwargs ],
234+ ):
235+ """
236+ LayoutLMv3 eager attention with support for relative position bias and spatial attention bias.
237+ Based on the CogView attention trick for training stability.
238+ """
239+ if scaling is None :
240+ scaling = 1.0 / math .sqrt (query .size (- 1 ))
241+
242+ # Take the dot product between "query" and "key" to get the raw attention scores.
243+ # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
244+ # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
245+ attention_scores = torch .matmul (query / math .sqrt (query .size (- 1 )), key .transpose (- 1 , - 2 ))
246+
247+ # Add relative position bias and spatial attention bias if available
248+ if (
249+ module .has_relative_attention_bias
250+ and module .has_spatial_attention_bias
251+ and rel_pos is not None
252+ and rel_2d_pos is not None
253+ ):
254+ attention_scores = attention_scores + (rel_pos + rel_2d_pos ) / math .sqrt (query .size (- 1 ))
255+ elif module .has_relative_attention_bias and rel_pos is not None :
256+ attention_scores = attention_scores + rel_pos / math .sqrt (query .size (- 1 ))
257+ elif module .has_spatial_attention_bias and rel_2d_pos is not None :
258+ attention_scores = attention_scores + rel_2d_pos / math .sqrt (query .size (- 1 ))
259+
260+ if attention_mask is not None :
261+ # Apply the attention mask
262+ attention_scores = attention_scores + attention_mask
263+
264+ # Normalize the attention scores to probabilities.
265+ # Use the trick of the CogView paper to stabilize training
266+ # https://huggingface.co/papers/2105.13290 Section 2.4
267+ alpha = 32
268+ scaled_attention_scores = attention_scores / alpha
269+ max_value = scaled_attention_scores .amax (dim = (- 1 )).unsqueeze (- 1 )
270+ new_attention_scores = (scaled_attention_scores - max_value ) * alpha
271+ attention_probs = nn .functional .softmax (new_attention_scores , dim = - 1 )
272+
273+ # This is actually dropping out entire tokens to attend to, which might
274+ # seem a bit unusual, but is taken from the original Transformer paper.
275+ attention_probs = nn .functional .dropout (attention_probs , p = dropout , training = module .training )
276+
277+ attn_output = torch .matmul (attention_probs , value )
278+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
279+
280+ return attn_output , attention_probs
281+
282+
217283class LayoutLMv3SelfAttention (nn .Module ):
218- def __init__ (self , config ):
284+ def __init__ (self , config , layer_idx = None ):
219285 super ().__init__ ()
220286 if config .hidden_size % config .num_attention_heads != 0 and not hasattr (config , "embedding_size" ):
221287 raise ValueError (
222288 f"The hidden size ({ config .hidden_size } ) is not a multiple of the number of attention "
223289 f"heads ({ config .num_attention_heads } )"
224290 )
291+ self .config = config
225292
226293 self .num_attention_heads = config .num_attention_heads
227294 self .attention_head_size = int (config .hidden_size / config .num_attention_heads )
228295 self .all_head_size = self .num_attention_heads * self .attention_head_size
296+ self .scaling = self .attention_head_size ** - 0.5
229297
230298 self .query = nn .Linear (config .hidden_size , self .all_head_size )
231299 self .key = nn .Linear (config .hidden_size , self .all_head_size )
@@ -234,18 +302,7 @@ def __init__(self, config):
234302 self .dropout = nn .Dropout (config .attention_probs_dropout_prob )
235303 self .has_relative_attention_bias = config .has_relative_attention_bias
236304 self .has_spatial_attention_bias = config .has_spatial_attention_bias
237-
238- def cogview_attention (self , attention_scores , alpha = 32 ):
239- """
240- https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
241- (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
242- will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
243- cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
244- """
245- scaled_attention_scores = attention_scores / alpha
246- max_value = scaled_attention_scores .amax (dim = (- 1 )).unsqueeze (- 1 )
247- new_attention_scores = (scaled_attention_scores - max_value ) * alpha
248- return nn .Softmax (dim = - 1 )(new_attention_scores )
305+ self .layer_idx = layer_idx
249306
250307 def forward (
251308 self ,
@@ -254,54 +311,45 @@ def forward(
254311 output_attentions = False ,
255312 rel_pos = None ,
256313 rel_2d_pos = None ,
314+ ** kwargs : Unpack [TransformersKwargs ],
257315 ):
258- batch_size , seq_length , _ = hidden_states .shape
259- query_layer = (
260- self .query (hidden_states )
261- .view (batch_size , - 1 , self .num_attention_heads , self .attention_head_size )
262- .transpose (1 , 2 )
263- )
264- key_layer = (
265- self .key (hidden_states )
266- .view (batch_size , - 1 , self .num_attention_heads , self .attention_head_size )
267- .transpose (1 , 2 )
268- )
269- value_layer = (
270- self .value (hidden_states )
271- .view (batch_size , - 1 , self .num_attention_heads , self .attention_head_size )
272- .transpose (1 , 2 )
273- )
274-
275- # Take the dot product between "query" and "key" to get the raw attention scores.
276- # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
277- # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
278- attention_scores = torch .matmul (query_layer / math .sqrt (self .attention_head_size ), key_layer .transpose (- 1 , - 2 ))
279-
280- if self .has_relative_attention_bias and self .has_spatial_attention_bias :
281- attention_scores += (rel_pos + rel_2d_pos ) / math .sqrt (self .attention_head_size )
282- elif self .has_relative_attention_bias :
283- attention_scores += rel_pos / math .sqrt (self .attention_head_size )
284-
285- if attention_mask is not None :
286- # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
287- attention_scores = attention_scores + attention_mask
288-
289- # Normalize the attention scores to probabilities.
290- # Use the trick of the CogView paper to stabilize training
291- attention_probs = self .cogview_attention (attention_scores )
292-
293- # This is actually dropping out entire tokens to attend to, which might
294- # seem a bit unusual, but is taken from the original Transformer paper.
295- attention_probs = self .dropout (attention_probs )
296-
297- context_layer = torch .matmul (attention_probs , value_layer )
298-
299- context_layer = context_layer .permute (0 , 2 , 1 , 3 ).contiguous ()
300- new_context_layer_shape = context_layer .size ()[:- 2 ] + (self .all_head_size ,)
301- context_layer = context_layer .view (* new_context_layer_shape )
316+ input_shape = hidden_states .shape [:- 1 ]
317+ hidden_shape = (* input_shape , - 1 , self .attention_head_size )
318+
319+ # Get query, key, value projections
320+ query_layer = self .query (hidden_states ).view (* hidden_shape ).transpose (1 , 2 )
321+ key_layer = self .key (hidden_states ).view (* hidden_shape ).transpose (1 , 2 )
322+ value_layer = self .value (hidden_states ).view (* hidden_shape ).transpose (1 , 2 )
323+
324+ # Determine attention implementation
325+ attention_interface : Callable = layoutlmv3_eager_attention_forward
326+ use_eager = self .config ._attn_implementation == "eager"
327+
328+ if not use_eager :
329+ # SDPA and Flash Attention don't support custom relative position bias and spatial attention bias
330+ if self .has_relative_attention_bias or self .has_spatial_attention_bias :
331+ raise ValueError (
332+ f"You are using { self .config ._attn_implementation } as attention type. However, LayoutLMv3's "
333+ "relative position bias and spatial attention bias are not compatible with it. "
334+ 'Please load the model with `attn_implementation="eager"`.'
335+ )
336+ attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
302337
303- outputs = (context_layer , attention_probs ) if output_attentions else (context_layer ,)
338+ attn_output , attn_weights = attention_interface (
339+ self ,
340+ query_layer ,
341+ key_layer ,
342+ value_layer ,
343+ attention_mask ,
344+ dropout = 0.0 if not self .training else self .dropout .p ,
345+ scaling = self .scaling ,
346+ rel_pos = rel_pos ,
347+ rel_2d_pos = rel_2d_pos ,
348+ ** kwargs ,
349+ )
350+ attn_output = attn_output .reshape (* input_shape , - 1 ).contiguous ()
304351
352+ outputs = (attn_output , attn_weights ) if output_attentions else (attn_output ,)
305353 return outputs
306354
307355
@@ -320,11 +368,11 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
320368 return hidden_states
321369
322370
323- # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
371+ # Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
324372class LayoutLMv3Attention (nn .Module ):
325- def __init__ (self , config ):
373+ def __init__ (self , config , layer_idx = None ):
326374 super ().__init__ ()
327- self .self = LayoutLMv3SelfAttention (config )
375+ self .self = LayoutLMv3SelfAttention (config , layer_idx = layer_idx )
328376 self .output = LayoutLMv3SelfOutput (config )
329377
330378 def forward (
@@ -334,26 +382,28 @@ def forward(
334382 output_attentions = False ,
335383 rel_pos = None ,
336384 rel_2d_pos = None ,
385+ ** kwargs : Unpack [TransformersKwargs ],
337386 ):
338387 self_outputs = self .self (
339388 hidden_states ,
340389 attention_mask ,
341390 output_attentions ,
342391 rel_pos = rel_pos ,
343392 rel_2d_pos = rel_2d_pos ,
393+ ** kwargs ,
344394 )
345395 attention_output = self .output (self_outputs [0 ], hidden_states )
346396 outputs = (attention_output ,) + self_outputs [1 :] # add attentions if we output them
347397 return outputs
348398
349399
350- # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
400+ # Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
351401class LayoutLMv3Layer (GradientCheckpointingLayer ):
352- def __init__ (self , config ):
402+ def __init__ (self , config , layer_idx = None ):
353403 super ().__init__ ()
354404 self .chunk_size_feed_forward = config .chunk_size_feed_forward
355405 self .seq_len_dim = 1
356- self .attention = LayoutLMv3Attention (config )
406+ self .attention = LayoutLMv3Attention (config , layer_idx = layer_idx )
357407 self .intermediate = LayoutLMv3Intermediate (config )
358408 self .output = LayoutLMv3Output (config )
359409
@@ -364,13 +414,15 @@ def forward(
364414 output_attentions = False ,
365415 rel_pos = None ,
366416 rel_2d_pos = None ,
417+ ** kwargs : Unpack [TransformersKwargs ],
367418 ):
368419 self_attention_outputs = self .attention (
369420 hidden_states ,
370421 attention_mask ,
371422 output_attentions = output_attentions ,
372423 rel_pos = rel_pos ,
373424 rel_2d_pos = rel_2d_pos ,
425+ ** kwargs ,
374426 )
375427 attention_output = self_attention_outputs [0 ]
376428
@@ -393,9 +445,8 @@ class LayoutLMv3Encoder(nn.Module):
393445 def __init__ (self , config ):
394446 super ().__init__ ()
395447 self .config = config
396- self .layer = nn .ModuleList ([LayoutLMv3Layer (config ) for _ in range (config .num_hidden_layers )])
448+ self .layer = nn .ModuleList ([LayoutLMv3Layer (config , layer_idx = i ) for i in range (config .num_hidden_layers )])
397449 self .gradient_checkpointing = False
398-
399450 self .has_relative_attention_bias = config .has_relative_attention_bias
400451 self .has_spatial_attention_bias = config .has_spatial_attention_bias
401452
@@ -803,18 +854,20 @@ def forward(
803854 final_bbox = bbox
804855 if self .config .has_relative_attention_bias :
805856 position_ids = self .embeddings .position_ids [:, : input_shape [1 ]]
806- position_ids = position_ids .expand_as ( input_ids )
857+ position_ids = position_ids .expand ( input_shape )
807858 final_position_ids = position_ids
808859
809- extended_attention_mask : torch .Tensor = self .get_extended_attention_mask (
810- attention_mask , None , device , dtype = embedding_output .dtype
860+ attention_mask = create_bidirectional_mask (
861+ config = self .config ,
862+ input_embeds = embedding_output ,
863+ attention_mask = attention_mask ,
811864 )
812865
813866 encoder_outputs = self .encoder (
814867 embedding_output ,
815868 bbox = final_bbox ,
816869 position_ids = final_position_ids ,
817- attention_mask = extended_attention_mask ,
870+ attention_mask = attention_mask ,
818871 output_attentions = output_attentions ,
819872 output_hidden_states = output_hidden_states ,
820873 return_dict = return_dict ,
0 commit comments