Skip to content

Commit 4cb41ad

Browse files
authored
[tests] re-enable aria fast tests (#40846)
* rise from the dead * test
1 parent ef05393 commit 4cb41ad

File tree

5 files changed

+44
-73
lines changed

5 files changed

+44
-73
lines changed

src/transformers/models/idefics3/modeling_idefics3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ def forward(
481481
self,
482482
pixel_values,
483483
patch_attention_mask: Optional[torch.BoolTensor] = None,
484+
**kwargs: Unpack[TransformersKwargs],
484485
) -> Union[tuple, BaseModelOutput]:
485486
batch_size = pixel_values.size(0)
486487
if patch_attention_mask is None:

src/transformers/models/smolvlm/modeling_smolvlm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def forward(
368368
self,
369369
pixel_values,
370370
patch_attention_mask: Optional[torch.BoolTensor] = None,
371+
**kwargs: Unpack[TransformersKwargs],
371372
) -> Union[tuple, BaseModelOutput]:
372373
batch_size = pixel_values.size(0)
373374
if patch_attention_mask is None:

src/transformers/utils/generic.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -813,12 +813,11 @@ def wrapper(*args, **kwargs):
813813

814814
class TransformersKwargs(TypedDict, total=False):
815815
"""
816-
Keyword arguments to be passed to the loss function
816+
Keyword arguments to be passed to the forward pass of a `PreTrainedModel`.
817817
818818
Attributes:
819819
num_items_in_batch (`Optional[torch.Tensor]`, *optional*):
820-
Number of items in the batch. It is recommended to pass it when
821-
you are doing gradient accumulation.
820+
Number of items in the batch. It is recommended to pass it when you are doing gradient accumulation.
822821
output_hidden_states (`Optional[bool]`, *optional*):
823822
Most of the models support outputting all hidden states computed during the forward pass.
824823
output_attentions (`Optional[bool]`, *optional*):
@@ -1059,7 +1058,22 @@ def wrapped_forward(*args, **kwargs):
10591058
module.forward = make_capture_wrapper(module, original_forward, key, specs.index)
10601059
monkey_patched_layers.append((module, original_forward))
10611060

1062-
outputs = func(self, *args, **kwargs)
1061+
try:
1062+
outputs = func(self, *args, **kwargs)
1063+
except TypeError as original_exception:
1064+
# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
1065+
# Get a TypeError even after removing the recordable kwargs -> re-raise the original exception
1066+
# Otherwise -> we're probably missing `**kwargs` in the decorated function
1067+
kwargs_without_recordable = {k: v for k, v in kwargs.items() if k not in recordable_keys}
1068+
try:
1069+
outputs = func(self, *args, **kwargs_without_recordable)
1070+
except TypeError:
1071+
raise original_exception
1072+
raise TypeError(
1073+
"Missing `**kwargs` in the signature of the `@check_model_inputs`-decorated function "
1074+
f"({func.__qualname__})"
1075+
)
1076+
10631077
# Restore original forward methods
10641078
for module, original_forward in monkey_patched_layers:
10651079
module.forward = original_forward

tests/generation/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,12 +2131,12 @@ def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, c
21312131
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
21322132
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
21332133

2134-
# (batch, head, seq_length, head_features)
2134+
# (batch, # kv heads, seq_length, head_features)
21352135
expected_shape = (
21362136
batch_size,
2137-
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
2137+
getattr(config, "num_key_value_heads", None) or config.num_attention_heads,
21382138
cache_length,
2139-
config.hidden_size // config.num_attention_heads,
2139+
getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads,
21402140
)
21412141

21422142
if isinstance(decoder_past_key_values, Cache):

tests/models/aria/test_modeling_aria.py

Lines changed: 21 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import unittest
1717

18-
import pytest
1918
import requests
2019

2120
from transformers import (
@@ -61,6 +60,10 @@ class AriaVisionText2TextModelTester:
6160
def __init__(
6261
self,
6362
parent,
63+
batch_size=13,
64+
num_channels=3,
65+
image_size=16,
66+
num_image_tokens=4,
6467
ignore_index=-100,
6568
image_token_index=9,
6669
projector_hidden_act="gelu",
@@ -83,32 +86,32 @@ def __init__(
8386
num_choices=4,
8487
pad_token_id=1,
8588
hidden_size=32,
86-
intermediate_size=64,
89+
intermediate_size=16,
8790
max_position_embeddings=60,
8891
model_type="aria_moe_lm",
8992
moe_intermediate_size=4,
90-
moe_num_experts=4,
93+
moe_num_experts=3,
9194
moe_topk=2,
92-
num_attention_heads=8,
95+
num_attention_heads=2,
9396
num_experts_per_tok=3,
9497
num_hidden_layers=2,
95-
num_key_value_heads=8,
98+
num_key_value_heads=2,
9699
rope_theta=5000000,
97100
vocab_size=99,
98101
eos_token_id=2,
99102
head_dim=4,
100103
),
101104
is_training=True,
102105
vision_config=Idefics3VisionConfig(
103-
image_size=358,
104-
patch_size=10,
106+
image_size=16,
107+
patch_size=8,
105108
num_channels=3,
106109
is_training=True,
107110
hidden_size=32,
108-
projection_dim=20,
111+
projection_dim=4,
109112
num_hidden_layers=2,
110-
num_attention_heads=16,
111-
intermediate_size=10,
113+
num_attention_heads=2,
114+
intermediate_size=4,
112115
dropout=0.1,
113116
attention_dropout=0.1,
114117
initializer_range=0.02,
@@ -130,11 +133,14 @@ def __init__(
130133
self.num_attention_heads = text_config.num_attention_heads
131134
self.is_training = is_training
132135

133-
self.batch_size = 10
134-
self.num_channels = 3
135-
self.image_size = 358
136-
self.num_image_tokens = 128
136+
self.batch_size = batch_size
137+
self.num_channels = num_channels
138+
self.image_size = image_size
139+
self.num_image_tokens = num_image_tokens
137140
self.seq_length = seq_length + self.num_image_tokens
141+
self.projector_patch_to_query_dict = {
142+
vision_config.image_size**2 // vision_config.patch_size**2: vision_config.projection_dim
143+
}
138144

139145
def get_config(self):
140146
return AriaConfig(
@@ -146,6 +152,7 @@ def get_config(self):
146152
vision_feature_select_strategy=self.vision_feature_select_strategy,
147153
vision_feature_layer=self.vision_feature_layer,
148154
eos_token_id=self.eos_token_id,
155+
projector_patch_to_query_dict=self.projector_patch_to_query_dict,
149156
)
150157

151158
def prepare_config_and_inputs(self):
@@ -176,7 +183,6 @@ def prepare_config_and_inputs_for_common(self):
176183
return config, inputs_dict
177184

178185

179-
@slow
180186
@require_torch
181187
class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
182188
"""
@@ -193,61 +199,10 @@ def setUp(self):
193199
self.model_tester = AriaVisionText2TextModelTester(self)
194200
self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False)
195201

