Skip to content

Commit 8e9080e

Browse files
committed
Initial commit
0 parents  commit 8e9080e

File tree

8 files changed

+494
-0
lines changed

8 files changed

+494
-0
lines changed

compute_mean_std.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
from torchvision import transforms, datasets
3+
import numpy as np
4+
5+
6+
def compute_mean_std(path_dataset):
7+
"""
8+
Compute mean and standard deviation of an image dataset.
9+
Acknowledgment : http://forums.fast.ai/t/image-normalization-in-pytorch/7534
10+
"""
11+
transform = transforms.Compose([
12+
transforms.Resize(224),
13+
transforms.ToTensor()
14+
])
15+
16+
dataset = datasets.ImageFolder(root=path_dataset,
17+
transform=transform)
18+
# Choose a large batch size to better approximate. Optimally load the dataset entirely on memory.
19+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4096, shuffle=False, num_workers=4)
20+
21+
pop_mean = []
22+
pop_std = []
23+
24+
for i, data in enumerate(data_loader, 0):
25+
# shape (batch_size, 3, height, width)
26+
numpy_image = data[0].numpy()
27+
28+
# shape (3,) -> 3 channels
29+
batch_mean = np.mean(numpy_image, axis=(0, 2, 3))
30+
batch_std = np.std(numpy_image, axis=(0, 2, 3))
31+
32+
pop_mean.append(batch_mean)
33+
pop_std.append(batch_std)
34+
35+
# shape (num_iterations, 3) -> (mean across 0th axis) -> shape (3,)
36+
pop_mean = np.array(pop_mean).mean(axis=0)
37+
pop_std = np.array(pop_std).mean(axis=0)
38+
39+
values = {
40+
'mean': pop_mean,
41+
'std': pop_std
42+
}
43+
44+
return values
45+
46+
47+
def main():
48+
mean_std = {}
49+
for dataset in ['amazon', 'dslr', 'webcam']:
50+
# Construct path
51+
dataset_path = './data/%s/images' % dataset
52+
values = compute_mean_std(dataset_path)
53+
# Add values to dict
54+
mean_std[dataset] = values
55+
56+
print(mean_std)
57+
58+
if __name__ == '__main__':
59+
main()

config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Paper: In the training phase, we set the batch size to 128,
2+
# base learning rate to 10−3, weight decay to 5×10−4, and momentum to 0.9
3+
4+
lr = 1e-3
5+
decay = 5e-4
6+
momentum = 0.9
7+
batch_size = 128
8+
epochs = 20
9+
n_classes = 31
10+
lambda_coral = 0

coral.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import numpy as np
3+
4+
5+
def coral(source, target):
6+
7+
d = source.size(1) # dim vector
8+
9+
source_c = compute_covariance(source)
10+
target_c = compute_covariance(target)
11+
12+
loss = torch.sum(torch.mul((source_c - target_c), (source_c - target_c)))
13+
14+
loss = loss / (4 * d * d)
15+
return loss
16+
17+
18+
def compute_covariance(input_data):
19+
"""
20+
Compute Covariance matrix of the input data
21+
"""
22+
n = input_data.size(0) # batch_size
23+
24+
id_row = torch.ones(n).resize(1, n)
25+
sum_column = torch.mm(id_row, input_data)
26+
mean_column = torch.div(sum_column, n)
27+
term_mul_2 = torch.mm(mean_column.t(), mean_column)
28+
d_t_d = torch.mm(input_data.t(), input_data)
29+
c = torch.add(d_t_d, (-1 * term_mul_2)) * 1 / (n - 1)
30+
31+
return c

data_loader.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from torchvision import transforms, datasets
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
6+
7+
def get_loader(name_dataset, batch_size, train=True):
8+
9+
# Computed with compute_mean_std.py
10+
mean_std = {
11+
'amazon': {
12+
'mean': [0.79235494, 0.7862071 , 0.78418255],
13+
'std': [0.31496558, 0.3174693 , 0.3193569 ]
14+
},
15+
'dslr': {
16+
'mean': [0.47086468, 0.44865608, 0.40637794],
17+
'std': [0.20395322, 0.19204104, 0.1996422 ]
18+
},
19+
'webcam': {
20+
'mean': [0.6119875 , 0.6187739 , 0.61730677],
21+
'std': [0.25063968, 0.25554898, 0.25773206]
22+
}
23+
}
24+
25+
data_transform = transforms.Compose([
26+
transforms.Scale(224),
27+
transforms.ToTensor(),
28+
transforms.Normalize(mean=mean_std[name_dataset]['mean'],
29+
std=mean_std[name_dataset]['std'])
30+
])
31+
32+
dataset = datasets.ImageFolder(root='./data/%s/images' % name_dataset,
33+
transform=data_transform)
34+
dataset_loader = torch.utils.data.DataLoader(dataset,
35+
batch_size=batch_size, shuffle=train,
36+
num_workers=4)
37+
return dataset_loader
38+

