diff --git a/.gitignore b/.gitignore index 32e42a39..02c07d82 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# IDE related +.vscode/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index 4209c2d0..9e6dcc8f 100644 --- a/README.md +++ b/README.md @@ -34,9 +34,19 @@ Alternatively you can build your own dataset by setting up the following directo | | | └── B # Contains domain B images (i.e. Batman) ### 2. Train! + +Before training, make sure to startup the visdom server in another terminal. Otherwise, you will get HTTPConnection errors. It's as simple as: + ``` -./train --dataroot datasets// --cuda +visdom ``` + +Next, you can launch the actual training script. If your python version is not located at `/usr/bin/python3` (e.g., if you are using conda), you can delete the first line of `./train`. + +``` +python train.py --dataroot datasets// --cuda +``` + This command will start a training session using the images under the *dataroot/train* directory with the hyperparameters that showed best results according to CycleGAN authors. You are free to change those hyperparameters, see ```./train --help``` for a description of those. Both generators and discriminators weights will be saved under the output directory. @@ -53,7 +63,7 @@ You can also view the training progress as well as live output images by running ## Testing ``` -./test --dataroot datasets// --cuda +python test.py --dataroot datasets// --cuda ``` This command will take the images under the *dataroot/test* directory, run them through the generators and save the output under the *output/A* and *output/B* directories. As with train, some parameters like the weights to load, can be tweaked, see ```./test --help``` for more information. diff --git a/dataset.py b/dataset.py new file mode 100755 index 00000000..93d460f4 --- /dev/null +++ b/dataset.py @@ -0,0 +1,45 @@ +import glob +from collections.abc import Iterable +import os +import random +from typing import Optional, Callable + +from PIL import Image + +import torch +from torch.utils.data import Dataset +import torchvision.transforms as transforms + + +class ImageDataset(Dataset): + def __init__( + self, + root: str, + transforms_: Optional[transforms.Compose] = None, + unaligned: bool = True, + mode: str = 'train', + grayscale: bool = False + ) -> None: + self.transform: Callable[Image, torch.Tensor] = transforms.Compose(transforms_) + self.unaligned: bool = unaligned + self.grayscale: bool = grayscale + + self.files_A: Iterable[str] = sorted(glob.glob(os.path.join(root, f'{mode}/A') + '/*.*')) + self.files_B: Iterable[str] = sorted(glob.glob(os.path.join(root, f'{mode}/B') + '/*.*')) + + def __getitem__(self, index: int): + idx_A: int = index % len(self.files_A) + item_A: Image = Image.open(self.files_A[idx_A]) + + idx_B: int = random.randint(0, len(self.files_B) - 1) \ + if self.unaligned else (index % len(self.files_B)) + item_B: Image = Image.open(self.files_B[idx_B]) + + if self.grayscale: + item_A = item_A.convert('L') + item_B = item_B.convert('L') + + return dict(A=self.transform(item_A), B=self.transform(item_B)) + + def __len__(self): + return max(len(self.files_A), len(self.files_B)) diff --git a/datasets.py b/datasets.py deleted file mode 100755 index 8498ebf1..00000000 --- a/datasets.py +++ /dev/null @@ -1,28 +0,0 @@ -import glob -import random -import os - -from torch.utils.data import Dataset -from PIL import Image -import torchvision.transforms as transforms - -class ImageDataset(Dataset): - def __init__(self, root, transforms_=None, unaligned=False, mode='train'): - self.transform = transforms.Compose(transforms_) - self.unaligned = unaligned - - self.files_A = sorted(glob.glob(os.path.join(root, '%s/A' % mode) + '/*.*')) - self.files_B = sorted(glob.glob(os.path.join(root, '%s/B' % mode) + '/*.*')) - - def __getitem__(self, index): - item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)])) - - if self.unaligned: - item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])) - else: - item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)])) - - return {'A': item_A, 'B': item_B} - - def __len__(self): - return max(len(self.files_A), len(self.files_B)) \ No newline at end of file diff --git a/models.py b/models.py index b94e4893..433f2d00 100755 --- a/models.py +++ b/models.py @@ -1,85 +1,107 @@ import torch.nn as nn import torch.nn.functional as F + class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() - conv_block = [ nn.ReflectionPad2d(1), - nn.Conv2d(in_features, in_features, 3), - nn.InstanceNorm2d(in_features), - nn.ReLU(inplace=True), - nn.ReflectionPad2d(1), - nn.Conv2d(in_features, in_features, 3), - nn.InstanceNorm2d(in_features) ] + conv_block = [ + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + nn.InstanceNorm2d(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + nn.InstanceNorm2d(in_features) + ] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) + class Generator(nn.Module): def __init__(self, input_nc, output_nc, n_residual_blocks=9): super(Generator, self).__init__() - # Initial convolution block - model = [ nn.ReflectionPad2d(3), - nn.Conv2d(input_nc, 64, 7), - nn.InstanceNorm2d(64), - nn.ReLU(inplace=True) ] + # Initial convolution block + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + nn.InstanceNorm2d(64), + nn.ReLU(inplace=True) + ] # Downsampling in_features = 64 - out_features = in_features*2 + out_features = in_features * 2 for _ in range(2): - model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), - nn.InstanceNorm2d(out_features), - nn.ReLU(inplace=True) ] + model += [ + nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + nn.InstanceNorm2d(out_features), + nn.ReLU(inplace=True) + ] + in_features = out_features - out_features = in_features*2 + out_features = in_features * 2 # Residual blocks for _ in range(n_residual_blocks): model += [ResidualBlock(in_features)] # Upsampling - out_features = in_features//2 + out_features = in_features // 2 for _ in range(2): - model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), - nn.InstanceNorm2d(out_features), - nn.ReLU(inplace=True) ] + model += [ + nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), + nn.InstanceNorm2d(out_features), + nn.ReLU(inplace=True) + ] in_features = out_features - out_features = in_features//2 + out_features = in_features // 2 # Output layer - model += [ nn.ReflectionPad2d(3), - nn.Conv2d(64, output_nc, 7), - nn.Tanh() ] + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(64, output_nc, 7), + nn.Tanh() + ] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) + class Discriminator(nn.Module): def __init__(self, input_nc): super(Discriminator, self).__init__() # A bunch of convolutions one after another - model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1), - nn.LeakyReLU(0.2, inplace=True) ] - - model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1), - nn.InstanceNorm2d(128), - nn.LeakyReLU(0.2, inplace=True) ] - - model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1), - nn.InstanceNorm2d(256), - nn.LeakyReLU(0.2, inplace=True) ] - - model += [ nn.Conv2d(256, 512, 4, padding=1), - nn.InstanceNorm2d(512), - nn.LeakyReLU(0.2, inplace=True) ] + model = [ + nn.Conv2d(input_nc, 64, 4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True) + ] + + model += [ + nn.Conv2d(64, 128, 4, stride=2, padding=1), + nn.InstanceNorm2d(128), + nn.LeakyReLU(0.2, inplace=True) + ] + + model += [ + nn.Conv2d(128, 256, 4, stride=2, padding=1), + nn.InstanceNorm2d(256), + nn.LeakyReLU(0.2, inplace=True) + ] + + model += [ + nn.Conv2d(256, 512, 4, padding=1), + nn.InstanceNorm2d(512), + nn.LeakyReLU(0.2, inplace=True) + ] # FCN classification layer model += [nn.Conv2d(512, 1, 4, padding=1)] @@ -87,6 +109,6 @@ def __init__(self, input_nc): self.model = nn.Sequential(*model) def forward(self, x): - x = self.model(x) + x = self.model(x) # Average pooling and flatten - return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1) \ No newline at end of file + return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..a2be9c0e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy==1.22.1 +Pillow==9.1.0 +torch==1.10.2 +torchvision==0.11.3 +visdom==0.1.8.9 diff --git a/test b/test.py similarity index 52% rename from test rename to test.py index 7bde6501..865ba7cd 100755 --- a/test +++ b/test.py @@ -1,35 +1,41 @@ -#!/usr/bin/python3 - import argparse -import sys import os +import sys +import torch +from torch.autograd import Variable +from torch.utils.data import DataLoader import torchvision.transforms as transforms from torchvision.utils import save_image -from torch.utils.data import DataLoader -from torch.autograd import Variable -import torch -from models import Generator from datasets import ImageDataset +from models import Generator parser = argparse.ArgumentParser() -parser.add_argument('--batchSize', type=int, default=1, help='size of the batches') -parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset') -parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') -parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') -parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)') -parser.add_argument('--cuda', action='store_true', help='use GPU computation') -parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') -parser.add_argument('--generator_A2B', type=str, default='output/netG_A2B.pth', help='A2B generator checkpoint file') -parser.add_argument('--generator_B2A', type=str, default='output/netG_B2A.pth', help='B2A generator checkpoint file') +parser.add_argument('--batchSize', type=int, default=1, + help='size of the batches') +parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', + help='root directory of the dataset') +parser.add_argument('--input_nc', type=int, default=1, + help='number of channels of input data') +parser.add_argument('--output_nc', type=int, default=1, + help='number of channels of output data') +parser.add_argument('--size', type=int, default=256, + help='size of the data (squared assumed)') +parser.add_argument('--cuda', action='store_true', + help='use GPU computation') +parser.add_argument('--n_cpu', type=int, default=8, + help='number of cpu threads to use during batch generation') +parser.add_argument('--generator_A2B', type=str, default='output/netG_A2B.pth', + help='A2B generator checkpoint file') +parser.add_argument('--generator_B2A', type=str, default='output/netG_B2A.pth', + help='B2A generator checkpoint file') opt = parser.parse_args() print(opt) if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") -###### Definition of variables ###### # Networks netG_A2B = Generator(opt.input_nc, opt.output_nc) netG_B2A = Generator(opt.output_nc, opt.input_nc) @@ -52,13 +58,13 @@ input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size) # Dataset loader -transforms_ = [ transforms.ToTensor(), - transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ] -dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, mode='test'), - batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu) -################################### +transforms_ = [ + transforms.ToTensor(), + transforms.Normalize(tuple([0.5] * opt.output_nc), tuple([0.5] * opt.output_nc)) +] -###### Testing###### +dataset = ImageDataset(opt.dataroot, transforms_=transforms_, mode='test', grayscale=(opt.input_nc == 1)) +dataloader = DataLoader(dataset, batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu) # Create output dirs if they don't exist if not os.path.exists('output/A'): @@ -72,14 +78,13 @@ real_B = Variable(input_B.copy_(batch['B'])) # Generate output - fake_B = 0.5*(netG_A2B(real_A).data + 1.0) - fake_A = 0.5*(netG_B2A(real_B).data + 1.0) + fake_B = 0.5 * (netG_A2B(real_A).data + 1.0) + fake_A = 0.5 * (netG_B2A(real_B).data + 1.0) # Save image files - save_image(fake_A, 'output/A/%04d.png' % (i+1)) - save_image(fake_B, 'output/B/%04d.png' % (i+1)) + save_image(fake_A, f'output/A/{(i + 1):04d}.png') + save_image(fake_B, f'output/B/{(i + 1):04d}.png') - sys.stdout.write('\rGenerated images %04d of %04d' % (i+1, len(dataloader))) + sys.stdout.write(f'\rGenerated images {(i + 1):04d} of {len(dataloader):04d}') sys.stdout.write('\n') -################################### diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..7a40b718 --- /dev/null +++ b/tox.ini @@ -0,0 +1,11 @@ +[flake8] +ignore = W293, C901 +max-line-length = 110 +exclude = + # No need to traverse these directories + .git + # Jupyter Notebooks + *.ipynb + # __init__ files break rule F401 + */__init__.py +max-complexity = 10 \ No newline at end of file diff --git a/train b/train.py similarity index 52% rename from train rename to train.py index 95160b51..b09ef2a4 100755 --- a/train +++ b/train.py @@ -1,42 +1,49 @@ -#!/usr/bin/python3 - import argparse import itertools -import torchvision.transforms as transforms -from torch.utils.data import DataLoader -from torch.autograd import Variable -from PIL import Image import torch +from torch.autograd import Variable +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from datasets import ImageDataset from models import Generator from models import Discriminator from utils import ReplayBuffer from utils import LambdaLR from utils import Logger from utils import weights_init_normal -from datasets import ImageDataset parser = argparse.ArgumentParser() -parser.add_argument('--epoch', type=int, default=0, help='starting epoch') -parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') -parser.add_argument('--batchSize', type=int, default=1, help='size of the batches') -parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset') -parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate') -parser.add_argument('--decay_epoch', type=int, default=100, help='epoch to start linearly decaying the learning rate to 0') -parser.add_argument('--size', type=int, default=256, help='size of the data crop (squared assumed)') -parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data') -parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data') -parser.add_argument('--cuda', action='store_true', help='use GPU computation') -parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') +parser.add_argument('--epoch', type=int, default=0, + help='starting epoch') +parser.add_argument('--n_epochs', type=int, default=200, + help='number of epochs of training') +parser.add_argument('--batchSize', type=int, default=1, + help='size of the batches') +parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', + help='root directory of the dataset') +parser.add_argument('--lr', type=float, default=0.0002, + help='initial learning rate') +parser.add_argument('--decay_epoch', type=int, default=100, + help='epoch to start linearly decaying the learning rate to 0') +parser.add_argument('--size', type=int, default=256, + help='size of the data crop (squared assumed)') +parser.add_argument('--input_nc', type=int, default=1, + help='number of channels of input data') +parser.add_argument('--output_nc', type=int, default=1, + help='number of channels of output data') +parser.add_argument('--cuda', action='store_true', + help='use GPU computation') +parser.add_argument('--n_cpu', type=int, default=8, + help='number of cpu threads to use during batch generation') opt = parser.parse_args() print(opt) if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") - -###### Definition of variables ###### -# Networks + +# Creating the 4 networks netG_A2B = Generator(opt.input_nc, opt.output_nc) netG_B2A = Generator(opt.output_nc, opt.input_nc) netD_A = Discriminator(opt.input_nc) @@ -48,25 +55,34 @@ netD_A.cuda() netD_B.cuda() +# Weight initialization netG_A2B.apply(weights_init_normal) netG_B2A.apply(weights_init_normal) netD_A.apply(weights_init_normal) netD_B.apply(weights_init_normal) -# Lossess +# Losses initialization criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() # Optimizers & LR schedulers -optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), - lr=opt.lr, betas=(0.5, 0.999)) +optimizer_G = torch.optim.Adam( + itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), + lr=opt.lr, betas=(0.5, 0.999) +) optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=opt.lr, betas=(0.5, 0.999)) -lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) -lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) -lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step) +lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR( + optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step +) +lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR( + optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step +) +lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR( + optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step +) # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor @@ -79,35 +95,39 @@ fake_B_buffer = ReplayBuffer() # Dataset loader -transforms_ = [ transforms.Resize(int(opt.size*1.12), Image.BICUBIC), - transforms.RandomCrop(opt.size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ] -dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True), - batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) - -# Loss plot +transforms_ = [ + transforms.Resize(int(opt.size * 1.12), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.RandomCrop(opt.size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(tuple([0.5] * opt.output_nc), tuple([0.5] * opt.output_nc)) +] + +dataset = ImageDataset(opt.dataroot, transforms_=transforms_, unaligned=True, grayscale=(opt.input_nc == 1)) +dataloader = DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=opt.n_cpu) + +# Create logger for loss plots in Visdom logger = Logger(opt.n_epochs, len(dataloader)) -################################### -###### Training ###### +# Training for epoch in range(opt.epoch, opt.n_epochs): for i, batch in enumerate(dataloader): # Set model input real_A = Variable(input_A.copy_(batch['A'])) real_B = Variable(input_B.copy_(batch['B'])) - - ###### Generators A2B and B2A ###### + + # ###################### + # Generators A2B and B2A + # ###################### optimizer_G.zero_grad() # Identity loss # G_A2B(B) should equal B if real B is fed same_B = netG_A2B(real_B) - loss_identity_B = criterion_identity(same_B, real_B)*5.0 + loss_identity_B = criterion_identity(same_B, real_B) * 5.0 # G_B2A(A) should equal A if real A is fed same_A = netG_B2A(real_A) - loss_identity_A = criterion_identity(same_A, real_A)*5.0 + loss_identity_A = criterion_identity(same_A, real_A) * 5.0 # GAN loss fake_B = netG_A2B(real_A) @@ -120,19 +140,21 @@ # Cycle loss recovered_A = netG_B2A(fake_B) - loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0 + loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0 recovered_B = netG_A2B(fake_A) - loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0 + loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0 # Total loss - loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB + loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B \ + + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB loss_G.backward() optimizer_G.step() - ################################### - ###### Discriminator A ###### + # ###################### + # Discriminator A + # ###################### optimizer_D_A.zero_grad() # Real loss @@ -145,13 +167,14 @@ loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss - loss_D_A = (loss_D_real + loss_D_fake)*0.5 + loss_D_A = (loss_D_real + loss_D_fake) * 0.5 loss_D_A.backward() optimizer_D_A.step() - ################################### - ###### Discriminator B ###### + # ###################### + # Discriminator B + # ###################### optimizer_D_B.zero_grad() # Real loss @@ -164,16 +187,29 @@ loss_D_fake = criterion_GAN(pred_fake, target_fake) # Total loss - loss_D_B = (loss_D_real + loss_D_fake)*0.5 + loss_D_B = (loss_D_real + loss_D_fake) * 0.5 loss_D_B.backward() optimizer_D_B.step() - ################################### - # Progress report (http://localhost:8097) - logger.log({'loss_G': loss_G, 'loss_G_identity': (loss_identity_A + loss_identity_B), 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), - 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), 'loss_D': (loss_D_A + loss_D_B)}, - images={'real_A': real_A, 'real_B': real_B, 'fake_A': fake_A, 'fake_B': fake_B}) + # ###################### + # Progress report + # ###################### + logger.log( + { + 'loss_G': loss_G, + 'loss_G_identity': (loss_identity_A + loss_identity_B), + 'loss_G_GAN': (loss_GAN_A2B + loss_GAN_B2A), + 'loss_G_cycle': (loss_cycle_ABA + loss_cycle_BAB), + 'loss_D': (loss_D_A + loss_D_B) + }, + images={ + 'real_A': real_A, + 'real_B': real_B, + 'fake_A': fake_A, + 'fake_B': fake_B + } + ) # Update learning rates lr_scheduler_G.step() @@ -185,4 +221,3 @@ torch.save(netG_B2A.state_dict(), 'output/netG_B2A.pth') torch.save(netD_A.state_dict(), 'output/netD_A.pth') torch.save(netD_B.state_dict(), 'output/netD_B.pth') -################################### diff --git a/utils.py b/utils.py index 54ab8357..5ef66aed 100755 --- a/utils.py +++ b/utils.py @@ -1,118 +1,152 @@ -import random -import time import datetime +import random import sys +import time +from typing import Dict, List -from torch.autograd import Variable -import torch -from visdom import Visdom import numpy as np +from visdom import Visdom + +import torch +from torch import nn +from torch.autograd import Variable + + +def tensor2image(tensor: torch.Tensor) -> np.ndarray: + + image: np.ndarray = 127.5 * (tensor[0].cpu().float().numpy() + 1.0) -def tensor2image(tensor): - image = 127.5*(tensor[0].cpu().float().numpy() + 1.0) if image.shape[0] == 1: - image = np.tile(image, (3,1,1)) + image = np.tile(image, (3, 1, 1)) + return image.astype(np.uint8) + class Logger(): - def __init__(self, n_epochs, batches_epoch): - self.viz = Visdom() - self.n_epochs = n_epochs - self.batches_epoch = batches_epoch - self.epoch = 1 - self.batch = 1 - self.prev_time = time.time() - self.mean_period = 0 - self.losses = {} - self.loss_windows = {} - self.image_windows = {} + def __init__(self, n_epochs: int, batches_epoch: int): + + self.viz: Visdom = Visdom(server='127.0.0.1', port=8097) + self.n_epochs: int = n_epochs + self.batches_epoch: int = batches_epoch + self.epoch: int = 1 + self.batch: int = 1 + self.prev_time: time.time = time.time() + self.mean_period: int = 0 + self.losses: Dict = {} + self.loss_windows: Dict = {} + self.image_windows: Dict = {} - def log(self, losses=None, images=None): + def log(self, losses: Dict = None, images: Dict = None) -> None: + self.mean_period += (time.time() - self.prev_time) self.prev_time = time.time() - sys.stdout.write('\rEpoch %03d/%03d [%04d/%04d] -- ' % (self.epoch, self.n_epochs, self.batch, self.batches_epoch)) + sys.stdout.write(f'Epoch {self.epoch:03d}/{self.n_epochs:03d} \ + [{self.batch:04d}/{self.batches_epoch:04d} -- ') for i, loss_name in enumerate(losses.keys()): - if loss_name not in self.losses: - self.losses[loss_name] = losses[loss_name].data[0] - else: - self.losses[loss_name] += losses[loss_name].data[0] - if (i+1) == len(losses.keys()): - sys.stdout.write('%s: %.4f -- ' % (loss_name, self.losses[loss_name]/self.batch)) + if loss_name not in self.losses: + self.losses[loss_name] = losses[loss_name].data.item() else: - sys.stdout.write('%s: %.4f | ' % (loss_name, self.losses[loss_name]/self.batch)) + self.losses[loss_name] += losses[loss_name].data.item() + + sys.stdout.write(f'{loss_name}: {(self.losses[loss_name] / self.batches):.4f}') + sys.stdout.write(' -- ' if (i + 1) == len(losses.keys()) else ' | ') - batches_done = self.batches_epoch*(self.epoch - 1) + self.batch - batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch - sys.stdout.write('ETA: %s' % (datetime.timedelta(seconds=batches_left*self.mean_period/batches_done))) + batches_done = self.batches_epoch * (self.epoch - 1) + self.batch + batches_left = self.batches_epoch * (self.n_epochs - self.epoch) + self.batches_epoch - self.batch + + eta = datetime.timedelta(seconds=batches_left * self.mean_period / batches_done) + sys.stdout.write(f'ETA: {eta}') # Draw images for image_name, tensor in images.items(): if image_name not in self.image_windows: - self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), opts={'title':image_name}) + self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), + opts=dict(title=image_name)) else: - self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], opts={'title':image_name}) + self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], + opts=dict(title=image_name)) # End of epoch if (self.batch % self.batches_epoch) == 0: # Plot losses for loss_name, loss in self.losses.items(): if loss_name not in self.loss_windows: - self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), - opts={'xlabel': 'epochs', 'ylabel': loss_name, 'title': loss_name}) + self.loss_windows[loss_name] = self.viz.line( + X=np.array([self.epoch]), + Y=np.array([loss / self.batch]), + opts=dict(xlabel='epochs', ylabel=loss_name, title=loss_name) + ) else: - self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), win=self.loss_windows[loss_name], update='append') + self.viz.line( + X=np.array([self.epoch]), + Y=np.array([loss / self.batch]), + win=self.loss_windows[loss_name], update='append' + ) + # Reset losses for next epoch self.losses[loss_name] = 0.0 self.epoch += 1 self.batch = 1 sys.stdout.write('\n') + else: self.batch += 1 - class ReplayBuffer(): - def __init__(self, max_size=50): + + def __init__(self, max_size: int = 50): assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.' - self.max_size = max_size + self.max_size: int = max_size self.data = [] - def push_and_pop(self, data): - to_return = [] + def push_and_pop(self, data: torch.Tensor) -> Variable: + + to_return: List = [] + for element in data.data: + element = torch.unsqueeze(element, 0) + if len(self.data) < self.max_size: self.data.append(element) to_return.append(element) else: - if random.uniform(0,1) > 0.5: - i = random.randint(0, self.max_size-1) + if random.uniform(0, 1) > 0.5: + i = random.randint(0, self.max_size - 1) to_return.append(self.data[i].clone()) self.data[i] = element else: to_return.append(element) + return Variable(torch.cat(to_return)) + class LambdaLR(): - def __init__(self, n_epochs, offset, decay_start_epoch): - assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!" - self.n_epochs = n_epochs - self.offset = offset - self.decay_start_epoch = decay_start_epoch + def __init__(self, n_epochs: int, offset: int, decay_start_epoch: int) -> None: + + assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!" + self.n_epochs: int = n_epochs + self.offset: int = offset + self.decay_start_epoch: int = decay_start_epoch - def step(self, epoch): - return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch) + def step(self, epoch: int) -> float: + return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) \ + / (self.n_epochs - self.decay_start_epoch) + + +def weights_init_normal(m: nn.Module) -> None: + + classname: str = m.__class__.__name__ -def weights_init_normal(m): - classname = m.__class__.__name__ if classname.find('Conv') != -1: - torch.nn.init.normal(m.weight.data, 0.0, 0.02) + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal(m.weight.data, 1.0, 0.02) + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant(m.bias.data, 0.0) -