22
33import torch
44from torch import nn
5+ from torch .cuda .amp .autocast_mode import autocast
56from torch .nn import functional as F
67
78from TTS .tts .configs import GlowTTSConfig
@@ -68,6 +69,8 @@ def __init__(self, config: GlowTTSConfig):
6869 # TODO: make this adjustable
6970 self .c_in_channels = 256
7071
72+ self .run_data_dep_init = config .data_dep_init_steps > 0
73+
7174 self .encoder = Encoder (
7275 self .num_chars ,
7376 out_channels = self .out_channels ,
@@ -131,6 +134,18 @@ def compute_outputs(attn, o_mean, o_log_scale, x_mask):
131134 o_attn_dur = torch .log (1 + torch .sum (attn , - 1 )) * x_mask
132135 return y_mean , y_log_scale , o_attn_dur
133136
137+ def unlock_act_norm_layers (self ):
138+ """Unlock activation normalization layers for data depended initalization."""
139+ for f in self .decoder .flows :
140+ if getattr (f , "set_ddi" , False ):
141+ f .set_ddi (True )
142+
143+ def lock_act_norm_layers (self ):
144+ """Lock activation normalization layers."""
145+ for f in self .decoder .flows :
146+ if getattr (f , "set_ddi" , False ):
147+ f .set_ddi (False )
148+
134149 def forward (
135150 self , x , x_lengths , y , y_lengths = None , aux_input = {"d_vectors" : None , "speaker_ids" : None }
136151 ): # pylint: disable=dangerous-default-value
@@ -142,6 +157,7 @@ def forward(
142157 - y_lengths::math:`B`
143158 - g: :math:`[B, C] or B`
144159 """
160+ # [B, T, C] -> [B, C, T]
145161 y = y .transpose (1 , 2 )
146162 y_max_length = y .size (2 )
147163 # norm speaker embeddings
@@ -157,6 +173,7 @@ def forward(
157173 y , y_lengths , y_max_length , attn = self .preprocess (y , y_lengths , y_max_length , None )
158174 # create masks
159175 y_mask = torch .unsqueeze (sequence_mask (y_lengths , y_max_length ), 1 ).to (x_mask .dtype )
176+ # [B, 1, T_en, T_de]
160177 attn_mask = torch .unsqueeze (x_mask , - 1 ) * torch .unsqueeze (y_mask , 2 )
161178 # decoder pass
162179 z , logdet = self .decoder (y , y_mask , g = g , reverse = False )
@@ -172,7 +189,7 @@ def forward(
172189 y_mean , y_log_scale , o_attn_dur = self .compute_outputs (attn , o_mean , o_log_scale , x_mask )
173190 attn = attn .squeeze (1 ).permute (0 , 2 , 1 )
174191 outputs = {
175- "model_outputs " : z .transpose (1 , 2 ),
192+ "z " : z .transpose (1 , 2 ),
176193 "logdet" : logdet ,
177194 "y_mean" : y_mean .transpose (1 , 2 ),
178195 "y_log_scale" : y_log_scale .transpose (1 , 2 ),
@@ -319,7 +336,8 @@ def inference(
319336 return outputs
320337
321338 def train_step (self , batch : dict , criterion : nn .Module ):
322- """Perform a single training step by fetching the right set if samples from the batch.
339+ """A single training step. Forward pass and loss computation. Run data depended initialization for the
340+ first `config.data_dep_init_steps` steps.
323341
324342 Args:
325343 batch (dict): [description]
@@ -332,31 +350,57 @@ def train_step(self, batch: dict, criterion: nn.Module):
332350 d_vectors = batch ["d_vectors" ]
333351 speaker_ids = batch ["speaker_ids" ]
334352
335- outputs = self .forward (
336- text_input ,
337- text_lengths ,
338- mel_input ,
339- mel_lengths ,
340- aux_input = {"d_vectors" : d_vectors , "speaker_ids" : speaker_ids },
341- )
342-
343- loss_dict = criterion (
344- outputs ["model_outputs" ],
345- outputs ["y_mean" ],
346- outputs ["y_log_scale" ],
347- outputs ["logdet" ],
348- mel_lengths ,
349- outputs ["durations_log" ],
350- outputs ["total_durations_log" ],
351- text_lengths ,
352- )
353+ if self .run_data_dep_init and self .training :
354+ # compute data-dependent initialization of activation norm layers
355+ self .unlock_act_norm_layers ()
356+ with torch .no_grad ():
357+ _ = self .forward (
358+ text_input ,
359+ text_lengths ,
360+ mel_input ,
361+ mel_lengths ,
362+ aux_input = {"d_vectors" : d_vectors , "speaker_ids" : speaker_ids },
363+ )
364+ outputs = None
365+ loss_dict = None
366+ self .lock_act_norm_layers ()
367+ else :
368+ # normal training step
369+ outputs = self .forward (
370+ text_input ,
371+ text_lengths ,
372+ mel_input ,
373+ mel_lengths ,
374+ aux_input = {"d_vectors" : d_vectors , "speaker_ids" : speaker_ids },
375+ )
353376
377+ with autocast (enabled = False ): # avoid mixed_precision in criterion
378+ loss_dict = criterion (
379+ outputs ["z" ].float (),
380+ outputs ["y_mean" ].float (),
381+ outputs ["y_log_scale" ].float (),
382+ outputs ["logdet" ].float (),
383+ mel_lengths ,
384+ outputs ["durations_log" ].float (),
385+ outputs ["total_durations_log" ].float (),
386+ text_lengths ,
387+ )
354388 return outputs , loss_dict
355389
356390 def train_log (self , ap : AudioProcessor , batch : dict , outputs : dict ): # pylint: disable=no-self-use
357- model_outputs = outputs ["model_outputs" ]
358391 alignments = outputs ["alignments" ]
392+ text_input = batch ["text_input" ]
393+ text_lengths = batch ["text_lengths" ]
359394 mel_input = batch ["mel_input" ]
395+ d_vectors = batch ["d_vectors" ]
396+ speaker_ids = batch ["speaker_ids" ]
397+
398+ # model runs reverse flow to predict spectrograms
399+ pred_outputs = self .inference (
400+ text_input [:1 ],
401+ aux_input = {"x_lengths" : text_lengths [:1 ], "d_vectors" : d_vectors , "speaker_ids" : speaker_ids },
402+ )
403+ model_outputs = pred_outputs ["model_outputs" ]
360404
361405 pred_spec = model_outputs [0 ].data .cpu ().numpy ()
362406 gt_spec = mel_input [0 ].data .cpu ().numpy ()
@@ -393,26 +437,29 @@ def test_run(self, ap):
393437 test_figures = {}
394438 test_sentences = self .config .test_sentences
395439 aux_inputs = self .get_aux_input ()
396- for idx , sen in enumerate (test_sentences ):
397- outputs = synthesis (
398- self ,
399- sen ,
400- self .config ,
401- "cuda" in str (next (self .parameters ()).device ),
402- ap ,
403- speaker_id = aux_inputs ["speaker_id" ],
404- d_vector = aux_inputs ["d_vector" ],
405- style_wav = aux_inputs ["style_wav" ],
406- enable_eos_bos_chars = self .config .enable_eos_bos_chars ,
407- use_griffin_lim = True ,
408- do_trim_silence = False ,
409- )
410-
411- test_audios ["{}-audio" .format (idx )] = outputs ["wav" ]
412- test_figures ["{}-prediction" .format (idx )] = plot_spectrogram (
413- outputs ["outputs" ]["model_outputs" ], ap , output_fig = False
414- )
415- test_figures ["{}-alignment" .format (idx )] = plot_alignment (outputs ["alignments" ], output_fig = False )
440+ if len (test_sentences ) == 0 :
441+ print (" | [!] No test sentences provided." )
442+ else :
443+ for idx , sen in enumerate (test_sentences ):
444+ outputs = synthesis (
445+ self ,
446+ sen ,
447+ self .config ,
448+ "cuda" in str (next (self .parameters ()).device ),
449+ ap ,
450+ speaker_id = aux_inputs ["speaker_id" ],
451+ d_vector = aux_inputs ["d_vector" ],
452+ style_wav = aux_inputs ["style_wav" ],
453+ enable_eos_bos_chars = self .config .enable_eos_bos_chars ,
454+ use_griffin_lim = True ,
455+ do_trim_silence = False ,
456+ )
457+
458+ test_audios ["{}-audio" .format (idx )] = outputs ["wav" ]
459+ test_figures ["{}-prediction" .format (idx )] = plot_spectrogram (
460+ outputs ["outputs" ]["model_outputs" ], ap , output_fig = False
461+ )
462+ test_figures ["{}-alignment" .format (idx )] = plot_alignment (outputs ["alignments" ], output_fig = False )
416463 return test_figures , test_audios
417464
418465 def preprocess (self , y , y_lengths , y_max_length , attn = None ):
@@ -441,3 +488,7 @@ def get_criterion(self):
441488 from TTS .tts .layers .losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
442489
443490 return GlowTTSLoss ()
491+
492+ def on_train_step_start (self , trainer ):
493+ """Decide on every training step wheter enable/disable data depended initialization."""
494+ self .run_data_dep_init = trainer .total_steps_done < self .data_dep_init_steps
0 commit comments