diff --git a/models/networks.py b/models/networks.py index 476ccb3d..2a9488f1 100755 --- a/models/networks.py +++ b/models/networks.py @@ -113,8 +113,11 @@ def __call__(self, input, target_is_real): class VGGLoss(nn.Module): def __init__(self, gpu_ids): - super(VGGLoss, self).__init__() - self.vgg = Vgg19().cuda() + super(VGGLoss, self).__init__() + if len(gpu_ids)>0: + self.vgg = Vgg19().cuda() + else: + self.vgg = Vgg19() self.criterion = nn.L1Loss() self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]