Skip to content

Commit ee33997

Browse files
committed
minor logging optimization
1 parent 9503595 commit ee33997

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_meta_train_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def train_loop(net,
184184
default_embedding = style_embedding_function(
185185
batch_of_spectrograms=datasets[0][0][2].unsqueeze(0).to(device),
186186
batch_of_spectrogram_lengths=datasets[0][0][3].unsqueeze(0).to(device)).squeeze()
187-
print(f"\nTotal Steps: {step_counter}")
187+
print("Reconstruction Loss: {}".format(round(sum(l1_losses_total) / len(l1_losses_total), 3)))
188+
print("Steps: {}\n".format(step_counter))
188189
torch.save({
189190
"model" : net.state_dict(),
190191
"optimizer" : optimizer.state_dict(),

TrainingInterfaces/Text_to_Spectrogram/ToucanTTS/toucantts_train_loop.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,10 @@ def train_loop(net,
174174
}, os.path.join(save_directory, "checkpoint_{}.pt".format(step_counter)))
175175
delete_old_checkpoints(save_directory, keep=5)
176176

177-
print("Epoch: {}".format(epoch))
178-
print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
179-
print("Steps: {}".format(step_counter))
177+
print("\nEpoch: {}".format(epoch))
178+
print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
179+
print("Reconstruction Loss: {}".format(round(sum(l1_losses_total) / len(l1_losses_total), 3)))
180+
print("Steps: {}\n".format(step_counter))
180181
if use_wandb:
181182
wandb.log({
182183
"l1_loss" : round(sum(l1_losses_total) / len(l1_losses_total), 5),
@@ -211,10 +212,6 @@ def train_loop(net,
211212
except IndexError:
212213
print("generating progress plots failed.")
213214

214-
if step_counter > steps:
215-
# DONE
216-
return
217-
218215
if step_counter > 3 * postnet_start_steps:
219216
# Run manual SWA (torch builtin doesn't work unfortunately due to the use of weight norm in the postflow)
220217
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=save_directory, n=2)
@@ -223,6 +220,9 @@ def train_loop(net,
223220
check_dict = torch.load(os.path.join(save_directory, "best.pt"), map_location=device)
224221
net.load_state_dict(check_dict["model"])
225222

223+
if step_counter > steps:
224+
return # DONE
225+
226226
net.train()
227227

228228

0 commit comments

Comments
 (0)