@@ -202,22 +202,33 @@ def validate_model_shards(model_path: Path) -> Tuple[bool, Optional[str], List[P
202202 )
203203
204204 if not shard_files :
205- # No sharded files - check for single model file
206- single_files = list (model_path .glob ("model.safetensors" )) or list (
207- model_path .glob ("pytorch_model.bin" )
208- )
205+ # No sharded files - check for any safetensors or bin files
206+ # Exclude non-model files like tokenizer, config, optimizer, etc.
207+ all_safetensors = list (model_path .glob ("*.safetensors" ))
208+ all_bins = list (model_path .glob ("*.bin" ))
209+
210+ # Filter out non-model files
211+ excluded_prefixes = ["tokenizer" , "optimizer" , "training_" , "config" ]
212+ single_files = [
213+ f
214+ for f in (all_safetensors or all_bins )
215+ if not any (f .name .startswith (prefix ) for prefix in excluded_prefixes )
216+ and not f .name .endswith (".index.json" )
217+ ]
218+
209219 if single_files :
210- # Validate the single safetensors file if it exists
211- if single_files [0 ].suffix == ".safetensors" :
212- is_valid , error_msg = validate_safetensors_file (single_files [0 ])
213- if not is_valid :
214- return (
215- False ,
216- f"Corrupted file { single_files [0 ].name } : { error_msg } " ,
217- [single_files [0 ]],
218- )
220+ # Validate all safetensors files, not just the first one
221+ for model_file in single_files :
222+ if model_file .suffix == ".safetensors" :
223+ is_valid , error_msg = validate_safetensors_file (model_file )
224+ if not is_valid :
225+ return (
226+ False ,
227+ f"Corrupted file { model_file .name } : { error_msg } " ,
228+ [model_file ],
229+ )
219230 return True , None , []
220- return False , "No model files found (safetensors or bin)" , []
231+ return False , "No model weight files found (safetensors or bin)" , []
221232
222233 # Extract total shard count from any shard filename
223234 total_shards = None
0 commit comments