Skip to content

Commit 0f3d868

Browse files
authored
Merge pull request #815 from coqui-ai/dev
v0.3.1
2 parents 0592a58 + 2766dd1 commit 0f3d868

File tree

7 files changed

+115
-52
lines changed

7 files changed

+115
-52
lines changed

TTS/tts/configs/glow_tts_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass, field
2+
from typing import List
23

34
from TTS.tts.configs.shared_configs import BaseTTSConfig
45

@@ -167,3 +168,14 @@ class GlowTTSConfig(BaseTTSConfig):
167168
min_seq_len: int = 3
168169
max_seq_len: int = 500
169170
r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it.
171+
172+
# testing
173+
test_sentences: List[str] = field(
174+
default_factory=lambda: [
175+
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
176+
"Be a voice, not an echo.",
177+
"I'm sorry Dave. I'm afraid I can't do that.",
178+
"This cake is great. It's so delicious and moist.",
179+
"Prior to November 22, 1963.",
180+
]
181+
)

TTS/tts/configs/speedy_speech_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
119119
hidden_channels=128,
120120
num_speakers=0,
121121
positional_encoding=True,
122-
detach_duration_predictor=True
122+
detach_duration_predictor=True,
123123
)
124124

125125
# multi-speaker settings

TTS/tts/layers/glow_tts/encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ def forward(self, x, x_lengths, g=None):
165165
# set duration predictor input
166166
if g is not None:
167167
g_exp = g.expand(-1, -1, x.size(-1))
168-
x_dp = torch.cat([torch.detach(x), g_exp], 1)
168+
x_dp = torch.cat([x.detach(), g_exp], 1)
169169
else:
170-
x_dp = torch.detach(x)
170+
x_dp = x.detach()
171171
# final projection layer
172172
x_m = self.proj_m(x) * x_mask
173173
if not self.mean_only:

TTS/tts/layers/losses.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,11 @@ def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x
427427
return_dict = {}
428428
# flow loss - neg log likelihood
429429
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2)
430-
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[1])
430+
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[2])
431431
# duration loss - MSE
432-
# loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
432+
loss_dur = torch.sum((o_dur_log - o_attn_dur) ** 2) / torch.sum(x_lengths)
433433
# duration loss - huber loss
434-
loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths)
434+
# loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths)
435435
return_dict["loss"] = log_mle + loss_dur
436436
return_dict["log_mle"] = log_mle
437437
return_dict["loss_dur"] = loss_dur

TTS/tts/models/glow_tts.py

Lines changed: 92 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch import nn
5+
from torch.cuda.amp.autocast_mode import autocast
56
from torch.nn import functional as F
67

78
from TTS.tts.configs import GlowTTSConfig
@@ -68,6 +69,8 @@ def __init__(self, config: GlowTTSConfig):
6869
# TODO: make this adjustable
6970
self.c_in_channels = 256
7071

