@@ -2392,7 +2392,7 @@ def load_model(
23922392 backend : str, optional
23932393 Backend to use. Either "transformers" or "openai-whisper".
23942394 download_root : str, optional
2395- Root folder to download the model to. If None, use the default download root.
2395+ Root folder to download the model to. If None, use the default download root (typically: ~/.cache)
23962396 in_memory : bool, optional
23972397 Whether to preload the model weights into host memory.
23982398 """
@@ -2405,11 +2405,12 @@ def load_model(
24052405 name = f"openai/whisper-{ name } "
24062406 # TODO: use download_root
24072407 # TODO: does in_memory makes sense?
2408+ cache_dir = os .path .join (download_root , "huggingface" , "hub" ) if download_root else None ,
24082409 try :
2409- generation_config = transformers .GenerationConfig .from_pretrained (name )
2410+ generation_config = transformers .GenerationConfig .from_pretrained (name , cache_dir = cache_dir )
24102411 except OSError :
2411- generation_config = transformers .GenerationConfig .from_pretrained ("openai/whisper-tiny" )
2412- processor = transformers .WhisperProcessor .from_pretrained (name )
2412+ generation_config = transformers .GenerationConfig .from_pretrained ("openai/whisper-tiny" , cache_dir = cache_dir )
2413+ processor = transformers .WhisperProcessor .from_pretrained (name , cache_dir = cache_dir )
24132414 if device is None :
24142415 device = "cuda" if torch .cuda .is_available () else "cpu"
24152416 precision = torch .float32
@@ -2421,6 +2422,7 @@ def load_model(
24212422 # torch_dtype=torch.bfloat16,
24222423 # attn_implementation="flash_attention_2",
24232424 # attn_implementation="sdpa",
2425+ cache_dir = cache_dir ,
24242426 )
24252427 # model = model.to_bettertransformer()
24262428
@@ -2433,7 +2435,12 @@ def load_model(
24332435 extension = os .path .splitext (name )[- 1 ] if os .path .isfile (name ) else None
24342436
24352437 if name in whisper .available_models () or extension == ".pt" :
2436- return whisper .load_model (name , device = device , download_root = download_root , in_memory = in_memory )
2438+ return whisper .load_model (
2439+ name ,
2440+ device = device ,
2441+ download_root = os .path .join (download_root , "whisper" ) if download_root else None ,
2442+ in_memory = in_memory
2443+ )
24372444
24382445 # Otherwise, assume transformers
24392446 if extension in [".ckpt" , ".bin" ]:
@@ -2446,7 +2453,11 @@ def load_model(
24462453 raise ImportError (f"If you are trying to download a HuggingFace model with { name } , please install first the transformers library" )
24472454 from transformers .utils import cached_file
24482455
2449- kwargs = dict (cache_dir = download_root , use_auth_token = None , revision = None )
2456+ kwargs = dict (
2457+ cache_dir = os .path .join (download_root , "huggingface" , "hub" ) if download_root else None ,
2458+ use_auth_token = None ,
2459+ revision = None ,
2460+ )
24502461 try :
24512462 model_path = cached_file (name , "pytorch_model.bin" , ** kwargs )
24522463 except OSError as err :
0 commit comments