Skip to content

Commit 987087c

Browse files
- Fix bug: In windows, PermissionError as model file opens in 'wb' mode not 'wb+'.
- Avoid model download at every run.
1 parent c2eba60 commit 987087c

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
import tempfile
34
import warnings
@@ -121,10 +122,16 @@ def __call__(self,
121122
def init_jit_model(model_url: str,
122123
device: torch.device = torch.device('cpu')):
123124
torch.set_grad_enabled(False)
124-
with tempfile.NamedTemporaryFile('wb', suffix='.model') as f:
125+
126+
model_dir = os.path.join(os.path.dirname(__file__), "model")
127+
os.makedirs(model_dir, exist_ok=True)
128+
model_path = os.path.join(model_dir, os.path.basename(model_url))
129+
130+
if not os.path.isfile(model_path):
125131
torch.hub.download_url_to_file(model_url,
126-
f.name,
132+
model_path,
127133
progress=True)
128-
model = torch.jit.load(f.name, map_location=device)
129-
model.eval()
134+
135+
model = torch.jit.load(model_path, map_location=device)
136+
model.eval()
130137
return model, Decoder(model.labels)

0 commit comments

Comments
 (0)