@@ -57,14 +57,17 @@ def after_epoch(self, epoch):
5757 assert flag_stop is not None
5858 return flag_stop
5959
60- def log_r_loss (self , list_b_reg_loss ):
60+ def log_loss (self , list_b_reg_loss , loss_task , loss ):
6161 """
6262 just for logging the self.epo_reg_loss_tr
6363 """
64+ self .epo_task_loss_tr += loss_task .sum ().detach ().item ()
65+ #
6466 list_b_reg_loss_sumed = [ele .sum ().detach ().item ()
6567 for ele in list_b_reg_loss ]
6668 self .epo_reg_loss_tr = list (map (add , self .epo_reg_loss_tr ,
6769 list_b_reg_loss_sumed ))
70+ self .epo_loss_tr += loss .detach ().item ()
6871
6972 def tr_batch (self , tensor_x , tensor_y , tensor_d , others , ind_batch , epoch ):
7073 """
@@ -78,7 +81,6 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch):
7881 loss = self .cal_loss (tensor_x , tensor_y , tensor_d , others )
7982 loss .backward ()
8083 self .optimizer .step ()
81- self .epo_loss_tr += loss .detach ().item ()
8284 self .after_batch (epoch , ind_batch )
8385 self .counter_batch += 1
8486
@@ -88,15 +90,13 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others):
8890 """
8991 loss_task = self .model .cal_task_loss (tensor_x , tensor_y )
9092
91- # only for logging
92- self .epo_task_loss_tr += loss_task .sum ().detach ().item ()
93- #
94- list_reg_tr , list_mu_tr = self .cal_reg_loss (tensor_x , tensor_y ,
93+ list_reg_tr_batch , list_mu_tr = self .cal_reg_loss (tensor_x , tensor_y ,
9594 tensor_d , others )
96- #
97- self . log_r_loss ( list_reg_tr ) # just for logging
98- reg_tr = self . model . inner_product ( list_reg_tr , list_mu_tr )
95+ tensor_batch_reg_loss_penalized = self . model . list_inner_product (
96+ list_reg_tr_batch , list_mu_tr )
97+ assert len ( tensor_batch_reg_loss_penalized . shape ) == 1
9998 loss_erm_agg = g_tensor_batch_agg (loss_task )
100- loss_reg_agg = g_tensor_batch_agg (reg_tr )
99+ loss_reg_agg = g_tensor_batch_agg (tensor_batch_reg_loss_penalized )
101100 loss = self .model .multiplier4task_loss * loss_erm_agg + loss_reg_agg
101+ self .log_loss (list_reg_tr_batch , loss_task , loss )
102102 return loss
0 commit comments