Skip to content

Commit 11283fc

Browse files
authored
Ensures that only GPT model is in training mode during XTTS GPT training (#3241)
* Ensures that only GPT model is in training mode during training * Fix parallel wavegan unit test
1 parent 14579a4 commit 11283fc

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

TTS/tts/layers/xtts/trainer/gpt_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,10 @@ def eval_step(self, batch, criterion):
318318
batch["cond_idxs"] = None
319319
return self.train_step(batch, criterion)
320320

321-
def on_epoch_start(self, trainer): # pylint: disable=W0613
322-
# guarante that dvae will be in eval mode after .train() on evaluation end
323-
self.dvae = self.dvae.eval()
321+
def on_train_epoch_start(self, trainer):
322+
trainer.model.eval() # the whole model to eval
323+
# put gpt model in training mode
324+
trainer.model.xtts.gpt.train()
324325

325326
def on_init_end(self, trainer): # pylint: disable=W0613
326327
# ignore similarities.pth on clearml save/upload

TTS/vocoder/configs/parallel_wavegan_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
9494
use_noise_augment: bool = False
9595
use_cache: bool = True
9696
steps_to_start_discriminator: int = 200000
97+
target_loss: str = "loss_1"
9798

9899
# LOSS PARAMETERS - overrides
99100
use_stft_loss: bool = True

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pandas>=1.4,<2.0
2727
# deps for training
2828
matplotlib>=3.7.0
2929
# coqui stack
30-
trainer
30+
trainer>=0.0.32
3131
# config management
3232
coqpit>=0.0.16
3333
# chinese g2p deps

0 commit comments

Comments
 (0)