diff --git a/examples/tutorials/dermoscopy.ipynb b/examples/tutorials/dermoscopy.ipynb index e1d6b66..df222f5 100644 --- a/examples/tutorials/dermoscopy.ipynb +++ b/examples/tutorials/dermoscopy.ipynb @@ -235,10 +235,12 @@ " ]\n", " save_dir = os.path.join(model_args.root_save_dir, \"results\"+'_'.join([t.format(v) for (t, v) in setup]))\n", " \n", - " # Change map_location if training on GPU.\n", - " net = torch.load(os.path.join(save_dir, 'net.p'), map_location='cpu')\n", - " # Change to True if training on GPU.\n", - " net.cuda_available=False\n", + " if cuda_available:\n", + " net = torch.load(os.path.join(save_dir, 'net.p'), map_location='cpu')\n", + " net.cuda_available=False\n", + " else:\n", + " net = torch.load(os.path.join(save_dir, 'net.p'), map_location='cuda:0': 'cpu'})\n", + " net.cuda_available=True\n", " \n", " return net" ]