1515
1616import unittest
1717
18- import pytest
1918import requests
2019
2120from transformers import (
@@ -61,6 +60,10 @@ class AriaVisionText2TextModelTester:
6160 def __init__ (
6261 self ,
6362 parent ,
63+ batch_size = 13 ,
64+ num_channels = 3 ,
65+ image_size = 16 ,
66+ num_image_tokens = 4 ,
6467 ignore_index = - 100 ,
6568 image_token_index = 9 ,
6669 projector_hidden_act = "gelu" ,
@@ -83,32 +86,32 @@ def __init__(
8386 num_choices = 4 ,
8487 pad_token_id = 1 ,
8588 hidden_size = 32 ,
86- intermediate_size = 64 ,
89+ intermediate_size = 16 ,
8790 max_position_embeddings = 60 ,
8891 model_type = "aria_moe_lm" ,
8992 moe_intermediate_size = 4 ,
90- moe_num_experts = 4 ,
93+ moe_num_experts = 3 ,
9194 moe_topk = 2 ,
92- num_attention_heads = 8 ,
95+ num_attention_heads = 2 ,
9396 num_experts_per_tok = 3 ,
9497 num_hidden_layers = 2 ,
95- num_key_value_heads = 8 ,
98+ num_key_value_heads = 2 ,
9699 rope_theta = 5000000 ,
97100 vocab_size = 99 ,
98101 eos_token_id = 2 ,
99102 head_dim = 4 ,
100103 ),
101104 is_training = True ,
102105 vision_config = Idefics3VisionConfig (
103- image_size = 358 ,
104- patch_size = 10 ,
106+ image_size = 16 ,
107+ patch_size = 8 ,
105108 num_channels = 3 ,
106109 is_training = True ,
107110 hidden_size = 32 ,
108- projection_dim = 20 ,
111+ projection_dim = 4 ,
109112 num_hidden_layers = 2 ,
110- num_attention_heads = 16 ,
111- intermediate_size = 10 ,
113+ num_attention_heads = 2 ,
114+ intermediate_size = 4 ,
112115 dropout = 0.1 ,
113116 attention_dropout = 0.1 ,
114117 initializer_range = 0.02 ,
@@ -130,11 +133,14 @@ def __init__(
130133 self .num_attention_heads = text_config .num_attention_heads
131134 self .is_training = is_training
132135
133- self .batch_size = 10
134- self .num_channels = 3
135- self .image_size = 358
136- self .num_image_tokens = 128
136+ self .batch_size = batch_size
137+ self .num_channels = num_channels
138+ self .image_size = image_size
139+ self .num_image_tokens = num_image_tokens
137140 self .seq_length = seq_length + self .num_image_tokens
141+ self .projector_patch_to_query_dict = {
142+ vision_config .image_size ** 2 // vision_config .patch_size ** 2 : vision_config .projection_dim
143+ }
138144
139145 def get_config (self ):
140146 return AriaConfig (
@@ -146,6 +152,7 @@ def get_config(self):
146152 vision_feature_select_strategy = self .vision_feature_select_strategy ,
147153 vision_feature_layer = self .vision_feature_layer ,
148154 eos_token_id = self .eos_token_id ,
155+ projector_patch_to_query_dict = self .projector_patch_to_query_dict ,
149156 )
150157
151158 def prepare_config_and_inputs (self ):
@@ -176,7 +183,6 @@ def prepare_config_and_inputs_for_common(self):
176183 return config , inputs_dict
177184
178185
179- @slow
180186@require_torch
181187class AriaForConditionalGenerationModelTest (ModelTesterMixin , GenerationTesterMixin , unittest .TestCase ):
182188 """
@@ -193,61 +199,10 @@ def setUp(self):
193199 self .model_tester = AriaVisionText2TextModelTester (self )
194200 self .config_tester = ConfigTester (self , config_class = AriaConfig , has_text_modality = False )
195201
196- @unittest .skip (
197- reason = "This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
198- )
199- def test_training_gradient_checkpointing (self ):
200- pass
201-
202- @unittest .skip (
203- reason = "This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
204- )
205- def test_training_gradient_checkpointing_use_reentrant (self ):
206- pass
207-
208- @unittest .skip (
209- reason = "This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
210- )
211- def test_training_gradient_checkpointing_use_reentrant_false (self ):
212- pass
213-
214- @unittest .skip (reason = "Compile not yet supported because in LLava models" )
215- @pytest .mark .torch_compile_test
216- def test_sdpa_can_compile_dynamic (self ):
217- pass
218-
219- @unittest .skip (reason = "Compile not yet supported because in LLava models" )
220- def test_sdpa_can_dispatch_on_flash (self ):
221- pass
222-
223- @unittest .skip (reason = "Feedforward chunking is not yet supported" )
224- def test_feed_forward_chunking (self ):
225- pass
226-
227202 @unittest .skip (reason = "Unstable test" )
228203 def test_initialization (self ):
229204 pass
230205
231- @unittest .skip (reason = "Dynamic control flow due to MoE" )
232- def test_generate_with_static_cache (self ):
233- pass
234-
235- @unittest .skip (reason = "Dynamic control flow due to MoE" )
236- def test_generate_from_inputs_embeds_with_static_cache (self ):
237- pass
238-
239- @unittest .skip (reason = "Aria uses nn.MHA which is not compatible with offloading" )
240- def test_cpu_offload (self ):
241- pass
242-
243- @unittest .skip (reason = "Aria uses nn.MHA which is not compatible with offloading" )
244- def test_disk_offload_bin (self ):
245- pass
246-
247- @unittest .skip (reason = "Aria uses nn.MHA which is not compatible with offloading" )
248- def test_disk_offload_safetensors (self ):
249- pass
250-
251206
252207SKIP = False
253208torch_accelerator_module = getattr (torch , torch_device )
0 commit comments