Skip to content

Commit 928b541

Browse files
Renames load_from_state_dict to load_state_dict
1 parent 342dec5 commit 928b541

File tree

4 files changed

+24
-24
lines changed

4 files changed

+24
-24
lines changed

tripy/examples/nanogpt/weight_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2024-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -52,7 +52,7 @@ def load_weights_from_hf(model, model_type, dtype):
5252
param = tp.Parameter(weight)
5353
tripy_state_dict[key] = param
5454

55-
model.load_from_state_dict(tripy_state_dict)
55+
model.load_state_dict(tripy_state_dict)
5656

5757

5858
def load_quant_weights_from_hf(model, model_type, dtype, quant_mode):
@@ -112,5 +112,5 @@ def get_submodule(module, attr_name):
112112
param = tp.Parameter(weight.contiguous())
113113
tripy_state_dict[key] = param
114114

115-
model.load_from_state_dict(tripy_state_dict)
115+
model.load_state_dict(tripy_state_dict)
116116
print("Loaded weights to tripy model.")

tripy/tests/frontend/module/test_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,6 @@ def test_load_quantized_params_from_state_dict(self):
6565
weight_quant_dim=0,
6666
)
6767

68-
qlinear.load_from_state_dict(
68+
qlinear.load_state_dict(
6969
{"weight_scale": tp.Parameter(tp.ones((20,))), "input_scale": tp.Parameter(tp.ones((20,)))}
7070
)

tripy/tests/frontend/module/test_module.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,23 +63,23 @@ def test_state_dict(self, network):
6363
"dummy2.nested.param": network.dummy2.nested.param,
6464
}
6565

