|
29 | 29 |
|
30 | 30 | import pytest |
31 | 31 | import requests |
32 | | -from huggingface_hub import HfApi, HfFolder |
| 32 | +from huggingface_hub import HfApi, HfFolder, split_torch_state_dict_into_shards |
33 | 33 | from parameterized import parameterized |
34 | 34 | from pytest import mark |
35 | 35 | from requests.exceptions import HTTPError |
@@ -139,6 +139,32 @@ def __init__(self, config): |
139 | 139 | def forward(self, x): |
140 | 140 | return self.linear_2(self.linear(x)) |
141 | 141 |
|
| 142 | + class BaseModelWithUnexpectedKeys(PreTrainedModel): |
| 143 | + base_model_prefix = "base" |
| 144 | + config_class = PretrainedConfig |
| 145 | + _keys_to_ignore_on_load_unexpected = [r"^mtp.*"] |
| 146 | + |
| 147 | + def __init__(self, config): |
| 148 | + super().__init__(config) |
| 149 | + self.linear = nn.Linear(50, 50) |
| 150 | + self.linear_2 = nn.Linear(50, 50) |
| 151 | + |
| 152 | + def forward(self, x): |
| 153 | + return self.linear_2(self.linear(x)) |
| 154 | + |
| 155 | + class BaseModelWithMissingKeys(PreTrainedModel): |
| 156 | + base_model_prefix = "base" |
| 157 | + config_class = PretrainedConfig |
| 158 | + _keys_to_ignore_on_load_missing = [r"^linear"] |
| 159 | + |
| 160 | + def __init__(self, config): |
| 161 | + super().__init__(config) |
| 162 | + self.linear = nn.Linear(50, 50) |
| 163 | + self.linear_2 = nn.Linear(50, 50) |
| 164 | + |
| 165 | + def forward(self, x): |
| 166 | + return self.linear_2(self.linear(x)) |
| 167 | + |
142 | 168 | class BaseModelWithTiedWeights(PreTrainedModel): |
143 | 169 | config_class = PretrainedConfig |
144 | 170 |
|
@@ -2028,6 +2054,92 @@ class MyModelD(MyModelA): |
2028 | 2054 | self.assertIs(MyModelC.config_class, MyConfigC) |
2029 | 2055 | self.assertIs(MyModelD.config_class, MyConfigA) |
2030 | 2056 |
|
| 2057 | + def test_ignore_missing_key_works(self): |
| 2058 | + """Test that if a parameter (not buffer) is specified in `_keys_to_ignore_on_load_missing` and is actually |
| 2059 | + missing from the checkpoint, it will still be moved to cpu and initialized""" |
| 2060 | + temp = tempfile.TemporaryDirectory() |
| 2061 | + # Create dummy model |
| 2062 | + model = BaseModelWithMissingKeys(PretrainedConfig()) |
| 2063 | + |
| 2064 | + # Save the config |
| 2065 | + model.config.save_pretrained(temp.name) |
| 2066 | + # Get the state dict to save |
| 2067 | + state_dict = model.state_dict() |
| 2068 | + # Remove the layer that we should ignore if missing |
| 2069 | + del state_dict["linear.weight"], state_dict["linear.bias"] |
| 2070 | + # Save the state dict as a single shard |
| 2071 | + safe_save_file(state_dict, Path(temp.name) / "model.safetensors", metadata={"format": "pt"}) |
| 2072 | + |
| 2073 | + # Try loading back, with the missing key not present in the state_dict |
| 2074 | + model = BaseModelWithMissingKeys.from_pretrained(temp.name) |
| 2075 | + |
| 2076 | + # Make sure the skipped missing key is not still on meta device! |
| 2077 | + for k, v in model.state_dict().items(): |
| 2078 | + self.assertTrue(v.device.type == "cpu", f"{k} is not on cpu!") |
| 2079 | + |
| 2080 | + def test_device_map_works_with_unexpected_keys(self): |
| 2081 | + """Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually |
| 2082 | + present in the checkpoint, it will correctly be removed from the weights we load, especially those |
| 2083 | + we use if the device map has offloading""" |
| 2084 | + temp = tempfile.TemporaryDirectory() |
| 2085 | + |
| 2086 | + # Create dummy model |
| 2087 | + model = BaseModelWithUnexpectedKeys(PretrainedConfig()) |
| 2088 | + |
| 2089 | + # Save the config |
| 2090 | + model.config.save_pretrained(temp.name) |
| 2091 | + |
| 2092 | + # Get the state dict to save |
| 2093 | + state_dict = model.state_dict() |
| 2094 | + # Add a layer that is in the "_keys_to_ignore_on_load_unexpected" list to ignore |
| 2095 | + state_dict["mtp"] = torch.randn(12, 12) |
| 2096 | + # Save the state dict as a single shard |
| 2097 | + safe_save_file(state_dict, Path(temp.name) / "model.safetensors", metadata={"format": "pt"}) |
| 2098 | + |
| 2099 | + # Load the model with entire shards placed on disk in order to trigger `get_disk_only_shard_files`. |
| 2100 | + # Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out. |
| 2101 | + BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"}) |
| 2102 | + |
| 2103 | + def test_device_map_works_with_unexpected_keys_sharded(self): |
| 2104 | + """Test that if a parameter is specified in `_keys_to_ignore_on_load_unexpected` and is actually |
| 2105 | + present in the checkpoint, it will correctly be removed from the weights we load, especially those |
| 2106 | + we use if the device map has offloading""" |
| 2107 | + temp = tempfile.TemporaryDirectory() |
| 2108 | + |
| 2109 | + # Create dummy model |
| 2110 | + model = BaseModelWithUnexpectedKeys(PretrainedConfig()) |
| 2111 | + |
| 2112 | + # Save the config |
| 2113 | + model.config.save_pretrained(temp.name) |
| 2114 | + |
| 2115 | + # Get the state dict to save |
| 2116 | + state_dict = model.state_dict() |
| 2117 | + |
| 2118 | + # Add a layer that is in the "_keys_to_ignore_on_load_unexpected" list to ignore |
| 2119 | + state_dict["mtp"] = torch.randn(50, 50) |
| 2120 | + |
| 2121 | + # Split the state dict in shards, save the index and the shards |
| 2122 | + shards = split_torch_state_dict_into_shards(state_dict, max_shard_size="1kb") |
| 2123 | + index = { |
| 2124 | + "metadata": {"total_parameters": model.num_parameters(), **shards.metadata}, |
| 2125 | + "weight_map": shards.tensor_to_filename, |
| 2126 | + } |
| 2127 | + with open(Path(temp.name) / SAFE_WEIGHTS_INDEX_NAME, "w", encoding="utf-8") as f: |
| 2128 | + content = json.dumps(index, indent=2, sort_keys=True) + "\n" |
| 2129 | + f.write(content) |
| 2130 | + |
| 2131 | + # Save each shard |
| 2132 | + filename_to_tensors = shards.filename_to_tensors.items() |
| 2133 | + for shard_file, tensors in filename_to_tensors: |
| 2134 | + shard = {} |
| 2135 | + for tensor in tensors: |
| 2136 | + shard[tensor] = state_dict[tensor].contiguous() |
| 2137 | + safe_save_file(shard, Path(temp.name) / shard_file, metadata={"format": "pt"}) |
| 2138 | + |
| 2139 | + # Load the model with entire shards placed on disk in order to trigger `get_disk_only_shard_files`. |
| 2140 | + # Unexpected keys (mtp) should be removed from the state dict, therefore this should not error out. |
| 2141 | + BaseModelWithUnexpectedKeys.from_pretrained(temp.name, device_map={"linear": "cpu", "linear_2": "disk"}) |
| 2142 | + |
2031 | 2143 |
|
2032 | 2144 | @slow |
2033 | 2145 | @require_torch |
|
0 commit comments