Skip to content

Commit 9bc6a9a

Browse files
authored
Update model weight validation logic to handle special weight file naming (#13256)
1 parent 7cdaedb commit 9bc6a9a

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

scripts/ci/validate_and_download_models.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -202,22 +202,33 @@ def validate_model_shards(model_path: Path) -> Tuple[bool, Optional[str], List[P
202202
)
203203

204204
if not shard_files:
205-
# No sharded files - check for single model file
206-
single_files = list(model_path.glob("model.safetensors")) or list(
207-
model_path.glob("pytorch_model.bin")
208-
)
205+
# No sharded files - check for any safetensors or bin files
206+
# Exclude non-model files like tokenizer, config, optimizer, etc.
207+
all_safetensors = list(model_path.glob("*.safetensors"))
208+
all_bins = list(model_path.glob("*.bin"))
209+
210+
# Filter out non-model files
211+
excluded_prefixes = ["tokenizer", "optimizer", "training_", "config"]
212+
single_files = [
213+
f
214+
for f in (all_safetensors or all_bins)
215+
if not any(f.name.startswith(prefix) for prefix in excluded_prefixes)
216+
and not f.name.endswith(".index.json")
217+
]
218+
209219
if single_files:
210-
# Validate the single safetensors file if it exists
211-
if single_files[0].suffix == ".safetensors":
212-
is_valid, error_msg = validate_safetensors_file(single_files[0])
213-
if not is_valid:
214-
return (
215-
False,
216-
f"Corrupted file {single_files[0].name}: {error_msg}",
217-
[single_files[0]],
218-
)
220+
# Validate all safetensors files, not just the first one
221+
for model_file in single_files:
222+
if model_file.suffix == ".safetensors":
223+
is_valid, error_msg = validate_safetensors_file(model_file)
224+
if not is_valid:
225+
return (
226+
False,
227+
f"Corrupted file {model_file.name}: {error_msg}",
228+
[model_file],
229+
)
219230
return True, None, []
220-
return False, "No model files found (safetensors or bin)", []
231+
return False, "No model weight files found (safetensors or bin)", []
221232

222233
# Extract total shard count from any shard filename
223234
total_shards = None

0 commit comments

Comments
 (0)