72+
self.run_data_dep_init = config.data_dep_init_steps > 0
73+
7174
self.encoder = Encoder(
7275
self.num_chars,
7376
out_channels=self.out_channels,
@@ -131,6 +134,18 @@ def compute_outputs(attn, o_mean, o_log_scale, x_mask):
131134
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
132135
return y_mean, y_log_scale, o_attn_dur
133136

137+
def unlock_act_norm_layers(self):
138+
"""Unlock activation normalization layers for data depended initalization."""
139+
for f in self.decoder.flows:
140+
if getattr(f, "set_ddi", False):
141+
f.set_ddi(True)
142+
143+
def lock_act_norm_layers(self):
144+
"""Lock activation normalization layers."""
145+
for f in self.decoder.flows:
146+
if getattr(f, "set_ddi", False):
147+
f.set_ddi(False)
148+
134149
def forward(
135150
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
136151
): # pylint: disable=dangerous-default-value
@@ -142,6 +157,7 @@ def forward(
142157
- y_lengths::math:`B`
143158
- g: :math:`[B, C] or B`
144159
"""
160+
# [B, T, C] -> [B, C, T]
145161
y = y.transpose(1, 2)
146162
y_max_length = y.size(2)
147163
# norm speaker embeddings
@@ -157,6 +173,7 @@ def forward(
157173
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
158174
# create masks
159175
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
176+
# [B, 1, T_en, T_de]
160177
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
161178
# decoder pass
162179
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
@@ -172,7 +189,7 @@ def forward(
172189
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
173190
attn = attn.squeeze(1).permute(0, 2, 1)
174191
outputs = {
175-
"model_outputs": z.transpose(1, 2),
192+
"z": z.transpose(1, 2),
176193
"logdet": logdet,
177194
"y_mean": y_mean.transpose(1, 2),
178195
"y_log_scale": y_log_scale.transpose(1, 2),
@@ -319,7 +336,8 @@ def inference(
319336
return outputs
320337

321338
def train_step(self, batch: dict, criterion: nn.Module):
322-
"""Perform a single training step by fetching the right set if samples from the batch.
339+
"""A single training step. Forward pass and loss computation. Run data depended initialization for the
340+
first `config.data_dep_init_steps` steps.
323341
324342
Args:
325343
batch (dict): [description]
@@ -332,31 +350,57 @@ def train_step(self, batch: dict, criterion: nn.Module):
332350
d_vectors = batch["d_vectors"]
333351
speaker_ids = batch["speaker_ids"]
334352

335-
outputs = self.forward(
336-
text_input,
337-
text_lengths,
338-
mel_input,
339-
mel_lengths,
340-
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
341-
)
342-
343-
loss_dict = criterion(
344-
outputs["model_outputs"],
345-
outputs["y_mean"],
346-
outputs["y_log_scale"],
347-
outputs["logdet"],
348-
mel_lengths,
349-
outputs["durations_log"],
350-
outputs["total_durations_log"],
351-
text_lengths,
352-
)
353+
if self.run_data_dep_init and self.training:
354+
# compute data-dependent initialization of activation norm layers
355+
self.unlock_act_norm_layers()
356+
with torch.no_grad():
357+
_ = self.forward(
358+
text_input,
359+
text_lengths,
360+
mel_input,
361+
mel_lengths,
362+
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
363+
)
364+
outputs = None
365+
loss_dict = None
366+
self.lock_act_norm_layers()
367+
else:
368+
# normal training step
369+
outputs = self.forward(
370+
text_input,
371+
text_lengths,
372+
mel_input,
373+
mel_lengths,
374+
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
375+
)
353376

377+
with autocast(enabled=False): # avoid mixed_precision in criterion
378+
loss_dict = criterion(
379+
outputs["z"].float(),
380+
outputs["y_mean"].float(),
381+
outputs["y_log_scale"].float(),
382+
outputs["logdet"].float(),
383+
mel_lengths,
384+
outputs["durations_log"].float(),
385+
outputs["total_durations_log"].float(),
386+
text_lengths,
387+
)
354388
return outputs, loss_dict
355389

356390
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
357-
model_outputs = outputs["model_outputs"]
358391
alignments = outputs["alignments"]
392+
text_input = batch["text_input"]
393+
text_lengths = batch["text_lengths"]
359394
mel_input = batch["mel_input"]
395+
d_vectors = batch["d_vectors"]
396+
speaker_ids = batch["speaker_ids"]
397+
398+
# model runs reverse flow to predict spectrograms
399+
pred_outputs = self.inference(
400+
text_input[:1],
401+
aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids},
402+
)
403+
model_outputs = pred_outputs["model_outputs"]
360404

361405
pred_spec = model_outputs[0].data.cpu().numpy()
362406
gt_spec = mel_input[0].data.cpu().numpy()
@@ -393,26 +437,29 @@ def test_run(self, ap):
393437
test_figures = {}
394438
test_sentences = self.config.test_sentences
395439
aux_inputs = self.get_aux_input()
396-
for idx, sen in enumerate(test_sentences):
397-
outputs = synthesis(
398-
self,
399-
sen,
400-
self.config,
401-
"cuda" in str(next(self.parameters()).device),
402-
ap,
403-
speaker_id=aux_inputs["speaker_id"],
404-
d_vector=aux_inputs["d_vector"],
405-
style_wav=aux_inputs["style_wav"],
406-
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
407-
use_griffin_lim=True,
408-
do_trim_silence=False,
409-
)
410-
411-
test_audios["{}-audio".format(idx)] = outputs["wav"]
412-
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
413-
outputs["outputs"]["model_outputs"], ap, output_fig=False
414-
)
415-
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
440+
if len(test_sentences) == 0:
441+
print(" | [!] No test sentences provided.")
442+
else:
443+
for idx, sen in enumerate(test_sentences):
444+
outputs = synthesis(
445+
self,
446+
sen,
447+
self.config,
448+
"cuda" in str(next(self.parameters()).device),
449+
ap,
450+
speaker_id=aux_inputs["speaker_id"],
451+
d_vector=aux_inputs["d_vector"],
452+
style_wav=aux_inputs["style_wav"],
453+
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
454+
use_griffin_lim=True,
455+
do_trim_silence=False,
456+
)
457+
458+
test_audios["{}-audio".format(idx)] = outputs["wav"]
459+
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
460+
outputs["outputs"]["model_outputs"], ap, output_fig=False
461+
)
462+
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
416463
return test_figures, test_audios
417464

418465
def preprocess(self, y, y_lengths, y_max_length, attn=None):
@@ -441,3 +488,7 @@ def get_criterion(self):
441488
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
442489

443490
return GlowTTSLoss()
491+
492+
def on_train_step_start(self, trainer):
493+
"""Decide on every training step wheter enable/disable data depended initialization."""
494+
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps

recipes/ljspeech/glow_tts/train_glowtts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
run_eval=True,
1616
test_delay_epochs=-1,
1717
epochs=1000,
18-
text_cleaner="english_cleaners",
19-
use_phonemes=False,
18+
text_cleaner="phoneme_cleaners",
19+
use_phonemes=True,
2020
phoneme_language="en-us",
2121
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
2222
print_step=25,
23-
print_eval=True,
24-
mixed_precision=False,
23+
print_eval=False,
24+
mixed_precision=True,
2525
output_path=output_path,
2626
datasets=[dataset_config],
2727
)

tests/tts_tests/test_glow_tts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_train_step():
6363
optimizer.zero_grad()
6464
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, None)
6565
loss_dict = criterion(
66-
outputs["model_outputs"],
66+
outputs["z"],
6767
outputs["y_mean"],
6868
outputs["y_log_scale"],
6969
outputs["logdet"],

0 commit comments

Comments
 (0)