Skip to content

Commit d02a06e

Browse files
authored
fix mistake
The correct ```download_model.py``` was not uploaded prior to last release, thus preventing the transcriber and image summary generator from working properly. Also removed a de-bugging print statement from ```gui.py```.
1 parent abe02d8 commit d02a06e

File tree

2 files changed

+40
-47
lines changed

2 files changed

+40
-47
lines changed

src/download_model.py

Lines changed: 39 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,44 @@
11
import os
2-
import subprocess
32
from pathlib import Path
4-
from PySide6.QtCore import QObject, Signal
3+
from huggingface_hub import snapshot_download, HfApi
4+
import logging
5+
import threading
56

6-
class ModelDownloadedSignal(QObject):
7-
downloaded = Signal(str, str)
7+
logging.getLogger("transformers").setLevel(logging.ERROR)
88

9-
model_downloaded_signal = ModelDownloadedSignal()
10-
11-
MODEL_DIRECTORIES = {
12-
"vector": "Vector",
13-
"chat": "Chat"
14-
}
15-
16-
class ModelDownloader:
17-
def __init__(self, model_name, model_type):
18-
self.model_name = model_name
19-
self.model_type = model_type
20-
self._model_directory = None
21-
22-
def get_model_directory_name(self):
23-
return self.model_name.replace("/", "--")
24-
25-
def get_model_directory(self):
26-
if not self._model_directory:
27-
model_type_dir = MODEL_DIRECTORIES.get(self.model_type, "")
28-
self._model_directory = Path("Models") / model_type_dir / self.get_model_directory_name()
29-
return self._model_directory
30-
31-
def get_model_url(self):
32-
return f"https://huggingface.co/{self.model_name}"
33-
34-
def download_model(self):
35-
model_url = self.get_model_url()
36-
target_directory = self.get_model_directory()
37-
print(f"Downloading {self.model_name}...")
38-
39-
env = os.environ.copy()
40-
env["GIT_CLONE_PROTECTION_ACTIVE"] = "false"
9+
def download_model_files(repo_id, local_dir):
10+
try:
11+
api = HfApi()
12+
files_list = api.list_repo_files(repo_id)
4113

42-
try:
43-
subprocess.run(
44-
["git", "clone", "--depth", "1", model_url, str(target_directory)],
45-
check=True,
46-
env=env
47-
)
48-
print("\033[92mModel downloaded and ready to use.\033[0m")
49-
model_downloaded_signal.downloaded.emit(self.model_name, self.model_type)
50-
except subprocess.CalledProcessError as e:
51-
print(f"Command 'git clone' returned non-zero exit status {e.returncode}.")
14+
top_level_files = [f for f in files_list if '/' not in f]
15+
if not top_level_files:
16+
raise ValueError("No top-level files found in the repository.")
17+
snapshot_download(
18+
repo_id,
19+
local_dir=local_dir,
20+
allow_patterns=top_level_files,
21+
local_dir_use_symlinks=False,
22+
)
23+
print(f"Downloaded top-level files from {repo_id} to {local_dir}")
24+
return True
25+
except Exception as e:
26+
print(f"Failed to download model: {e}")
27+
return False
28+
29+
def download_model(repo_id):
30+
folder_name = repo_id.replace('/', '_') # CHANGED FROM TWO DASHES
31+
current_dir = Path(__file__).resolve().parent
32+
models_dir = current_dir / "Models" / "vector"
33+
local_dir = models_dir / folder_name
34+
35+
os.makedirs(local_dir, exist_ok=True)
36+
37+
thread = threading.Thread(target=download_model_files, args=(repo_id, local_dir))
38+
thread.start()
39+
return thread
40+
41+
if __name__ == "__main__":
42+
test_repo_id = "thenlper/gte-large"
43+
download_thread = download_model(test_repo_id)
44+
download_thread.join()

src/gui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from utilities import list_theme_files, make_theme_changer, load_stylesheet
1818

1919
# Print the current working directory
20-
print(f"Current working directory: {os.getcwd()}")
20+
# print(f"Current working directory: {os.getcwd()}")
2121

2222
# Check if we can write to the current directory
2323
try:

0 commit comments

Comments
 (0)