diff --git a/bark/generation.py b/bark/generation.py index 54f98709..a3f48194 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -209,7 +209,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): if not os.path.exists(ckpt_path): logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") _download(model_info["repo_id"], model_info["file_name"]) - checkpoint = torch.load(ckpt_path, map_location=device) + checkpoint = torch.load(ckpt_path, weights_only=False, map_location=device) # this is a hack model_args = checkpoint["model_args"] if "input_vocab_size" not in model_args: