Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
# torchvision says we should use transforms.InterpolationMode, but does not export this properly
# so we have to rely on the automatic conversion from PIL.Image enum, with warning
transform_list.append(transforms.Resize(osize, interpolation=method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))

Expand Down
2 changes: 1 addition & 1 deletion data/custom_dataset_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def CreateDataset(opt):
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()

print("dataset [%s] was created" % (dataset.name()))
#print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset

Expand Down
2 changes: 1 addition & 1 deletion data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
#print(data_loader.name())
data_loader.initialize(opt)
return data_loader
2 changes: 1 addition & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load_network(self, network, network_label, epoch_label, save_dir=''):
if not os.path.isfile(save_path):
print('%s not exists yet!' % save_path)
if network_label == 'G':
raise('Generator must exist!')
raise(Exception('Generator must exist!'))
else:
#network.load_state_dict(torch.load(save_path))
try:
Expand Down
6 changes: 3 additions & 3 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_glo
elif netG == 'encoder':
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
else:
raise('generator not implemented!')
print(netG)
raise(Exception('generator not implemented!'))
#print(netG)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netG.cuda(gpu_ids[0])
Expand All @@ -46,7 +46,7 @@ def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_glo
def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
print(netD)
#print(netD)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netD.cuda(gpu_ids[0])
Expand Down
53 changes: 29 additions & 24 deletions models/pix2pixHD_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
Expand Down Expand Up @@ -108,35 +107,38 @@ def initialize(self, opt):
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None):
if self.opt.label_nc == 0:
input_label = label_map.data.cuda()
input_label = label_map
else:
# create one-hot vector for label map
if len(self.gpu_ids):
label_map = label_map.cuda()
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
input_label = self.Tensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.long(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
if len(self.gpu_ids):
input_label = input_label.cuda()
inst_map = inst_map.cuda()

# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
input_label = torch.cat((input_label, edge_map), dim=1)

# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
if real_image is not None and len(self.gpu_ids):
real_image = real_image.cuda()

# instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
if self.opt.load_features and len(self.gpu_ids):
feat_map = feat_map.cuda()
if self.opt.label_feat and len(self.gpu_ids):
inst_map = label_map.cuda()

return input_label, inst_map, real_image, feat_map
Expand Down Expand Up @@ -194,8 +196,8 @@ def forward(self, label, inst, image, feat, infer=False):

def inference(self, label, inst, image=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
image = image if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(label, inst, image)

# Fake Generation
if self.use_features:
Expand All @@ -209,10 +211,7 @@ def inference(self, label, inst, image=None):
else:
input_concat = input_label

if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
return fake_image

Expand All @@ -238,11 +237,14 @@ def sample_features(self, inst):
return feat_map

def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
if len(self.gpu_ids):
image = image.cuda()
inst = inst.cuda()
with torch.no_grad():
feat_map = self.netE.forward(image, inst).cpu()
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
Expand All @@ -254,13 +256,16 @@ def encode_features(self, image, inst):
idx = idx[num//2,:]
val = np.zeros((1, feat_num+1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].item()
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature

def get_edges(self, t):
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge = torch.ByteTensor(t.size())
if len(self.gpu_ids):
edge = edge.cuda()
edge = edge.zero_()
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
Expand Down
13 changes: 7 additions & 6 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def initialize(self):

self.initialized = True

def parse(self, save=True):
def parse(self, args=None, save=True, silent=False):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt = self.parser.parse_args(args=args)
self.opt.isTrain = self.isTrain # train or test

str_ids = self.opt.gpu_ids.split(',')
Expand All @@ -81,10 +81,11 @@ def parse(self, save=True):

args = vars(self.opt)

print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
if not silent:
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')

# save to the disk
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
Expand Down