Skip to content

Commit 7258ea4

Browse files
Fix loading logic flaw with regards to unexpected and missing keys (#40850)
* Unexpected keys should be ignored at load with device map * remove them all * fix logic flaw * fix * simplify * style * fix * revert caching allocator change * add other test * add nice doc --------- Co-authored-by: Cyril Vallez <[email protected]>
1 parent 2c4caa1 commit 7258ea4

File tree

2 files changed

+137
-28
lines changed

2 files changed

+137
-28
lines changed

src/transformers/modeling_utils.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -738,8 +738,6 @@ def _load_state_dict_into_meta_model(
738738
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
739739

740740
for param_name, empty_param in state_dict.items():
741-
if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling
742-
continue
743741
# we need to use serialized_param_name as file pointer is untouched
744742
if is_meta_state_dict:
745743
# This is the name of the parameter as it appears on disk file
@@ -1414,7 +1412,6 @@ def _get_device_map(
14141412

14151413

14161414
def _find_missing_and_unexpected_keys(
1417-
cls,
14181415
model: "PreTrainedModel",
14191416
original_checkpoint_keys: list[str],
14201417
checkpoint_keys: list[str],
@@ -1444,12 +1441,6 @@ def _find_missing_and_unexpected_keys(
14441441
model_buffers = {n for n, _ in model.named_buffers()}
14451442
unexpected_keys = sorted(unexpected_keys - model_buffers)
14461443

1447-
# Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model
1448-
# (so the buffer name has changed). Remove them in such a case
1449-
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
1450-
if has_inv_freq_buffers:
1451-
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
1452-
14531444
tied_params = find_tied_parameters(model)
14541445
for group in tied_params:
14551446
missing_in_group = [k for k in missing_keys if k in group]
@@ -1460,15 +1451,6 @@ def _find_missing_and_unexpected_keys(
14601451
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
14611452
unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix)
14621453

1463-
# Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...)
1464-
if cls._keys_to_ignore_on_load_missing is not None:
1465-
for pattern in cls._keys_to_ignore_on_load_missing:
1466-
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None]
1467-
1468-
if cls._keys_to_ignore_on_load_unexpected is not None:
1469-
for pattern in cls._keys_to_ignore_on_load_unexpected:
1470-
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]
1471-
14721454
return missing_keys, unexpected_keys
14731455

14741456

@@ -5320,12 +5302,7 @@ def _load_pretrained_model(
53205302

53215303
# Find missing and unexpected keys from the state dict
53225304
missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
5323-
cls,
5324-
model,
5325-
original_checkpoint_keys,
5326-
checkpoint_keys,
5327-
loading_base_model_from_task_state_dict,
5328-
hf_quantizer,
5305+
model, original_checkpoint_keys, checkpoint_keys, loading_base_model_from_task_state_dict, hf_quantizer
53295306
)
53305307
# Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
53315308
# same way as missing keys)
@@ -5339,8 +5316,10 @@ def _load_pretrained_model(
53395316
weights_only,
53405317
)
53415318

5342-
# We need to update both the mapping and the list of checkpoint keys to remove the mismatched ones
5343-
key_renaming_mapping = {k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys}
5319+
# We need to update both the mapping and the list of checkpoint keys to remove the mismatched and unexpected ones
5320+
key_renaming_mapping = {
5321+
k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys and v not in unexpected_keys
5322+
}
53445323
checkpoint_keys = list(key_renaming_mapping.values())
53455324

53465325
# Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
@@ -5366,14 +5345,15 @@ def _load_pretrained_model(
53665345
# in the submodule
53675346
key_renaming_mapping = {k: v[len(_prefix) :] for k, v in key_renaming_mapping.items()}
53685347
checkpoint_keys = list(key_renaming_mapping.values())
5348+
unexpected_keys = [k[len(_prefix) :] if k.startswith(_prefix) else k for k in unexpected_keys]
53695349
# We need to update the device map as well
53705350
if device_map is not None:
53715351
device_map = {k[len(_prefix) :] if k.startswith(_prefix) else k: v for k, v in device_map.items()}
53725352
# small sanity check: the base model should not contain task-specific head keys
53735353
task_specific_expected_keys = [s for s in model.state_dict() if not s.startswith(_prefix)]
53745354
base_model_expected_keys = list(model_to_load.state_dict().keys())
53755355
if any(
5376-
key in task_specific_expected_keys and key not in base_model_expected_keys for key in checkpoint_keys
5356+
key in task_specific_expected_keys and key not in base_model_expected_keys for key in unexpected_keys
53775357
):
53785358
raise ValueError(
53795359
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
@@ -5555,6 +5535,23 @@ def _load_pretrained_model(
55555535
device_mesh,
55565536
)
55575537

5538+
# Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...)
5539+
# We should remove them here to avoid raising warnings if they are present in the lists
5540+
if cls._keys_to_ignore_on_load_missing is not None:
5541+
for pattern in cls._keys_to_ignore_on_load_missing:
5542+
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None]
5543+
5544+
if cls._keys_to_ignore_on_load_unexpected is not None:
5545+
for pattern in cls._keys_to_ignore_on_load_unexpected:
5546+
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]
5547+
5548+
# Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model
5549+
# (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
5550+
# `_keys_to_ignore_on_load_unexpected` as it touches many models
5551+
has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer, _ in model.named_buffers())
5552+
if has_inv_freq_buffers:
5553+
unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
5554+
55585555
# All potential warnings/infos
55595556
if len(error_msgs) > 0:
55605557
error_msg = "\n\t".join(error_msgs)

tests/utils/test_modeling_utils.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import pytest
3131
import requests
32-
from huggingface_hub import HfApi, HfFolder
32+
from huggingface_hub import HfApi, HfFolder, split_torch_state_dict_into_shards
3333
from parameterized import parameterized
3434
from pytest import mark
3535
from requests.exceptions import HTTPError
@@ -139,6 +139,32 @@ def __init__(self, config):
139139
def forward(self, x):
140140
return self.linear_2(self.linear(x))
141141

142+
class BaseModelWithUnexpectedKeys(PreTrainedModel):
143+
base_model_prefix = "base"
144+
config_class = PretrainedConfig
145+
_keys_to_ignore_on_load_unexpected = [r"^mtp.*"]
146+
147+
def __init__(self, config):
148+
super().__init__(config)
149+
self.linear = nn.Linear(50, 50)
150+
self.linear_2 = nn.Linear(50, 50)
151+
152+
def forward(self, x):
153+
return self.linear_2(self.linear(x))
154+
155+
class BaseModelWithMissingKeys(PreTrainedModel):
156+
base_model_prefix = "base"
157+
config_class = PretrainedConfig
158+
_keys_to_ignore_on_load_missing = [r"^linear"]
159+
160+
def __init__(self, config):
161+
super().__init__(config)
162+
self.linear = nn.Linear(50, 50)
163+
self.linear_2 = nn.Linear(50, 50)
164+
165+
def forward(self, x):
166+
return self.linear_2(self.linear(x))
167+
142168
class BaseModelWithTiedWeights(PreTrainedModel):
143169
config_class = PretrainedConfig
144170

@@ -2028,6 +2054,92 @@ class MyModelD(MyModelA):
20282054
self.assertIs(MyModelC.config_class, MyConfigC)
20292055
self.assertIs(MyModelD.config_class, MyConfigA)
20302056

2057+
def test_ignore_missing_key_works(self):
2058+
"""Test that if a parameter (not buffer) is specified in `_keys_to_ignore_on_load_missing` and is actually
2059+
missing from the checkpoint, it will still be moved to cpu and initialized"""
2060+
temp = tempfile.TemporaryDirectory()
2061+
# Create dummy model
2062+
model = BaseModelWithMissingKeys(PretrainedConfig())
2063+
2064+
# Save the config
2065+
model.config.save_pretrained(temp.name)
2066+
# Get the state dict to save
2067+
state_dict = model.state_dict()
2068+
# Remove the layer that we should ignore if missing
2069+
del state_dict["linear.weight"], state_dict["linear.bias"]
2070+
# Save the state dict as a single shard
2071+
safe_save_file(state_dict, Path(temp.name) / "model.safetensors", metadata={"format": "pt"})
2072+
2073+
# Try loading back, with the missing key not present in the state_dict
2074+
model = BaseModelWithMissingKeys.from_pretrained(temp.name)
2075+
2076+
# Make sure the skipped missing key is not still on meta device!
2077+
for k, v in model.state_dict().items():
2078+
self.assertTrue(v.device.type == "cpu", f"{k} is not on cpu!")
2079+
2080+
def test_device_map_works_with_unexpected_keys(self):
2081+
"""Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually
2082+
present in the checkpoint, it will correctly be removed from the weights we load, especially those
2083+
we use if the device map has offloading"""
2084+
temp = tempfile.TemporaryDirectory()
2085+
2086+
# Create dummy model
2087+
model = BaseModelWithUnexpectedKeys(PretrainedConfig())
2088+
2089+
# Save the config
2090+
model.config.save_pretrained(temp.name)
2091+
2092+
# Get the state dict to save
2093+
state_dict = model.state_dict()
2094+
# Add a layer that is in the "_keys_to_ignore_on_load_unexpected" list to ignore
2095+
state_dict["mtp"] = torch.randn(12, 12)
2096+
# Save the state dict as a single shard
2097+
safe_save_file(state_dict, Path(temp.name) / "model.safetensors", metadata={"format": "pt"})
2098+
2099+
# Load the model with entire shards placed on disk in order to trigger `get_disk_only_shard_files`.
2100+
# Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out.
2101+
BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"})
2102+
2103+
def test_device_map_works_with_unexpected_keys_sharded(self):
2104+
"""Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually
2105+
present in the checkpoint, it will correctly be removed from the weights we load, especially those
2106+
we use if the device map has offloading"""
2107+
temp = tempfile.TemporaryDirectory()
2108+
2109+
# Create dummy model
2110+
model = BaseModelWithUnexpectedKeys(PretrainedConfig())
2111+
2112+
# Save the config
2113+
model.config.save_pretrained(temp.name)
2114+
2115+
# Get the state dict to save
2116+
state_dict = model.state_dict()
2117+
2118+
# Add a layer that is in the "_keys_to_ignore_on_load_unexpected" list to ignore
2119+
state_dict["mtp"] = torch.randn(50, 50)
2120+
2121+
# Split the state dict in shards, save the index and the shards
2122+
shards = split_torch_state_dict_into_shards(state_dict, max_shard_size="1kb")
2123+
index = {
2124+
"metadata": {"total_parameters": model.num_parameters(), **shards.metadata},
2125+
"weight_map": shards.tensor_to_filename,
2126+
}
2127+
with open(Path(temp.name) / SAFE_WEIGHTS_INDEX_NAME, "w", encoding="utf-8") as f:
2128+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
2129+
f.write(content)
2130+
2131+
# Save each shard
2132+
filename_to_tensors = shards.filename_to_tensors.items()
2133+
for shard_file, tensors in filename_to_tensors:
2134+
shard = {}
2135+
for tensor in tensors:
2136+
shard[tensor] = state_dict[tensor].contiguous()
2137+
safe_save_file(shard, Path(temp.name) / shard_file, metadata={"format": "pt"})
2138+
2139+
# Load the model with entire shards placed on disk in order to trigger `get_disk_only_shard_files`.
2140+
# Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out.
2141+
BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"})
2142+
20312143

20322144
@slow
20332145
@require_torch

0 commit comments

Comments
 (0)