@@ -63,23 +63,23 @@ def test_state_dict(self, network):
6363 "dummy2.nested.param" : network .dummy2 .nested .param ,
6464 }
6565
66- def test_load_from_state_dict_top_level_param (
66+ def test_load_state_dict_top_level_param (
6767 self ,
6868 network ,
6969 ):
7070 state_dict = {"param" : tp .Parameter (tp .zeros ((2 ,), dtype = tp .float32 ))}
71- network .load_from_state_dict (state_dict )
71+ network .load_state_dict (state_dict )
7272 assert network .param is state_dict ["param" ]
7373
74- def test_load_from_state_dict_nested_param (
74+ def test_load_state_dict_nested_param (
7575 self ,
7676 network ,
7777 ):
7878 state_dict = {"dummy1.nested.param" : tp .Parameter (tp .arange (2 , dtype = tp .float32 ))}
79- network .load_from_state_dict (state_dict )
79+ network .load_state_dict (state_dict )
8080 assert network .dummy1 .nested .param is state_dict ["dummy1.nested.param" ]
8181
82- def test_load_from_state_dict_with_different_shapes_fails (
82+ def test_load_state_dict_with_different_shapes_fails (
8383 self ,
8484 network ,
8585 ):
@@ -88,9 +88,9 @@ def test_load_from_state_dict_with_different_shapes_fails(
8888 with helper .raises (
8989 tp .TripyException , match = r"New parameter shape: \[3\] is not compatible with current shape: \[2\]"
9090 ):
91- network .load_from_state_dict (state_dict )
91+ network .load_state_dict (state_dict )
9292
93- def test_load_from_state_dict_with_different_dtype_fails (
93+ def test_load_state_dict_with_different_dtype_fails (
9494 self ,
9595 network ,
9696 ):
@@ -99,7 +99,7 @@ def test_load_from_state_dict_with_different_dtype_fails(
9999 with helper .raises (
100100 tp .TripyException , match = "New parameter dtype: float16 is not compatible with current dtype: float32"
101101 ):
102- network .load_from_state_dict (state_dict )
102+ network .load_state_dict (state_dict )
103103
104104 def test_mixed_collections_not_registered (self , network ):
105105 network .mix_param_list = [True , tp .Parameter (1 )]
@@ -153,20 +153,20 @@ def test_state_dict(self, list_network):
153153 "dummy_list.1.nested.param" : list_network .dummy_list [1 ].nested .param ,
154154 }
155155
156- def test_load_from_state_dict_top_level_param (
156+ def test_load_state_dict_top_level_param (
157157 self ,
158158 list_network ,
159159 ):
160160 state_dict = {"params.0" : tp .Parameter (tp .zeros ((2 ,), dtype = tp .float32 ))}
161- list_network .load_from_state_dict (state_dict )
161+ list_network .load_state_dict (state_dict )
162162 assert list_network .params [0 ] is state_dict ["params.0" ]
163163
164- def test_load_from_state_dict_nested_param (
164+ def test_load_state_dict_nested_param (
165165 self ,
166166 list_network ,
167167 ):
168168 state_dict = {"dummy_list.0.nested.param" : tp .Parameter (tp .arange (2 , dtype = tp .float32 ))}
169- list_network .load_from_state_dict (state_dict )
169+ list_network .load_state_dict (state_dict )
170170 assert list_network .dummy_list [0 ].nested .param is state_dict ["dummy_list.0.nested.param" ]
171171
172172 def test_modify_list_param (self , list_network ):
@@ -205,20 +205,20 @@ def test_state_dict(self, dict_network):
205205 "dummy_dict.op1.nested.param" : dict_network .dummy_dict ["op1" ].nested .param ,
206206 }
207207
208- def test_load_from_state_dict_top_level_param (
208+ def test_load_state_dict_top_level_param (
209209 self ,
210210 dict_network ,
211211 ):
212212 state_dict = {"params.param" : tp .Parameter (tp .zeros ((2 ,), dtype = tp .float32 ))}
213- dict_network .load_from_state_dict (state_dict )
213+ dict_network .load_state_dict (state_dict )
214214 assert dict_network .params ["param" ] is state_dict ["params.param" ]
215215
216- def test_load_from_state_dict_nested_param (
216+ def test_load_state_dict_nested_param (
217217 self ,
218218 dict_network ,
219219 ):
220220 state_dict = {"dummy_dict.op0.nested.param" : tp .Parameter (tp .arange (2 , dtype = tp .float32 ))}
221- dict_network .load_from_state_dict (state_dict )
221+ dict_network .load_state_dict (state_dict )
222222 assert dict_network .dummy_dict ["op0" ].nested .param is state_dict ["dummy_dict.op0.nested.param" ]
223223
224224 def test_modify_dict_param (self , dict_network ):
@@ -255,7 +255,7 @@ def test_state_dict(self, mixed_network):
255255 "mixed_dict.dummy.nested.param" : tensor ,
256256 "mixed_dict.dummy_nested.param" : tensor ,
257257 }
258- module .load_from_state_dict (external_state_dict )
258+ module .load_state_dict (external_state_dict )
259259 assert module .mixed_list [0 ].nested .param is tensor
260260 assert module .mixed_list [1 ].param is tensor
261261 assert module .mixed_dict ["dummy" ].nested .param is tensor
@@ -279,7 +279,7 @@ def test_state_dict(self, complex_network):
279279 "nets.list_net.params.0" : tensor ,
280280 "nets.list_net.dummy_list.0.nested.param" : tensor ,
281281 }
282- module .load_from_state_dict (external_state_dict )
282+ module .load_state_dict (external_state_dict )
283283 assert module .nets ["dict_net" ].params ["param" ] is tensor
284284 assert module .nets ["list_net" ].params [0 ] is tensor
285285 assert module .nets ["dict_net" ].dummy_dict ["op0" ].nested .param is tensor
0 commit comments