@@ -24,7 +24,8 @@ def train_loop(generator,
2424 epochs_per_save = 1 ,
2525 path_to_checkpoint = None ,
2626 batch_size = 32 ,
27- steps = 2500000 ,
27+ epochs = 100 ,
28+ # the ideas is to only load a subset of data that fits in the RAM, then train for some epochs, then load new data and continue and so on.
2829 resume = False ,
2930 use_signal_processing_losses = False # https://github.com/csteinmetz1/auraloss remember to cite if used
3031 ):
@@ -48,9 +49,9 @@ def train_loop(generator,
4849 d = discriminator .to (device )
4950 g .train ()
5051 d .train ()
51- optimizer_g = torch .optim .Adam (g .parameters (), betas = (0.5 , 0.9 ), lr = 2.0e-4 , weight_decay = 0.0 )
52+ optimizer_g = torch .optim .RAdam (g .parameters (), betas = (0.5 , 0.9 ), lr = 0.001 , weight_decay = 0.0 )
5253 scheduler_g = MultiStepLR (optimizer_g , gamma = 0.5 , milestones = [200000 , 400000 , 600000 , 800000 ])
53- optimizer_d = torch .optim .Adam (d .parameters (), betas = (0.5 , 0.9 ), lr = 2.0e-4 , weight_decay = 0.0 )
54+ optimizer_d = torch .optim .RAdam (d .parameters (), betas = (0.5 , 0.9 ), lr = 0.0005 , weight_decay = 0.0 )
5455 scheduler_d = MultiStepLR (optimizer_d , gamma = 0.5 , milestones = [200000 , 400000 , 600000 , 800000 ])
5556
5657 train_loader = DataLoader (dataset = train_dataset ,
@@ -76,7 +77,7 @@ def train_loop(generator,
7677
7778 start_time = time .time ()
7879
79- for _ in range (steps ):
80+ for _ in range (epochs ):
8081
8182 epoch += 1
8283 discriminator_losses = list ()
@@ -102,7 +103,7 @@ def train_loop(generator,
102103 if use_signal_processing_losses :
103104 for sl in signal_processing_loss_functions :
104105 signal_loss += sl (pred_wave , gold_wave )
105- signal_processing_losses .append (signal_loss .item ())
106+ signal_processing_losses .append (signal_loss .item () * 0.5 )
106107 d_outs = d (pred_wave )
107108 d_gold_outs = d (gold_wave )
108109 if step_counter > 10000 : # a little bit of warmup helps, but it's not that important
@@ -111,7 +112,7 @@ def train_loop(generator,
111112 adversarial_loss = torch .tensor ([0.0 ]).to (device )
112113 mel_loss = mel_l1 (pred_wave .squeeze (1 ), gold_wave )
113114 feature_matching_loss = feat_match_criterion (d_outs , d_gold_outs )
114- generator_total_loss = mel_loss * 40.0 + adversarial_loss * 4.0 + feature_matching_loss * 0.3 + signal_loss
115+ generator_total_loss = mel_loss * 40.0 + adversarial_loss * 4.0 + feature_matching_loss * 0.3 + signal_loss * 0.5
115116 optimizer_g .zero_grad ()
116117 generator_total_loss .backward ()
117118 generator_losses .append (generator_total_loss .item ())
0 commit comments