@@ -242,44 +242,55 @@ def test_dimension_input(self):
242242 expected = (inp_cp + inp_cp ).reshape ((- 1 , reshape_dim ))
243243 assert cp .array_equal (cp .from_dlpack (out ), expected )
244244
245- def test_compile_dict_input_info (self ):
246- """Test compilation with dictionary of InputInfo objects."""
247-
245+ def test_compile_nested_dict_input_info (self ):
248246 def func (data_dict ):
249- return data_dict ["a" ] + data_dict ["b" ]
247+ return data_dict ["a" ][ "inner" ] + data_dict ["b" ][ "list" ][ 0 ] + data_dict [ "b" ][ "list" ][ 1 ]
250248
251249 dict_input = {
252- "a" : tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ),
253- "b" : tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ),
250+ "a" : {
251+ "inner" : tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ),
252+ },
253+ "b" : {
254+ "list" : [
255+ tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ),
256+ tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ),
257+ ],
258+ },
254259 }
255260 compiled_func = tp .compile (func , args = [dict_input ])
256261
257- test_dict = {"a" : tp .ones ((2 , 3 ), dtype = tp .float32 ).eval (), "b" : (tp .ones ((2 , 3 ), dtype = tp .float32 ) * 2 ).eval ()}
262+ test_dict = {
263+ "a" : {"inner" : tp .ones ((2 , 3 ), dtype = tp .float32 ).eval ()},
264+ "b" : {
265+ "list" : [
266+ (tp .ones ((2 , 3 ), dtype = tp .float32 ) * 2 ).eval (),
267+ (tp .ones ((2 , 3 ), dtype = tp .float32 ) * 3 ).eval (),
268+ ]
269+ },
270+ }
258271 result = compiled_func (test_dict )
259- expected = test_dict ["a" ] + test_dict ["b" ]
272+ expected = test_dict ["a" ][ "inner" ] + test_dict ["b" ][ "list" ][ 0 ] + test_dict [ "b" ][ "list" ][ 1 ]
260273 assert cp .array_equal (cp .from_dlpack (result ), cp .from_dlpack (expected ))
261274
262- def test_compile_nested_list_input_info (self ):
263- """Test compilation with nested list containers."""
264-
275+ def test_compile_nested_sequence_input_info (self ):
265276 def func (data_list ):
266277 return data_list [0 ] + data_list [1 ][0 ] + data_list [1 ][1 ]
267278
268279 list_input = [
269280 tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ),
270- [ # Nested list
281+ [
271282 tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ),
272- tp .ones ((2 , 3 ), dtype = tp .float32 ) * 2 , # Constant in nested list
283+ tp .ones ((2 , 3 ), dtype = tp .float32 ) * 2 ,
273284 ],
274285 ]
275286 compiled_func = tp .compile (func , args = [list_input ])
276287
277288 test_list = [
278289 tp .ones ((2 , 3 ), dtype = tp .float32 ).eval (),
279- [ # Nested list in test data
290+ (
280291 (tp .ones ((2 , 3 ), dtype = tp .float32 ) * 3 ).eval (),
281- tp .ones ((2 , 3 ), dtype = tp .float32 ) * 2 , # Should match baked constant
282- ] ,
292+ tp .ones ((2 , 3 ), dtype = tp .float32 ) * 2 ,
293+ ) ,
283294 ]
284295 result = compiled_func (test_list )
285296 expected = test_list [0 ] + test_list [1 ][0 ] + test_list [1 ][1 ]
@@ -288,24 +299,35 @@ def func(data_list):
288299 def test_compile_mixed_containers_and_constants (self ):
289300 """Test compilation with comprehensive mix: regular InputInfo, dict container, list container, and standalone constant."""
290301
291- def func (regular_input , data_dict , data_list , constant_value ):
292- return regular_input + data_dict ["x" ] + data_dict ["y" ] + data_list [0 ] + data_list [1 ] + constant_value
302+ def func (regular_input , data_dict , data_list , const_in_dict , const ):
303+ return (
304+ regular_input
305+ + data_dict ["x" ]
306+ + data_dict ["y" ]
307+ + data_list [0 ]
308+ + data_list [1 ]
309+ + const_in_dict ["z" ]
310+ + const
311+ )
293312
294313 regular_input = tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 )
295314 dict_input = {
296315 "x" : tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ),
297- "y" : tp .zeros ((2 , 3 ), dtype = tp .float32 ), # Constant in dict
316+ "y" : tp .zeros ((2 , 3 ), dtype = tp .float32 ),
298317 }
299318 list_input = [tp .InputInfo (shape = (2 , 3 ), dtype = tp .float32 ), tp .ones ((2 , 3 ), dtype = tp .float32 ) * 3 ]
300- constant_value = tp .ones ((2 , 3 ), dtype = tp .float32 ) * 5
319+ const_in_dict = {"z" : tp .ones ((2 , 3 ), dtype = tp .float32 ) * 5 }
320+ const = tp .ones ((2 , 3 ), dtype = tp .float32 ) * 6
301321
302- compiled_func = tp .compile (func , args = [regular_input , dict_input , list_input , constant_value ])
322+ compiled_func = tp .compile (func , args = [regular_input , dict_input , list_input , const_in_dict , const ])
303323
304324 # Only InputInfo arguments should be in function signature
305325 test_regular = tp .ones ((2 , 3 ), dtype = tp .float32 ).eval ()
306326 test_dict = {"x" : (tp .ones ((2 , 3 ), dtype = tp .float32 ) * 2 ).eval (), "y" : tp .zeros ((2 , 3 ), dtype = tp .float32 )}
307327 test_list = [(tp .ones ((2 , 3 ), dtype = tp .float32 ) * 4 ).eval (), tp .ones ((2 , 3 ), dtype = tp .float32 ) * 3 ]
308328
309329 result = compiled_func (test_regular , test_dict , test_list )
310- expected = test_regular + test_dict ["x" ] + test_dict ["y" ] + test_list [0 ] + test_list [1 ] + constant_value
330+ expected = (
331+ test_regular + test_dict ["x" ] + test_dict ["y" ] + test_list [0 ] + test_list [1 ] + const_in_dict ["z" ] + const
332+ )
311333 assert cp .array_equal (cp .from_dlpack (result ), cp .from_dlpack (expected ))
0 commit comments