Skip to content

Commit eddd51e

Browse files
tohtana3outeille
andauthored
Fix checkpoint loading with DeepSpeed ZeRO3 (#42201)
fix checkpoint loading with DeepSpeed ZeRO3 Signed-off-by: Masahiro Tanaka <[email protected]> Co-authored-by: Ferdinand Mom <[email protected]>
1 parent 7607d80 commit eddd51e

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/transformers/modeling_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4233,6 +4233,19 @@ def _load_pretrained_model(
42334233
error_msgs = []
42344234

42354235
if is_deepspeed_zero3_enabled() and not is_quantized:
4236+
if state_dict is None:
4237+
if checkpoint_files is None:
4238+
raise ValueError(
4239+
"DeepSpeed ZeRO-3 initialization requires a state_dict or checkpoint files to load from."
4240+
)
4241+
merged_state_dict = {}
4242+
for ckpt_file in checkpoint_files:
4243+
merged_state_dict.update(
4244+
load_state_dict(
4245+
ckpt_file, is_quantized=is_quantized, map_location="cpu", weights_only=weights_only
4246+
)
4247+
)
4248+
state_dict = merged_state_dict
42364249
error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
42374250
# This is not true but for now we assume only best-case scenario with deepspeed, i.e. perfectly matching checkpoints
42384251
missing_keys, unexpected_keys, mismatched_keys, misc = set(), set(), set(), set()

0 commit comments

Comments
 (0)