diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 80656de2fe90..6d08831482c4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4215,14 +4215,31 @@ def _load_pretrained_model( pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")") if sharded_metadata is None: k_v_iterator = dict.fromkeys( - safe_open(checkpoint_files[0], framework="pt").keys(), checkpoint_files[0].rsplit("/", 1)[1] + safe_open(checkpoint_files[0], framework="pt").keys(), os.path.basename(checkpoint_files[0]) ).items() else: k_v_iterator = sharded_metadata["weight_map"].items() - merged_state_dict = {} + # Create a mapping from filename to full path for all checkpoint files + filename_to_path = {os.path.basename(f): f for f in checkpoint_files} + + # Group weights by file to load sequentially and avoid keeping too many files open + weights_by_file = {} for k, v in k_v_iterator: - match = pattern.match(k) + if v not in weights_by_file: + weights_by_file[v] = [] + weights_by_file[v].append(k) + + merged_state_dict = {} + # Load each file sequentially + for filename, weight_keys in weights_by_file.items(): + # Use the mapping to get the correct file path instead of joining paths + # This handles symbolic links on Windows correctly + shard_file_path = filename_to_path.get( + filename, os.path.join(os.path.dirname(checkpoint_files[0]), filename) + ) + + match = pattern.match(weight_keys[0]) if match and match.group(1) != "": device = device_map[match.group(1)] else: @@ -4231,11 +4248,11 @@ def _load_pretrained_model( device = device.index # safetensors only if device == "disk": device = "cpu" # we read to cpu to then write to disk - file_pointer = safe_open( - os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device - ) - all_pointer.add(file_pointer) - merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet + + with safe_open(shard_file_path, framework="pt", device=device) as f: + for k in weight_keys: + # Materialize the tensor immediately instead of keeping a lazy slice + merged_state_dict[k] = f.get_tensor(k) elif state_dict is not None: merged_state_dict = state_dict elif checkpoint_files is not None: