Skip to content

Commit c186583

Browse files
committed
remove transcript caches
1 parent acc144a commit c186583

File tree

7 files changed

+125
-116
lines changed

7 files changed

+125
-116
lines changed

InferenceInterfaces/ToucanTTSInterface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def forward(self,
166166
pause_duration_scaling_factor=pause_duration_scaling_factor,
167167
prosody_creativity=prosody_creativity)
168168

169-
wave, _, _ = self.vocoder(mel.unsqueeze(0))
169+
wave = self.vocoder(mel.unsqueeze(0))
170170
wave = wave.squeeze().cpu()
171171
wave = wave.numpy()
172172
sr = 24000

Modules/Vocoder/HiFiGAN_Discriminators.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import torch
1010
import torch.nn.functional as F
1111

12-
from Modules.Vocoder.Avocodo_Discriminators import MultiCoMBDiscriminator
13-
from Modules.Vocoder.Avocodo_Discriminators import MultiSubBandDiscriminator
1412
from Modules.Vocoder.SAN_modules import SANConv1d
1513
from Modules.Vocoder.SAN_modules import SANConv2d
1614

@@ -456,10 +454,13 @@ def forward(self, x):
456454

457455

458456
class AvocodoHiFiGANJointDiscriminator(torch.nn.Module):
457+
"""
458+
Contradicting the legacy name, the Avocodo parts were removed again for stability
459+
"""
459460

460461
def __init__(self,
461462
# Multi-scale discriminator related
462-
scales=3,
463+
scales=4,
463464
scale_downsample_pooling="AvgPool1d",
464465
scale_downsample_pooling_params={"kernel_size": 4,
465466
"stride" : 2,
@@ -471,7 +472,7 @@ def __init__(self,
471472
"max_downsample_channels" : 1024,
472473
"max_groups" : 16,
473474
"bias" : True,
474-
"downsample_scales" : [4, 4, 4, 4, 1],
475+
"downsample_scales" : [4, 4, 4, 1],
475476
"nonlinear_activation" : "LeakyReLU",
476477
"nonlinear_activation_params": {"negative_slope": 0.1}, },
477478
follow_official_norm=True,
@@ -481,41 +482,14 @@ def __init__(self,
481482
"out_channels" : 1,
482483
"kernel_sizes" : [5, 3],
483484
"channels" : 32,
484-
"downsample_scales" : [3, 3, 3, 3, 1],
485+
"downsample_scales" : [3, 3, 3, 1],
485486
"max_downsample_channels" : 1024,
486487
"bias" : True,
487488
"nonlinear_activation" : "LeakyReLU",
488489
"nonlinear_activation_params": {"negative_slope": 0.1},
489490
"use_weight_norm" : True,
490491
"use_spectral_norm" : False, },
491-
# CoMB discriminator related
492-
kernels=((7, 11, 11, 11, 11, 5),
493-
(11, 21, 21, 21, 21, 5),
494-
(15, 41, 41, 41, 41, 5)),
495-
channels=(16, 64, 256, 1024, 1024, 1024),
496-
groups=(1, 4, 16, 64, 256, 1),
497-
strides=(1, 1, 4, 4, 4, 1),
498-
# Sub-Band discriminator related
499-
tkernels=(7, 5, 3),
500-
fkernel=5,
501-
tchannels=(64, 128, 256, 256, 256),
502-
fchannels=(32, 64, 128, 128, 128),
503-
tstrides=((1, 1, 3, 3, 1),
504-
(1, 1, 3, 3, 1),
505-
(1, 1, 3, 3, 1)),
506-
fstride=(1, 1, 3, 3, 1),
507-
tdilations=(((5, 7, 11), (5, 7, 11), (5, 7, 11), (5, 7, 11), (5, 7, 11), (5, 7, 11)),
508-
((3, 5, 7), (3, 5, 7), (3, 5, 7), (3, 5, 7), (3, 5, 7)),
509-
((1, 2, 3), (1, 2, 3), (1, 2, 3), (1, 2, 3), (1, 2, 3))),
510-
fdilations=((1, 2, 3),
511-
(1, 2, 3),
512-
(1, 2, 3),
513-
(2, 3, 5),
514-
(2, 3, 5)),
515-
tsubband=(6, 11, 16),
516-
n=16,
517-
m=64,
518-
freq_init_ch=192):
492+
):
519493
super().__init__()
520494
self.msd = HiFiGANMultiScaleDiscriminator(scales=scales,
521495
downsample_pooling=scale_downsample_pooling,
@@ -524,10 +498,8 @@ def __init__(self,
524498
follow_official_norm=follow_official_norm, )
525499
self.mpd = HiFiGANMultiPeriodDiscriminator(periods=periods,
526500
discriminator_params=period_discriminator_params, )
527-
self.mcmbd = MultiCoMBDiscriminator(kernels, channels, groups, strides)
528-
self.msbd = MultiSubBandDiscriminator(tkernels, fkernel, tchannels, fchannels, tstrides, fstride, tdilations, fdilations, tsubband, n, m, freq_init_ch)
529501

530-
def forward(self, wave, intermediate_wave_upsampled_twice=None, intermediate_wave_upsampled_once=None, discriminator_train_flag=False):
502+
def forward(self, wave, discriminator_train_flag=False):
531503
"""
532504
Calculate forward propagation.
533505
@@ -542,9 +514,9 @@ def forward(self, wave, intermediate_wave_upsampled_twice=None, intermediate_wav
542514
"""
543515
msd_outs, msd_feats = self.msd(wave, discriminator_train_flag)
544516
mpd_outs, mpd_feats = self.mpd(wave, discriminator_train_flag)
545-
mcmbd_outs, mcmbd_feats = self.mcmbd(wave_final=wave,
546-
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice,
547-
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once,
548-
discriminator_train_flag=discriminator_train_flag)
549-
msbd_outs, msbd_feats = self.msbd(wave, discriminator_train_flag)
550-
return msd_outs + mpd_outs + mcmbd_outs + msbd_outs, msd_feats + mpd_feats + mcmbd_feats + msbd_feats
517+
return msd_outs + mpd_outs, msd_feats + mpd_feats
518+
519+
520+
if __name__ == '__main__':
521+
d = AvocodoHiFiGANJointDiscriminator()
522+
print(d(torch.randn([2, 1, 12288 * 2])))
Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import os
2-
import random
31
from multiprocessing import Manager
42
from multiprocessing import Process
53

64
import librosa
75
import numpy
6+
import numpy as np
87
import soundfile as sf
98
import torch
109
from torch.utils.data import Dataset
@@ -19,8 +18,8 @@ def __init__(self,
1918
list_of_original_paths,
2019
list_of_synthetic_paths,
2120
desired_samplingrate=24000,
22-
samples_per_segment=12288, # = (8192 * 3) 2 , as I used 8192 for 16kHz previously
23-
loading_processes=max(os.cpu_count() - 2, 1)):
21+
samples_per_segment=12288 * 2, # = (8192 * 3) 2 , as I used 8192 for 16kHz previously
22+
loading_processes=1):
2423
self.samples_per_segment = samples_per_segment
2524
self.desired_samplingrate = desired_samplingrate
2625
self.melspec_ap = AudioPreprocessor(input_sr=self.desired_samplingrate,
@@ -53,19 +52,18 @@ def cache_builder_process(self, path_split):
5352
try:
5453
path1, path2 = path
5554

56-
wave1, sr = sf.read(path1)
57-
if len(wave1.shape) == 2:
58-
wave1 = librosa.to_mono(numpy.transpose(wave1))
55+
wave, sr = sf.read(path1)
56+
if len(wave.shape) == 2:
57+
wave = librosa.to_mono(numpy.transpose(wave))
5958
if sr != self.desired_samplingrate:
60-
wave1 = librosa.resample(y=wave1, orig_sr=sr, target_sr=self.desired_samplingrate)
59+
wave = librosa.resample(y=wave, orig_sr=sr, target_sr=self.desired_samplingrate)
6160

62-
wave2, sr = sf.read(path2)
63-
if len(wave2.shape) == 2:
64-
wave2 = librosa.to_mono(numpy.transpose(wave2))
65-
if sr != self.desired_samplingrate:
66-
wave2 = librosa.resample(y=wave2, orig_sr=sr, target_sr=self.desired_samplingrate)
61+
if len(wave) > self.samples_per_segment + 2000:
62+
spec = torch.load(path2, map_location="cpu")
63+
self.waves.append((wave, spec))
64+
else:
65+
print("excluding short sample")
6766

68-
self.waves.append((wave1, wave2))
6967
except RuntimeError:
7068
print(f"Problem with the following path: {path}")
7169

@@ -77,32 +75,68 @@ def __getitem__(self, index):
7775
7876
return a pair of high-res audio and corresponding low-res spectrogram as if it was predicted by the TTS
7977
"""
80-
try:
81-
wave1 = self.waves[index][0]
82-
wave2 = self.waves[index][1]
83-
while len(wave1) < self.samples_per_segment + 50: # + 50 is just to be extra sure
84-
# catch files that are too short to apply meaningful signal processing and make them longer
85-
wave1 = numpy.concatenate([wave1, numpy.zeros(shape=1000), wave1])
86-
wave2 = numpy.concatenate([wave2, numpy.zeros(shape=1000), wave2])
87-
# add some true silence in the mix, so the vocoder is exposed to that as well during training
88-
wave1 = torch.Tensor(wave1)
89-
wave2 = torch.Tensor(wave2)
90-
91-
max_audio_start = len(wave1) - self.samples_per_segment
92-
audio_start = random.randint(0, max_audio_start)
93-
segment1 = wave1[audio_start: audio_start + self.samples_per_segment]
94-
segment2 = wave2[audio_start: audio_start + self.samples_per_segment]
95-
96-
resampled_segment = self.melspec_ap.resample(segment2).float() # 16kHz spectrogram as input, 24kHz wave as output, see Blizzard 2021 DelightfulTTS
97-
melspec = self.melspec_ap.audio_to_mel_spec_tensor(resampled_segment,
98-
explicit_sampling_rate=16000,
99-
normalize=False).transpose(0, 1)[:-1].transpose(0, 1)
100-
return segment1.detach(), melspec.detach()
101-
except RuntimeError:
102-
print("encountered a runtime error, using fallback strategy")
103-
if index == 0:
104-
index = len(self.waves) - 1
105-
return self.__getitem__(index - 1)
78+
wave = self.waves[index][0]
79+
wave = torch.Tensor(wave)
80+
81+
spec = self.waves[index][1]
82+
83+
spec_win, wave_win = get_matching_windows(waveform=wave, spectrogram=spec)
84+
return wave_win.detach(), spec_win.detach()
10685

10786
def __len__(self):
10887
return len(self.waves)
88+
89+
90+
def get_matching_windows(spectrogram, waveform, window_size_wave=24576, hop_length_spec=256, sample_rate_wave=24000, sample_rate_spec=16000):
91+
"""
92+
Cut random matching windows from a spectrogram and waveform with perfectly aligned time axes.
93+
94+
Parameters:
95+
- spectrogram: 2D numpy array (frames x freq_bins) of the spectrogram.
96+
- waveform: 1D numpy array of the ground truth waveform.
97+
- window_size_wave: Size of the window in waveform samples (default: 24576).
98+
- hop_length_spec: Hop length used for spectrogram extraction (default: 200 samples for 16 kHz).
99+
- sample_rate_wave: Sample rate of the waveform (default: 24000 Hz).
100+
- sample_rate_spec: Sample rate used to create the spectrogram (default: 16000 Hz).
101+
102+
Returns:
103+
- spec_window: A window cut from the spectrogram.
104+
- wave_window: A window cut from the waveform.
105+
"""
106+
spectrogram = spectrogram.transpose(0, 1)
107+
108+
# Calculate the number of samples per spectrogram frame in waveform's time
109+
spec_frame_duration = hop_length_spec / sample_rate_spec
110+
wave_sample_duration = 1 / sample_rate_wave
111+
spec_to_wave_conversion_factor = wave_sample_duration / spec_frame_duration
112+
113+
num_frames = int(window_size_wave * spec_to_wave_conversion_factor)
114+
115+
# Ensure we can extract a full window from the spectrogram
116+
max_start_frame = spectrogram.shape[0] - num_frames
117+
if max_start_frame <= 0:
118+
print(f"desired num frames: {num_frames}")
119+
print(f"spec_to_wave_conversion_factor: {spec_to_wave_conversion_factor}")
120+
print(f"spec_len: {spectrogram.shape[0]}")
121+
raise ValueError("Spectrogram is too short to extract the desired window size.")
122+
123+
# Randomly choose a start frame from the spectrogram
124+
start_frame = np.random.randint(0, max_start_frame)
125+
126+
# Calculate the start sample for the waveform based on the chosen start frame
127+
start_sample = int(start_frame // spec_to_wave_conversion_factor)
128+
end_sample = start_sample + window_size_wave
129+
130+
# Ensure the waveform can be fully sliced
131+
if end_sample > len(waveform):
132+
print(f"start_sample: {start_sample}")
133+
print(f"end_sample: {end_sample}")
134+
print(f"start_frame: {start_frame}")
135+
print(f"spec_to_wave_conversion_factor: {spec_to_wave_conversion_factor}")
136+
raise ValueError("Waveform is too short to extract the desired window size.")
137+
138+
# Extract matching windows
139+
spec_window = spectrogram[start_frame:start_frame + num_frames, :].transpose(0, 1)
140+
wave_window = waveform[start_sample:end_sample]
141+
142+
return spec_window, wave_window

Modules/Vocoder/HiFiGAN_Generator.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ class HiFiGAN(torch.nn.Module):
1515
def __init__(self,
1616
in_channels=128,
1717
out_channels=1,
18-
channels=512,
18+
channels=768,
1919
kernel_size=7,
20-
upsample_scales=(8, 6, 4, 2), # CAREFUL: Avocodo assumes that there are always 4 upsample scales, because it takes intermediate results.
21-
upsample_kernel_sizes=(16, 12, 8, 4),
20+
upsample_scales=(8, 6, 2, 2, 2), # CAREFUL: Avocodo assumes that there are always 4 upsample scales, because it takes intermediate results.
21+
upsample_kernel_sizes=(16, 12, 4, 4, 4),
2222
resblock_kernel_sizes=(3, 7, 11),
2323
resblock_dilations=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
2424
use_additional_convs=True,
@@ -87,9 +87,6 @@ def __init__(self,
8787
1,
8888
padding=(kernel_size - 1) // 2, ), torch.nn.Tanh(), )
8989

90-
self.out_proj_x1 = torch.nn.Conv1d(channels // 4, 1, 7, 1, padding=3)
91-
self.out_proj_x2 = torch.nn.Conv1d(channels // 8, 1, 7, 1, padding=3)
92-
9390
# apply weight norm
9491
self.apply_weight_norm()
9592

@@ -118,13 +115,9 @@ def forward(self, c):
118115
for j in range(self.num_blocks):
119116
cs += self.blocks[i * self.num_blocks + j](c)
120117
c = cs / self.num_blocks
121-
if i == 1:
122-
x1 = self.out_proj_x1(c)
123-
elif i == 2:
124-
x2 = self.out_proj_x2(c)
125118
c = self.output_conv(c)
126119

127-
return c, x2, x1
120+
return c
128121

129122
def reset_parameters(self):
130123
"""

Modules/Vocoder/HiFiGAN_train_loop.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
from run_weight_averaging import load_net_bigvgan
2020

2121

22+
def collate_fn(batch):
23+
return torch.stack([x[0] for x in batch]), torch.stack([x[1] for x in batch])
24+
25+
2226
def train_loop(generator,
2327
discriminator,
2428
train_dataset,
@@ -29,7 +33,7 @@ def train_loop(generator,
2933
batch_size=32,
3034
epochs=100,
3135
resume=False,
32-
generator_steps_per_discriminator_step=5,
36+
generator_steps_per_discriminator_step=2,
3337
generator_warmup=30000,
3438
use_wandb=False,
3539
finetune=False
@@ -52,11 +56,12 @@ def train_loop(generator,
5256
train_loader = DataLoader(dataset=train_dataset,
5357
batch_size=batch_size,
5458
shuffle=True,
55-
num_workers=8,
59+
num_workers=16,
5660
pin_memory=True,
5761
drop_last=True,
5862
prefetch_factor=2,
59-
persistent_workers=True)
63+
persistent_workers=True,
64+
collate_fn=collate_fn)
6065

6166
if resume:
6267
path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=model_save_dir)
@@ -93,18 +98,16 @@ def train_loop(generator,
9398

9499
gold_wave = datapoint[0].to(device).unsqueeze(1)
95100
melspec = datapoint[1].to(device)
96-
pred_wave, intermediate_wave_upsampled_twice, intermediate_wave_upsampled_once = g(melspec)
101+
pred_wave = g(melspec)
97102
if torch.any(torch.isnan(pred_wave)):
98103
print("A NaN in the wave! Skipping...")
99104
continue
100105

101106
mel_loss = mel_l1(pred_wave.squeeze(1), gold_wave)
102-
generator_total_loss = mel_loss * 85.0
107+
generator_total_loss = mel_loss * 45.0
103108

104109
if step_counter > generator_warmup + 100: # a bit of warmup helps, but it's not that important
105-
d_outs, d_fmaps = d(wave=pred_wave,
106-
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice,
107-
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once)
110+
d_outs, d_fmaps = d(wave=pred_wave)
108111
adversarial_loss = generator_adv_loss(d_outs)
109112
adversarial_losses.append(adversarial_loss.item())
110113
generator_total_loss = generator_total_loss + adversarial_loss * 2 # based on own experience
@@ -136,8 +139,6 @@ def train_loop(generator,
136139

137140
if step_counter > generator_warmup and step_counter % generator_steps_per_discriminator_step == 0:
138141
d_outs, d_fmaps = d(wave=pred_wave.detach(),
139-
intermediate_wave_upsampled_twice=intermediate_wave_upsampled_twice.detach(),
140-
intermediate_wave_upsampled_once=intermediate_wave_upsampled_once.detach(),
141142
discriminator_train_flag=True)
142143
d_gold_outs, d_gold_fmaps = d(gold_wave,
143144
discriminator_train_flag=True) # have to recompute unfortunately due to autograd behaviour
@@ -168,7 +169,7 @@ def train_loop(generator,
168169
g.train()
169170
delete_old_checkpoints(model_save_dir, keep=5)
170171

171-
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=model_save_dir, n=2)
172+
checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=model_save_dir, n=1)
172173
averaged_model, _ = average_checkpoints(checkpoint_paths, load_func=load_net_bigvgan)
173174
torch.save(averaged_model.state_dict(), os.path.join(model_save_dir, "best.pt"))
174175

0 commit comments

Comments
 (0)