@@ -174,9 +174,10 @@ def train_loop(net,
174174 }, os .path .join (save_directory , "checkpoint_{}.pt" .format (step_counter )))
175175 delete_old_checkpoints (save_directory , keep = 5 )
176176
177- print ("Epoch: {}" .format (epoch ))
178- print ("Time elapsed: {} Minutes" .format (round ((time .time () - start_time ) / 60 )))
179- print ("Steps: {}" .format (step_counter ))
177+ print ("\n Epoch: {}" .format (epoch ))
178+ print ("Time elapsed: {} Minutes" .format (round ((time .time () - start_time ) / 60 )))
179+ print ("Reconstruction Loss: {}" .format (round (sum (l1_losses_total ) / len (l1_losses_total ), 3 )))
180+ print ("Steps: {}\n " .format (step_counter ))
180181 if use_wandb :
181182 wandb .log ({
182183 "l1_loss" : round (sum (l1_losses_total ) / len (l1_losses_total ), 5 ),
@@ -211,10 +212,6 @@ def train_loop(net,
211212 except IndexError :
212213 print ("generating progress plots failed." )
213214
214- if step_counter > steps :
215- # DONE
216- return
217-
218215 if step_counter > 3 * postnet_start_steps :
219216 # Run manual SWA (torch builtin doesn't work unfortunately due to the use of weight norm in the postflow)
220217 checkpoint_paths = get_n_recent_checkpoints_paths (checkpoint_dir = save_directory , n = 2 )
@@ -223,6 +220,9 @@ def train_loop(net,
223220 check_dict = torch .load (os .path .join (save_directory , "best.pt" ), map_location = device )
224221 net .load_state_dict (check_dict ["model" ])
225222
223+ if step_counter > steps :
224+ return # DONE
225+
226226 net .train ()
227227
228228
0 commit comments