Skip to content

Commit 5639949

Browse files
authored
Update hifigan_train_loop.py
1 parent cb38fb7 commit 5639949

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

TrainingInterfaces/Spectrogram_to_Wave/HiFIGAN/hifigan_train_loop.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def train_loop(generator,
2424
epochs_per_save=1,
2525
path_to_checkpoint=None,
2626
batch_size=32,
27-
steps=2500000,
27+
epochs=100,
28+
# the ideas is to only load a subset of data that fits in the RAM, then train for some epochs, then load new data and continue and so on.
2829
resume=False,
2930
use_signal_processing_losses=False # https://github.com/csteinmetz1/auraloss remember to cite if used
3031
):
@@ -48,9 +49,9 @@ def train_loop(generator,
4849
d = discriminator.to(device)
4950
g.train()
5051
d.train()
51-
optimizer_g = torch.optim.Adam(g.parameters(), betas=(0.5, 0.9), lr=2.0e-4, weight_decay=0.0)
52+
optimizer_g = torch.optim.RAdam(g.parameters(), betas=(0.5, 0.9), lr=0.001, weight_decay=0.0)
5253
scheduler_g = MultiStepLR(optimizer_g, gamma=0.5, milestones=[200000, 400000, 600000, 800000])
53-
optimizer_d = torch.optim.Adam(d.parameters(), betas=(0.5, 0.9), lr=2.0e-4, weight_decay=0.0)
54+
optimizer_d = torch.optim.RAdam(d.parameters(), betas=(0.5, 0.9), lr=0.0005, weight_decay=0.0)
5455
scheduler_d = MultiStepLR(optimizer_d, gamma=0.5, milestones=[200000, 400000, 600000, 800000])
5556

5657
train_loader = DataLoader(dataset=train_dataset,
@@ -76,7 +77,7 @@ def train_loop(generator,
7677

7778
start_time = time.time()
7879

79-
for _ in range(steps):
80+
for _ in range(epochs):
8081

8182
epoch += 1
8283
discriminator_losses = list()
@@ -102,7 +103,7 @@ def train_loop(generator,
102103
if use_signal_processing_losses:
103104
for sl in signal_processing_loss_functions:
104105
signal_loss += sl(pred_wave, gold_wave)
105-
signal_processing_losses.append(signal_loss.item())
106+
signal_processing_losses.append(signal_loss.item() * 0.5)
106107
d_outs = d(pred_wave)
107108
d_gold_outs = d(gold_wave)
108109
if step_counter > 10000: # a little bit of warmup helps, but it's not that important
@@ -111,7 +112,7 @@ def train_loop(generator,
111112
adversarial_loss = torch.tensor([0.0]).to(device)
112113
mel_loss = mel_l1(pred_wave.squeeze(1), gold_wave)
113114
feature_matching_loss = feat_match_criterion(d_outs, d_gold_outs)
114-
generator_total_loss = mel_loss * 40.0 + adversarial_loss * 4.0 + feature_matching_loss * 0.3 + signal_loss
115+
generator_total_loss = mel_loss * 40.0 + adversarial_loss * 4.0 + feature_matching_loss * 0.3 + signal_loss * 0.5
115116
optimizer_g.zero_grad()
116117
generator_total_loss.backward()
117118
generator_losses.append(generator_total_loss.item())

0 commit comments

Comments
 (0)