Skip to content

Commit 910dd3b

Browse files
authored
Merge pull request #731 from marrlab/list_inner_product
stable loss agg, keep batch structure, dial, List inner product
2 parents 5e5820b + 3845d43 commit 910dd3b

File tree

8 files changed

+40
-25
lines changed

8 files changed

+40
-25
lines changed

domainlab/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
g_inst_component_loss_agg = torch.sum
99
g_tensor_batch_agg = torch.sum
1010
g_list_loss_agg = sum
11-
g_list_model_penalized_reg_agg = sum
11+
12+
def g_list_model_penalized_reg_agg(list_penalized_reg):
13+
"""
14+
aggregate along the list, but do not diminish the batch structure of the tensor
15+
"""
16+
return torch.stack(list_penalized_reg, dim=0).sum(dim=0)
17+
1218
g_str_cross_entropy_agg = "none"
1319
# component loss refers to aggregation of pixel loss, digit of KL divergences loss
1420
# instance loss currently use torch.sum, which is the same effect as torch.mean, the

domainlab/algos/trainers/a_trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,14 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
217217
"""
218218
list_reg_model, list_mu_model = self.decoratee.cal_reg_loss(
219219
tensor_x, tensor_y, tensor_d, others)
220-
list_reg, list_mu = self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
221-
return list_reg_model + list_reg, list_mu_model + list_mu
220+
assert len(list_reg_model) == len(list_mu_model)
221+
222+
list_reg_trainer, list_mu_trainer = self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
223+
assert len(list_reg_trainer) == len(list_mu_trainer)
224+
225+
list_loss = list_reg_model + list_reg_trainer
226+
list_mu = list_mu_model + list_mu_trainer
227+
return list_loss, list_mu
222228

223229
def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
224230
"""

domainlab/algos/trainers/train_basic.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,17 @@ def after_epoch(self, epoch):
5757
assert flag_stop is not None
5858
return flag_stop
5959

60-
def log_r_loss(self, list_b_reg_loss):
60+
def log_loss(self, list_b_reg_loss, loss_task, loss):
6161
"""
6262
just for logging the self.epo_reg_loss_tr
6363
"""
64+
self.epo_task_loss_tr += loss_task.sum().detach().item()
65+
#
6466
list_b_reg_loss_sumed = [ele.sum().detach().item()
6567
for ele in list_b_reg_loss]
6668
self.epo_reg_loss_tr = list(map(add, self.epo_reg_loss_tr,
6769
list_b_reg_loss_sumed))
70+
self.epo_loss_tr += loss.detach().item()
6871

6972
def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch):
7073
"""
@@ -78,7 +81,6 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch):
7881
loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others)
7982
loss.backward()
8083
self.optimizer.step()
81-
self.epo_loss_tr += loss.detach().item()
8284
self.after_batch(epoch, ind_batch)
8385
self.counter_batch += 1
8486

@@ -88,15 +90,13 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others):
8890
"""
8991
loss_task = self.model.cal_task_loss(tensor_x, tensor_y)
9092

