66import fnmatch
77import humanfriendly
88import atexit
9+ import yaml
10+ import functools
911
1012class ModelDownloadedSignal (QObject ):
11- downloaded = Signal (str , str )
13+ downloaded = Signal (str , str )
1214
1315model_downloaded_signal = ModelDownloadedSignal ()
1416
1517MODEL_DIRECTORIES = {
16- "vector" : "vector" ,
17- "chat" : "chat" ,
18- "tts" : "tts" ,
19- "jeeves" : "jeeves" ,
20- "ocr" : "ocr"
18+ "vector" : "vector" ,
19+ "chat" : "chat" ,
20+ "tts" : "tts" ,
21+ "jeeves" : "jeeves" ,
22+ "ocr" : "ocr"
2123}
2224
25+ @functools .lru_cache (maxsize = 1 )
26+ def get_hf_token ():
27+ config_path = Path ("config.yaml" )
28+ if config_path .exists ():
29+ try :
30+ with open (config_path , 'r' , encoding = 'utf-8' ) as config_file :
31+ config = yaml .safe_load (config_file )
32+ return config .get ('hf_access_token' )
33+ except Exception as e :
34+ print (f"Warning: Could not load config: { e } " )
35+ return None
36+
2337class ModelDownloader (QObject ):
24- def __init__ (self , model_info , model_type ):
25- super ().__init__ ()
26- self .model_info = model_info
27- self .model_type = model_type
28- self ._model_directory = None
29- self .api = HfApi ()
30- self .api .timeout = 60 # increase timeout
31- disable_progress_bars ()
32- self .local_dir = self .get_model_directory ()
33-
34- def cleanup_incomplete_download (self ):
35- if self .local_dir .exists ():
36- import shutil
37- shutil .rmtree (self .local_dir )
38-
39- def get_model_url (self ):
40- if isinstance (self .model_info , dict ):
41- return self .model_info ['repo_id' ]
42- else :
43- return self .model_info
44-
45- def check_repo_type (self , repo_id ):
46- try :
47- repo_info = self .api .repo_info (repo_id , timeout = 60 ) # increase timeout
48- if repo_info .private :
49- return "private"
50- elif getattr (repo_info , 'gated' , False ):
51- return "gated"
52- else :
53- return "public"
54- except GatedRepoError :
55- return "gated"
56- except RepositoryNotFoundError :
57- return "not_found"
58- except Exception as e :
59- return f"error: { str (e )} "
60-
61- def get_model_directory_name (self ):
62- if isinstance (self .model_info , dict ):
63- return self .model_info ['cache_dir' ]
64- else :
65- return self .model_info .replace ("/" , "--" )
66-
67- def get_model_directory (self ):
68- return Path ("Models" ) / self .model_type / self .get_model_directory_name ()
69-
70- def download_model (self , allow_patterns = None , ignore_patterns = None ):
38+ def __init__ (self , model_info , model_type ):
39+ super ().__init__ ()
40+ self .model_info = model_info
41+ self .model_type = model_type
42+ self ._model_directory = None
43+
44+ self .hf_token = get_hf_token ()
45+
46+ self .api = HfApi (token = False )
47+ self .api .timeout = 60
48+ disable_progress_bars ()
49+ self .local_dir = self .get_model_directory ()
50+
51+ def cleanup_incomplete_download (self ):
52+ if self .local_dir .exists ():
53+ import shutil
54+ shutil .rmtree (self .local_dir )
55+
56+ def get_model_url (self ):
57+ if isinstance (self .model_info , dict ):
58+ return self .model_info ['repo_id' ]
59+ else :
60+ return self .model_info
61+
62+ def check_repo_type (self , repo_id ):
63+ try :
64+ repo_info = self .api .repo_info (repo_id , timeout = 60 , token = False )
65+ if repo_info .private :
66+ return "private"
67+ elif getattr (repo_info , 'gated' , False ):
68+ return "gated"
69+ else :
70+ return "public"
71+ except Exception as e :
72+ if self .hf_token and ("401" in str (e ) or "Unauthorized" in str (e )):
73+ try :
74+ api_with_token = HfApi (token = self .hf_token )
75+ repo_info = api_with_token .repo_info (repo_id , timeout = 60 )
76+ if repo_info .private :
77+ return "private"
78+ elif getattr (repo_info , 'gated' , False ):
79+ return "gated"
80+ else :
81+ return "public"
82+ except Exception as e2 :
83+ return f"error: { str (e2 )} "
84+ elif "404" in str (e ):
85+ return "not_found"
86+ else :
87+ return f"error: { str (e )} "
88+
89+ def get_model_directory_name (self ):
90+ if isinstance (self .model_info , dict ):
91+ return self .model_info ['cache_dir' ]
92+ else :
93+ return self .model_info .replace ("/" , "--" )
94+
95+ def get_model_directory (self ):
96+ return Path ("Models" ) / self .model_type / self .get_model_directory_name ()
97+
98+ def download_model (self , allow_patterns = None , ignore_patterns = None ):
7199 repo_id = self .get_model_url ()
72100
73- # only download if repo is public
74- # https://huggingface.co/docs/hub/models-gated#access-gated-models-as-a-user
75- # https://huggingface.co/docs/hub/en/enterprise-hub-tokens-management
76101 repo_type = self .check_repo_type (repo_id )
77- if repo_type != "public" :
102+ if repo_type not in [ "public" , "gated" ] :
78103 if repo_type == "private" :
79- print (f"Repository { repo_id } is private and requires a token. Aborting download." )
80- elif repo_type == "gated" :
81- print (f"Repository { repo_id } is gated. Please request access through the web interface. Aborting download." )
104+ print (f"Repository { repo_id } is private and requires a token." )
105+ if not self .hf_token :
106+ print ("No Hugging Face token found. Please add one through the credentials menu." )
107+ return
82108 elif repo_type == "not_found" :
83109 print (f"Repository { repo_id } not found. Aborting download." )
110+ return
84111 else :
85112 print (f"Error checking repository { repo_id } : { repo_type } . Aborting download." )
113+ return
114+
115+ if repo_type == "gated" and not self .hf_token :
116+ print (f"Repository { repo_id } is gated. Please add a Hugging Face token and request access through the web interface." )
86117 return
87118
88119 local_dir = self .get_model_directory ()
@@ -91,13 +122,12 @@ def download_model(self, allow_patterns=None, ignore_patterns=None):
91122 atexit .register (self .cleanup_incomplete_download )
92123
93124 try :
94- repo_files = list (self .api .list_repo_tree (repo_id , recursive = True ))
95- """
96- allow_patterns: If provided, only matching files are downloaded (ignore_patterns is disregarded)
97- ignore_patterns: If provided alone, matching files are excluded
98- neither: Uses default ignore patterns (.gitattributes, READMEs, etc.) with smart model file filtering
99- both: Behaves same as allow_patterns only
100- """
125+ if repo_type == "gated" and self .hf_token :
126+ api_for_listing = HfApi (token = self .hf_token )
127+ repo_files = list (api_for_listing .list_repo_tree (repo_id , recursive = True ))
128+ else :
129+ repo_files = list (self .api .list_repo_tree (repo_id , recursive = True , token = False ))
130+
101131 if allow_patterns is not None :
102132 final_ignore_patterns = None
103133 elif ignore_patterns is not None :
@@ -154,14 +184,21 @@ def download_model(self, allow_patterns=None, ignore_patterns=None):
154184 print (f"- { file } " )
155185 print (f"\n Downloading to { local_dir } ..." )
156186
157- snapshot_download (
158- repo_id = repo_id ,
159- local_dir = str (local_dir ),
160- max_workers = 4 ,
161- ignore_patterns = final_ignore_patterns ,
162- allow_patterns = allow_patterns ,
163- etag_timeout = 60 # increase timeout
164- )
187+ download_kwargs = {
188+ 'repo_id' : repo_id ,
189+ 'local_dir' : str (local_dir ),
190+ 'max_workers' : 4 ,
191+ 'ignore_patterns' : final_ignore_patterns ,
192+ 'allow_patterns' : allow_patterns ,
193+ 'etag_timeout' : 60
194+ }
195+
196+ if repo_type == "gated" and self .hf_token :
197+ download_kwargs ['token' ] = self .hf_token
198+ elif repo_type == "public" :
199+ download_kwargs ['token' ] = False
200+
201+ snapshot_download (** download_kwargs )
165202
166203 print ("\033 [92mModel downloaded and ready to use.\033 [0m" )
167204 atexit .unregister (self .cleanup_incomplete_download )
0 commit comments