File tree Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Expand file tree Collapse file tree 1 file changed +11
-4
lines changed Original file line number Diff line number Diff line change 1+ import os
12import torch
23import tempfile
34import warnings
@@ -121,10 +122,16 @@ def __call__(self,
121122def init_jit_model (model_url : str ,
122123 device : torch .device = torch .device ('cpu' )):
123124 torch .set_grad_enabled (False )
124- with tempfile .NamedTemporaryFile ('wb' , suffix = '.model' ) as f :
125+
126+ model_dir = os .path .join (os .path .dirname (__file__ ), "model" )
127+ os .makedirs (model_dir , exist_ok = True )
128+ model_path = os .path .join (model_dir , os .path .basename (model_url ))
129+
130+ if not os .path .isfile (model_path ):
125131 torch .hub .download_url_to_file (model_url ,
126- f . name ,
132+ model_path ,
127133 progress = True )
128- model = torch .jit .load (f .name , map_location = device )
129- model .eval ()
134+
135+ model = torch .jit .load (model_path , map_location = device )
136+ model .eval ()
130137 return model , Decoder (model .labels )
You can’t perform that action at this time.
0 commit comments