Skip to content

Commit 69b372a

Browse files
committed
Fix missing case for spatial-only attention bias in layoutlmv3_eager_attention_forward
1 parent 7f74c31 commit 69b372a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/transformers/models/layoutlmv3/modeling_layoutlmv3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ def layoutlmv3_eager_attention_forward(
254254
attention_scores = attention_scores + (rel_pos + rel_2d_pos) / math.sqrt(query.size(-1))
255255
elif module.has_relative_attention_bias and rel_pos is not None:
256256
attention_scores = attention_scores + rel_pos / math.sqrt(query.size(-1))
257+
elif module.has_spatial_attention_bias and rel_2d_pos is not None:
258+
attention_scores = attention_scores + rel_2d_pos / math.sqrt(query.size(-1))
257259

258260
if attention_mask is not None:
259261
# Apply the attention mask

0 commit comments

Comments
 (0)