11import os
2- import subprocess
32from pathlib import Path
4- from PySide6 .QtCore import QObject , Signal
3+ from huggingface_hub import snapshot_download , HfApi
4+ import logging
5+ import threading
56
6- class ModelDownloadedSignal (QObject ):
7- downloaded = Signal (str , str )
7+ logging .getLogger ("transformers" ).setLevel (logging .ERROR )
88
9- model_downloaded_signal = ModelDownloadedSignal ()
10-
11- MODEL_DIRECTORIES = {
12- "vector" : "Vector" ,
13- "chat" : "Chat"
14- }
15-
16- class ModelDownloader :
17- def __init__ (self , model_name , model_type ):
18- self .model_name = model_name
19- self .model_type = model_type
20- self ._model_directory = None
21-
22- def get_model_directory_name (self ):
23- return self .model_name .replace ("/" , "--" )
24-
25- def get_model_directory (self ):
26- if not self ._model_directory :
27- model_type_dir = MODEL_DIRECTORIES .get (self .model_type , "" )
28- self ._model_directory = Path ("Models" ) / model_type_dir / self .get_model_directory_name ()
29- return self ._model_directory
30-
31- def get_model_url (self ):
32- return f"https://huggingface.co/{ self .model_name } "
33-
34- def download_model (self ):
35- model_url = self .get_model_url ()
36- target_directory = self .get_model_directory ()
37- print (f"Downloading { self .model_name } ..." )
38-
39- env = os .environ .copy ()
40- env ["GIT_CLONE_PROTECTION_ACTIVE" ] = "false"
9+ def download_model_files (repo_id , local_dir ):
10+ try :
11+ api = HfApi ()
12+ files_list = api .list_repo_files (repo_id )
4113
42- try :
43- subprocess .run (
44- ["git" , "clone" , "--depth" , "1" , model_url , str (target_directory )],
45- check = True ,
46- env = env
47- )
48- print ("\033 [92mModel downloaded and ready to use.\033 [0m" )
49- model_downloaded_signal .downloaded .emit (self .model_name , self .model_type )
50- except subprocess .CalledProcessError as e :
51- print (f"Command 'git clone' returned non-zero exit status { e .returncode } ." )
14+ top_level_files = [f for f in files_list if '/' not in f ]
15+ if not top_level_files :
16+ raise ValueError ("No top-level files found in the repository." )
17+ snapshot_download (
18+ repo_id ,
19+ local_dir = local_dir ,
20+ allow_patterns = top_level_files ,
21+ local_dir_use_symlinks = False ,
22+ )
23+ print (f"Downloaded top-level files from { repo_id } to { local_dir } " )
24+ return True
25+ except Exception as e :
26+ print (f"Failed to download model: { e } " )
27+ return False
28+
29+ def download_model (repo_id ):
30+ folder_name = repo_id .replace ('/' , '_' ) # CHANGED FROM TWO DASHES
31+ current_dir = Path (__file__ ).resolve ().parent
32+ models_dir = current_dir / "Models" / "vector"
33+ local_dir = models_dir / folder_name
34+
35+ os .makedirs (local_dir , exist_ok = True )
36+
37+ thread = threading .Thread (target = download_model_files , args = (repo_id , local_dir ))
38+ thread .start ()
39+ return thread
40+
41+ if __name__ == "__main__" :
42+ test_repo_id = "thenlper/gte-large"
43+ download_thread = download_model (test_repo_id )
44+ download_thread .join ()
0 commit comments