196-
@unittest.skip(
197-
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
198-
)
199-
def test_training_gradient_checkpointing(self):
200-
pass
201-
202-
@unittest.skip(
203-
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
204-
)
205-
def test_training_gradient_checkpointing_use_reentrant(self):
206-
pass
207-
208-
@unittest.skip(
209-
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
210-
)
211-
def test_training_gradient_checkpointing_use_reentrant_false(self):
212-
pass
213-
214-
@unittest.skip(reason="Compile not yet supported because in LLava models")
215-
@pytest.mark.torch_compile_test
216-
def test_sdpa_can_compile_dynamic(self):
217-
pass
218-
219-
@unittest.skip(reason="Compile not yet supported because in LLava models")
220-
def test_sdpa_can_dispatch_on_flash(self):
221-
pass
222-
223-
@unittest.skip(reason="Feedforward chunking is not yet supported")
224-
def test_feed_forward_chunking(self):
225-
pass
226-
227202
@unittest.skip(reason="Unstable test")
228203
def test_initialization(self):
229204
pass
230205

231-
@unittest.skip(reason="Dynamic control flow due to MoE")
232-
def test_generate_with_static_cache(self):
233-
pass
234-
235-
@unittest.skip(reason="Dynamic control flow due to MoE")
236-
def test_generate_from_inputs_embeds_with_static_cache(self):
237-
pass
238-
239-
@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
240-
def test_cpu_offload(self):
241-
pass
242-
243-
@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
244-
def test_disk_offload_bin(self):
245-
pass
246-
247-
@unittest.skip(reason="Aria uses nn.MHA which is not compatible with offloading")
248-
def test_disk_offload_safetensors(self):
249-
pass
250-
251206

252207
SKIP = False
253208
torch_accelerator_module = getattr(torch, torch_device)

0 commit comments

Comments
 (0)