91-
# only for logging
92-
self.epo_task_loss_tr += loss_task.sum().detach().item()
93-
#
94-
list_reg_tr, list_mu_tr = self.cal_reg_loss(tensor_x, tensor_y,
93+
list_reg_tr_batch, list_mu_tr = self.cal_reg_loss(tensor_x, tensor_y,
9594
tensor_d, others)
96-
#
97-
self.log_r_loss(list_reg_tr) # just for logging
98-
reg_tr = self.model.inner_product(list_reg_tr, list_mu_tr)
95+
tensor_batch_reg_loss_penalized = self.model.list_inner_product(
96+
list_reg_tr_batch, list_mu_tr)
97+
assert len(tensor_batch_reg_loss_penalized.shape) == 1
9998
loss_erm_agg = g_tensor_batch_agg(loss_task)
100-
loss_reg_agg = g_tensor_batch_agg(reg_tr)
99+
loss_reg_agg = g_tensor_batch_agg(tensor_batch_reg_loss_penalized)
101100
loss = self.model.multiplier4task_loss * loss_erm_agg + loss_reg_agg
101+
self.log_loss(list_reg_tr_batch, loss_task, loss)
102102
return loss

domainlab/algos/trainers/train_dial.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def gen_adversarial(self, device, img_natural, vec_y):
1717
this is not necessarily constraint optimal due to nonlinearity,
1818
as the constraint epsilon is only considered ad-hoc
1919
"""
20-
# @FIXME: is there better way to initialize adversarial image?
2120
# ensure adversarial image not in computational graph
2221
steps_perturb = self.aconf.dial_steps_perturb
2322
scale = self.aconf.dial_noise_scale

domainlab/algos/trainers/train_matchdg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def tr_batch(self, epoch, batch_idx, x_e, y_e, d_e, others=None):
110110
if self.flag_erm:
111111
# decoratee can be both trainer or model
112112
list_loss_reg_rand, list_mu_reg = self.decoratee.cal_reg_loss(x_e, y_e, d_e, others)
113-
loss_reg = self.model.inner_product(list_loss_reg_rand, list_mu_reg)
113+
loss_reg = self.model.list_inner_product(list_loss_reg_rand, list_mu_reg)
114114
loss_task_rand = self.model.cal_task_loss(x_e, y_e)
115115
# loss_erm_rnd_loader, *_ = self.model.cal_loss(x_e, y_e, d_e, others)
116116
loss_erm_rnd_loader = loss_reg + loss_task_rand * self.model.multiplier4task_loss

domainlab/algos/trainers/train_mldg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def tr_epoch(self, epoch):
7676
# since mldg's reg loss is on target domain,
7777
# no other trainer except hyperscheduler could decorate it unless we use state pattern
7878
# in the future to control source and target domain loader behavior
79-
source_reg_tr = self.model.inner_product(list_source_reg_tr, list_source_mu_tr)
79+
source_reg_tr = self.model.list_inner_product(list_source_reg_tr, list_source_mu_tr)
8080
# self.aconf.gamma_reg * loss_look_forward.sum()
8181
loss = loss_source_task.sum() + source_reg_tr.sum() +\
8282
self.aconf.gamma_reg * loss_look_forward.sum()

domainlab/models/a_model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,25 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
4343
calculate the loss
4444
"""
4545
list_loss, list_multiplier = self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
46-
loss_reg = self.inner_product(list_loss, list_multiplier)
46+
loss_reg = self.list_inner_product(list_loss, list_multiplier)
4747
loss_task_alone = self.cal_task_loss(tensor_x, tensor_y)
4848
loss_task = self.multiplier4task_loss * loss_task_alone
4949
return loss_task + loss_reg, list_loss, loss_task_alone
5050

51-
def inner_product(self, list_loss_scalar, list_multiplier):
51+
def list_inner_product(self, list_loss, list_multiplier):
5252
"""
53-
compute inner product between list of scalar loss and multiplier
54-
- the first dimension of the tensor v_reg_loss is mini-batch
55-
the second dimension is the number of regularizers
56-
- the vector mmu has dimension the number of regularizers
53+
compute inner product between list of regularization loss and multiplier
54+
- the length of the list is the number of regularizers
55+
- for each element of the list: the first dimension of the tensor is mini-batch
56+
return value of list_inner_product should keep the minibatch structure, thus aggregation
57+
here only aggregate along the list
5758
"""
58-
list_tuple = zip(list_loss_scalar, list_multiplier)
59+
list_tuple = zip(list_loss, list_multiplier)
5960
list_penalized_reg = [mtuple[0]*mtuple[1] for mtuple in list_tuple]
60-
return g_list_model_penalized_reg_agg(list_penalized_reg)
61+
tensor_batch_penalized_loss = g_list_model_penalized_reg_agg(list_penalized_reg)
62+
# return value of list_inner_product should keep the minibatch structure, thus aggregation
63+
# here only aggregate along the list
64+
return tensor_batch_penalized_loss
6165

6266
@abc.abstractmethod
6367
def cal_task_loss(self, tensor_x, tensor_y):

domainlab/models/a_model_classif.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,4 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
216216
"""
217217
device = tensor_x.device
218218
bsize = tensor_x.shape[0]
219-
return [torch.zeros(bsize, 1).to(device)], [0.0]
219+
return [torch.zeros(bsize).to(device)], [0.0]

0 commit comments

Comments
 (0)