loading.ipynb

Lines changed: 125 additions & 0 deletions
Large diffs are not rendered by default.

model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from torchvision.models import alexnet
2+
import torch.nn as nn
3+
4+
5+
class Net(nn.Module):
6+
7+
def __init__(self, num_classes, pretrained=False):
8+
super(Net, self).__init__()
9+
10+
# check https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py
11+
self.model = alexnet(pretrained=pretrained, num_classes=num_classes)
12+
13+
# if we want to feed 448x448 images
14+
# self.model.avgpool = nn.AdaptiveAvgPool2d(1)
15+
16+
# In case we want to apply the loss to any other layer than the last
17+
# we need a forward hook on that layer
18+
# def save_features_layer_x(module, input, output):
19+
# self.layer_x = output
20+
21+
# This is a forward hook. Is executed each time forward is executed
22+
# self.model.layer4.register_forward_hook(save_features_layer_x)
23+
24+
def forward(self, x):
25+
out = self.model(x)
26+
return out # , self.layer_x

train.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from torchvision.models import alexnet
2+
import torch
3+
import torch.nn as nn
4+
from torch.autograd import Variable
5+
import torch.nn.functional as F
6+
from tqdm import tqdm
7+
8+
from data_loader import get_loader
9+
import config
10+
from utils import accuracy, Tracker
11+
from coral import coral
12+
13+
14+
def train(model, optimizer, source_loader, target_loader, epoch=0):
15+
16+
tracker = Tracker()
17+
model.train()
18+
tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99}
19+
20+
# Trackers to monitor classification and CORAL loss
21+
classification_loss_tracker = tracker.track('classification_loss', tracker_class(**tracker_params))
22+
coral_loss_tracker = tracker.track('CORAL_loss', tracker_class(**tracker_params))
23+
24+
min_n_batches = min(len(source_loader), len(target_loader))
25+
26+
tq = tqdm(range(min_n_batches), desc='{} E{:03d}'.format('Training + Adaptation', epoch), ncols=0)
27+
28+
for _ in tq:
29+
30+
source_data, source_label = next(iter(source_loader))
31+
target_data, _ = next(iter(target_loader)) # Unsupervised Domain Adaptation
32+
33+
source_data, source_label = Variable(source_data), Variable(source_label)
34+
target_data = Variable(target_data)
35+
36+
optimizer.zero_grad()
37+
38+
out_source = model(source_data)
39+
out_target = model(target_data)
40+
41+
classification_loss = F.cross_entropy(out_source, source_label)
42+
43+
# This is where the magic happens
44+
coral_loss = coral(out_source, out_target)
45+
composite_loss = classification_loss + config.lambda_coral * coral_loss
46+
47+
composite_loss.backward()
48+
optimizer.step()
49+
50+
classification_loss_tracker.append(classification_loss.item())
51+
coral_loss_tracker.append(coral_loss.item())
52+
fmt = '{:.4f}'.format
53+
tq.set_postfix(classification_loss=fmt(classification_loss_tracker.mean.value),
54+
coral_loss=fmt(coral_loss_tracker.mean.value))
55+
56+
57+
def evaluate(model, data_loader, dataset_name, epoch=0):
58+
model.eval()
59+
60+
tracker = Tracker()
61+
tracker_class, tracker_params = tracker.MeanMonitor, {}
62+
acc_tracker = tracker.track('accuracy', tracker_class(**tracker_params))
63+
64+
loader = tqdm(data_loader, desc='{} E{:03d}'.format('Evaluating on %s' % dataset_name, epoch), ncols=0)
65+
66+
accuracies = []
67+
with torch.no_grad():
68+
for target_data, target_label in loader:
69+
target_data = Variable(target_data)
70+
target_label = Variable(target_label)
71+
72+
output = model(target_data)
73+
74+
accuracies.append(accuracy(output, target_label))
75+
76+
acc_tracker.append(sum(accuracies)/len(accuracies))
77+
fmt = '{:.4f}'.format
78+
loader.set_postfix(accuracy=fmt(acc_tracker.mean.value))
79+
80+
81+
def main():
82+
83+
source_train_loader = get_loader(name_dataset='amazon', batch_size=config.batch_size, train=True)
84+
target_train_loader = get_loader(name_dataset='webcam', batch_size=config.batch_size, train=True)
85+
86+
source_evaluate_loader = get_loader(name_dataset='amazon', batch_size=config.batch_size, train=False)
87+
target_evaluate_loader = get_loader(name_dataset='webcam', batch_size=config.batch_size, train=False)
88+
89+
n_classes = len(source_train_loader.dataset.classes)
90+
91+
# ~ Paper : "We initialized the other layers with the parameters pre-trained on ImageNet"
92+
# check https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py
93+
model = alexnet(pretrained=True)
94+
# ~ Paper : The dimension of last fully connected layer (fc8) was set to the number of categories (31)
95+
model.classifier[6] = nn.Linear(4096, config.n_classes)
96+
# ~ Paper : and initialized with N(0, 0.005)
97+
torch.nn.init.normal_(model.classifier[6].weight, mean=0, std=5e-3)
98+
99+
# Initialize bias to small constant number (http://cs231n.github.io/neural-networks-2/#init)
100+
model.classifier[6].bias.data.fill_(0.01)
101+
102+
# ~ Paper : "The learning rate of fc8 is set to 10 times the other layers as it was training from scratch."
103+
optimizer = torch.optim.SGD([
104+
{'params': model.features.parameters()},
105+
{'params': model.classifier[:6].parameters()},
106+
# fc8 -> 7th element (index 6) in the Sequential block
107+
{'params': model.classifier[6].parameters(), 'lr': 10 * config.lr}
108+
], lr=config.lr, momentum=config.momentum) # if not specified, the default lr is used
109+
110+
for i in range(config.epochs):
111+
train(model, optimizer, source_train_loader, target_train_loader, i)
112+
evaluate(model, source_evaluate_loader, 'source', i)
113+
evaluate(model, target_evaluate_loader, 'target', i)
114+
115+
116+
if __name__ == '__main__':
117+
main()

utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import torch
4+
5+
6+
def imshow(image_tensor, mean, std, title=None):
7+
"""
8+
Imshow for normalized Tensors.
9+
Useful to visualize data from data loader
10+
"""
11+
12+
image = image_tensor.numpy().transpose((1, 2, 0))
13+
image = std * image + mean
14+
image = np.clip(image, 0, 1)
15+
plt.imshow(image)
16+
if title is not None:
17+
plt.title(title)
18+
plt.pause(0.001) # pause a bit so that plots are updated
19+
20+
21+
def accuracy(output, target):
22+
23+
_, predicted = torch.max(output.data, 1)
24+
total = target.size(0)
25+
correct = (predicted == target).sum().item()
26+
accuracy = correct/total
27+
28+
return accuracy
29+
30+
31+
class Tracker:
32+
33+
def __init__(self):
34+
self.data = {}
35+
36+
def track(self, name, *monitors):
37+
l = Tracker.ListStorage(monitors)
38+
self.data.setdefault(name, []).append(l)
39+
return l
40+
41+
def to_dict(self):
42+
return {k: list(map(list, v)) for k, v in self.data.items()}
43+
44+
class ListStorage:
45+
def __init__(self, monitors=[]):
46+
self.data = []
47+
self.monitors = monitors
48+
for monitor in self.monitors:
49+
setattr(self, monitor.name, monitor)
50+
51+
def append(self, item):
52+
for monitor in self.monitors:
53+
monitor.update(item)
54+
self.data.append(item)
55+
56+
def __iter__(self):
57+
return iter(self.data)
58+
59+
class MeanMonitor:
60+
name = 'mean'
61+
62+
def __init__(self):
63+
self.n = 0
64+
self.total = 0
65+
66+
def update(self, value):
67+
self.total += value
68+
self.n += 1
69+
70+
@property
71+
def value(self):
72+
return self.total / self.n
73+
74+
class MovingMeanMonitor:
75+
name = 'mean'
76+
77+
def __init__(self, momentum=0.9):
78+
self.momentum = momentum
79+
self.first = True
80+
self.value = None
81+
82+
def update(self, value):
83+
if self.first:
84+
self.value = value
85+
self.first = False
86+
else:
87+
m = self.momentum
88+
self.value = m * self.value + (1 - m) * value

0 commit comments

Comments
 (0)