66-
def test_load_from_state_dict_top_level_param(
66+
def test_load_state_dict_top_level_param(
6767
self,
6868
network,
6969
):
7070
state_dict = {"param": tp.Parameter(tp.zeros((2,), dtype=tp.float32))}
71-
network.load_from_state_dict(state_dict)
71+
network.load_state_dict(state_dict)
7272
assert network.param is state_dict["param"]
7373

74-
def test_load_from_state_dict_nested_param(
74+
def test_load_state_dict_nested_param(
7575
self,
7676
network,
7777
):
7878
state_dict = {"dummy1.nested.param": tp.Parameter(tp.arange(2, dtype=tp.float32))}
79-
network.load_from_state_dict(state_dict)
79+
network.load_state_dict(state_dict)
8080
assert network.dummy1.nested.param is state_dict["dummy1.nested.param"]
8181

82-
def test_load_from_state_dict_with_different_shapes_fails(
82+
def test_load_state_dict_with_different_shapes_fails(
8383
self,
8484
network,
8585
):
@@ -88,9 +88,9 @@ def test_load_from_state_dict_with_different_shapes_fails(
8888
with helper.raises(
8989
tp.TripyException, match=r"New parameter shape: \[3\] is not compatible with current shape: \[2\]"
9090
):
91-
network.load_from_state_dict(state_dict)
91+
network.load_state_dict(state_dict)
9292

93-
def test_load_from_state_dict_with_different_dtype_fails(
93+
def test_load_state_dict_with_different_dtype_fails(
9494
self,
9595
network,
9696
):
@@ -99,7 +99,7 @@ def test_load_from_state_dict_with_different_dtype_fails(
9999
with helper.raises(
100100
tp.TripyException, match="New parameter dtype: float16 is not compatible with current dtype: float32"
101101
):
102-
network.load_from_state_dict(state_dict)
102+
network.load_state_dict(state_dict)
103103

104104
def test_mixed_collections_not_registered(self, network):
105105
network.mix_param_list = [True, tp.Parameter(1)]
@@ -153,20 +153,20 @@ def test_state_dict(self, list_network):
153153
"dummy_list.1.nested.param": list_network.dummy_list[1].nested.param,
154154
}
155155

156-
def test_load_from_state_dict_top_level_param(
156+
def test_load_state_dict_top_level_param(
157157
self,
158158
list_network,
159159
):
160160
state_dict = {"params.0": tp.Parameter(tp.zeros((2,), dtype=tp.float32))}
161-
list_network.load_from_state_dict(state_dict)
161+
list_network.load_state_dict(state_dict)
162162
assert list_network.params[0] is state_dict["params.0"]
163163

164-
def test_load_from_state_dict_nested_param(
164+
def test_load_state_dict_nested_param(
165165
self,
166166
list_network,
167167
):
168168
state_dict = {"dummy_list.0.nested.param": tp.Parameter(tp.arange(2, dtype=tp.float32))}
169-
list_network.load_from_state_dict(state_dict)
169+
list_network.load_state_dict(state_dict)
170170
assert list_network.dummy_list[0].nested.param is state_dict["dummy_list.0.nested.param"]
171171

172172
def test_modify_list_param(self, list_network):
@@ -205,20 +205,20 @@ def test_state_dict(self, dict_network):
205205
"dummy_dict.op1.nested.param": dict_network.dummy_dict["op1"].nested.param,
206206
}
207207

208-
def test_load_from_state_dict_top_level_param(
208+
def test_load_state_dict_top_level_param(
209209
self,
210210
dict_network,
211211
):
212212
state_dict = {"params.param": tp.Parameter(tp.zeros((2,), dtype=tp.float32))}
213-
dict_network.load_from_state_dict(state_dict)
213+
dict_network.load_state_dict(state_dict)
214214
assert dict_network.params["param"] is state_dict["params.param"]
215215

216-
def test_load_from_state_dict_nested_param(
216+
def test_load_state_dict_nested_param(
217217
self,
218218
dict_network,
219219
):
220220
state_dict = {"dummy_dict.op0.nested.param": tp.Parameter(tp.arange(2, dtype=tp.float32))}
221-
dict_network.load_from_state_dict(state_dict)
221+
dict_network.load_state_dict(state_dict)
222222
assert dict_network.dummy_dict["op0"].nested.param is state_dict["dummy_dict.op0.nested.param"]
223223

224224
def test_modify_dict_param(self, dict_network):
@@ -255,7 +255,7 @@ def test_state_dict(self, mixed_network):
255255
"mixed_dict.dummy.nested.param": tensor,
256256
"mixed_dict.dummy_nested.param": tensor,
257257
}
258-
module.load_from_state_dict(external_state_dict)
258+
module.load_state_dict(external_state_dict)
259259
assert module.mixed_list[0].nested.param is tensor
260260
assert module.mixed_list[1].param is tensor
261261
assert module.mixed_dict["dummy"].nested.param is tensor
@@ -279,7 +279,7 @@ def test_state_dict(self, complex_network):
279279
"nets.list_net.params.0": tensor,
280280
"nets.list_net.dummy_list.0.nested.param": tensor,
281281
}
282-
module.load_from_state_dict(external_state_dict)
282+
module.load_state_dict(external_state_dict)
283283
assert module.nets["dict_net"].params["param"] is tensor
284284
assert module.nets["list_net"].params[0] is tensor
285285
assert module.nets["dict_net"].dummy_dict["op0"].nested.param is tensor

tripy/tripy/frontend/module/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(self):
147147

148148
return state_dict
149149

150-
def load_from_state_dict(self, state_dict: Dict[str, Parameter]) -> None:
150+
def load_state_dict(self, state_dict: Dict[str, Parameter]) -> None:
151151
r"""
152152
Loads parameters from the provided ``state_dict`` into the current module.
153153
This will recurse over any nested child modules.
@@ -174,7 +174,7 @@ def __init__(self): # doc: omit
174174
print(f"Before: {module.param}")
175175
176176
state_dict["param"] = tp.Parameter(tp.zeros((2,), dtype=tp.float32))
177-
module.load_from_state_dict(state_dict)
177+
module.load_state_dict(state_dict)
178178
179179
print(f"After: {module.param}")
180180

0 commit comments

Comments
 (0)