Skip to content

Commit 0609d7c

Browse files
committed
clarify cache folder
1 parent f7e6fff commit 0609d7c

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

whisper_timestamped/transcribe.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,7 +2392,7 @@ def load_model(
23922392
backend : str, optional
23932393
Backend to use. Either "transformers" or "openai-whisper".
23942394
download_root : str, optional
2395-
Root folder to download the model to. If None, use the default download root.
2395+
Root folder to download the model to. If None, use the default download root (typically: ~/.cache)
23962396
in_memory : bool, optional
23972397
Whether to preload the model weights into host memory.
23982398
"""
@@ -2405,11 +2405,12 @@ def load_model(
24052405
name = f"openai/whisper-{name}"
24062406
# TODO: use download_root
24072407
# TODO: does in_memory makes sense?
2408+
cache_dir=os.path.join(download_root, "huggingface", "hub") if download_root else None,
24082409
try:
2409-
generation_config = transformers.GenerationConfig.from_pretrained(name)
2410+
generation_config = transformers.GenerationConfig.from_pretrained(name, cache_dir=cache_dir)
24102411
except OSError:
2411-
generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny")
2412-
processor = transformers.WhisperProcessor.from_pretrained(name)
2412+
generation_config = transformers.GenerationConfig.from_pretrained("openai/whisper-tiny", cache_dir=cache_dir)
2413+
processor = transformers.WhisperProcessor.from_pretrained(name, cache_dir=cache_dir)
24132414
if device is None:
24142415
device = "cuda" if torch.cuda.is_available() else "cpu"
24152416
precision = torch.float32
@@ -2421,6 +2422,7 @@ def load_model(
24212422
# torch_dtype=torch.bfloat16,
24222423
# attn_implementation="flash_attention_2",
24232424
# attn_implementation="sdpa",
2425+
cache_dir=cache_dir,
24242426
)
24252427
# model = model.to_bettertransformer()
24262428

@@ -2433,7 +2435,12 @@ def load_model(
24332435
extension = os.path.splitext(name)[-1] if os.path.isfile(name) else None
24342436

24352437
if name in whisper.available_models() or extension == ".pt":
2436-
return whisper.load_model(name, device=device, download_root=download_root, in_memory=in_memory)
2438+
return whisper.load_model(
2439+
name,
2440+
device=device,
2441+
download_root=os.path.join(download_root, "whisper") if download_root else None,
2442+
in_memory=in_memory
2443+
)
24372444

24382445
# Otherwise, assume transformers
24392446
if extension in [".ckpt", ".bin"]:
@@ -2446,7 +2453,11 @@ def load_model(
24462453
raise ImportError(f"If you are trying to download a HuggingFace model with {name}, please install first the transformers library")
24472454
from transformers.utils import cached_file
24482455

2449-
kwargs = dict(cache_dir=download_root, use_auth_token=None, revision=None)
2456+
kwargs = dict(
2457+
cache_dir=os.path.join(download_root, "huggingface", "hub") if download_root else None,
2458+
use_auth_token=None,
2459+
revision=None,
2460+
)
24502461
try:
24512462
model_path = cached_file(name, "pytorch_model.bin", **kwargs)
24522463
except OSError as err:

0 commit comments

Comments
 (0)