Skip to content

Commit 81075a6

Browse files
authored
Merge pull request #11 from Flux9665/multi_lingual_multi_speaker
fix language ID not being used properly
2 parents 97c9006 + 944cede commit 81075a6

27 files changed

+1260
-1944
lines changed

InferenceInterfaces/InferenceFastSpeech2.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import librosa.display as lbd
55
import matplotlib.pyplot as plt
6+
import noisereduce
67
import sounddevice
78
import soundfile
89
import torch
@@ -16,45 +17,65 @@
1617

1718
class InferenceFastSpeech2(torch.nn.Module):
1819

19-
def __init__(self, device="cpu", model_name="Meta", language="en"):
20+
def __init__(self, device="cpu", model_name="Meta", language="en", noise_reduce=False):
2021
super().__init__()
2122
self.device = device
2223
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True)
2324
checkpoint = torch.load(os.path.join("Models", f"FastSpeech2_{model_name}", "best.pt"), map_location='cpu')
25+
self.use_lang_id = True
2426
try:
2527
self.phone2mel = FastSpeech2(weights=checkpoint["model"]).to(torch.device(device)) # multi speaker multi language
2628
except RuntimeError:
2729
try:
30+
self.use_lang_id = False
2831
self.phone2mel = FastSpeech2(weights=checkpoint["model"], lang_embs=None).to(torch.device(device)) # multi speaker single language
2932
except RuntimeError:
3033
self.phone2mel = FastSpeech2(weights=checkpoint["model"], lang_embs=None, utt_embed_dim=None).to(torch.device(device)) # single speaker
3134
self.mel2wav = HiFiGANGenerator(path_to_weights=os.path.join("Models", "HiFiGAN_combined", "best.pt")).to(torch.device(device))
3235
self.default_utterance_embedding = checkpoint["default_emb"].to(self.device)
3336
self.phone2mel.eval()
3437
self.mel2wav.eval()
35-
self.lang_id = get_language_id(language)
38+
if self.use_lang_id:
39+
self.lang_id = get_language_id(language)
40+
else:
41+
self.lang_id = None
3642
self.to(torch.device(device))
43+
self.noise_reduce = noise_reduce
44+
if self.noise_reduce:
45+
self.prototypical_noise = None
46+
self.update_noise_profile()
3747

3848
def set_utterance_embedding(self, path_to_reference_audio):
3949
wave, sr = soundfile.read(path_to_reference_audio)
4050
self.default_utterance_embedding = ProsodicConditionExtractor(sr=sr).extract_condition_from_reference_wave(wave).to(self.device)
51+
if self.noise_reduce:
52+
self.update_noise_profile()
53+
54+
def update_noise_profile(self):
55+
self.noise_reduce = False
56+
self.prototypical_noise = self("~." * 100, input_is_phones=True).cpu().numpy()
57+
self.noise_reduce = True
4158

4259
def set_language(self, lang_id):
4360
"""
4461
The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs
4562
"""
4663
self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True)
47-
self.lang_id = get_language_id(lang_id).to(self.device)
64+
if self.use_lang_id:
65+
self.lang_id = get_language_id(lang_id).to(self.device)
66+
else:
67+
self.lang_id = None
4868

49-
def forward(self, text, view=False, durations=None, pitch=None, energy=None):
69+
def forward(self, text, view=False, durations=None, pitch=None, energy=None, input_is_phones=False):
5070
with torch.inference_mode():
51-
phones = self.text2phone.string_to_tensor(text).to(torch.device(self.device))
71+
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
5272
mel, durations, pitch, energy = self.phone2mel(phones,
5373
return_duration_pitch_energy=True,
5474
utterance_embedding=self.default_utterance_embedding,
5575
durations=durations,
5676
pitch=pitch,
57-
energy=energy)
77+
energy=energy,
78+
lang_id=self.lang_id)
5879
mel = mel.transpose(0, 1)
5980
wave = self.mel2wav(mel)
6081
if view:
@@ -78,13 +99,19 @@ def forward(self, text, view=False, durations=None, pitch=None, energy=None):
7899
ax[0].set_title(text)
79100
plt.subplots_adjust(left=0.05, bottom=0.1, right=0.95, top=.9, wspace=0.0, hspace=0.0)
80101
plt.show()
102+
if self.noise_reduce:
103+
wave = torch.tensor(noisereduce.reduce_noise(y=wave.cpu().numpy(), y_noise=self.prototypical_noise, sr=48000, stationary=True), device=self.device)
81104
return wave
82105

83106
def read_to_file(self, text_list, file_location, silent=False, dur_list=None, pitch_list=None, energy_list=None):
84107
"""
85-
:param silent: Whether to be verbose about the process
86-
:param text_list: A list of strings to be read
87-
:param file_location: The path and name of the file it should be saved to
108+
Args:
109+
silent: Whether to be verbose about the process
110+
text_list: A list of strings to be read
111+
file_location: The path and name of the file it should be saved to
112+
energy_list: list of energy tensors to be used for the texts
113+
pitch_list: list of pitch tensors to be used for the texts
114+
dur_list: list of duration tensors to be used for the texts
88115
"""
89116
if not dur_list:
90117
dur_list = []

Preprocessing/ArticulatoryCombinedTextFrontend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,7 @@ def english_text_expansion(text):
288288

289289

290290
def get_language_id(language):
291-
if language == "en":
292-
return torch.LongTensor([0])
293-
elif language == "de":
291+
if language == "de":
294292
return torch.LongTensor([1])
295293
elif language == "el":
296294
return torch.LongTensor([2])
@@ -312,6 +310,8 @@ def get_language_id(language):
312310
return torch.LongTensor([10])
313311
elif language == "it":
314312
return torch.LongTensor([11])
313+
elif language == "en":
314+
return torch.LongTensor([12])
315315

316316

317317
if __name__ == '__main__':

Preprocessing/AudioPreprocessor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class AudioPreprocessor:
1313

14-
def __init__(self, input_sr, output_sr=None, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False, device="cpu"):
14+
def __init__(self, input_sr, output_sr=None, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False, device="cpu", fmax_for_spec=8000):
1515
"""
1616
The parameters are by default set up to do well
1717
on a 16kHz signal. A different sampling rate may
@@ -28,6 +28,7 @@ def __init__(self, input_sr, output_sr=None, melspec_buckets=80, hop_length=256,
2828
self.mel_buckets = melspec_buckets
2929
self.meter = pyln.Meter(input_sr)
3030
self.final_sr = input_sr
31+
self.fmax_for_spec = fmax_for_spec
3132
if cut_silence:
3233
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
3334
# careful: assumes 16kHz or 8kHz audio
@@ -58,7 +59,12 @@ def cut_silence_from_audio(self, audio):
5859
"""
5960
with torch.inference_mode():
6061
speech_timestamps = self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr)
61-
return audio[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
62+
try:
63+
result = audio[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
64+
return result
65+
except IndexError:
66+
print("Audio might be too short to cut silences from front and back.")
67+
return audio
6268

6369
def to_mono(self, x):
6470
"""
@@ -82,7 +88,7 @@ def normalize_loudness(self, audio):
8288
peak_normed = numpy.divide(loud_normed, peak)
8389
return peak_normed
8490

85-
def logmelfilterbank(self, audio, sampling_rate, fmin=40, fmax=8000, eps=1e-10):
91+
def logmelfilterbank(self, audio, sampling_rate, fmin=40, fmax=None, eps=1e-10):
8692
"""
8793
Compute log-Mel filterbank
8894
@@ -91,6 +97,8 @@ def logmelfilterbank(self, audio, sampling_rate, fmin=40, fmax=8000, eps=1e-10):
9197
compatibility, this is kept for now. If there is ever a reason to completely re-train
9298
all models, this would be a good opportunity to make the switch.
9399
"""
100+
if fmax is None:
101+
fmax = self.fmax_for_spec
94102
if isinstance(audio, torch.Tensor):
95103
audio = audio.numpy()
96104
# get amplitude spectrogram

TrainingInterfaces/Text_to_Spectrogram/AutoAligner/AlignerDataset.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from speechbrain.pretrained import EncoderClassifier
99
from torch.multiprocessing import Manager
1010
from torch.multiprocessing import Process
11-
from torch.multiprocessing import set_start_method
1211
from torch.utils.data import Dataset
1312
from tqdm import tqdm
1413

@@ -28,17 +27,12 @@ def __init__(self,
2827
cut_silences=True,
2928
rebuild_cache=False,
3029
verbose=False,
31-
device="cpu"):
30+
device="cpu",
31+
phone_input=False):
3232
os.makedirs(cache_dir, exist_ok=True)
3333
if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
34-
if (device == "cuda" or device == torch.device("cuda")) and cut_silences:
35-
try:
36-
set_start_method('spawn') # in order to be able to make use of cuda in multiprocessing
37-
except RuntimeError:
38-
pass
39-
elif cut_silences:
40-
torch.set_num_threads(1)
4134
if cut_silences:
35+
torch.set_num_threads(1)
4236
torch.hub.load(repo_or_dir='snakers4/silero-vad',
4337
model='silero_vad',
4438
force_reload=False,
@@ -68,7 +62,8 @@ def __init__(self,
6862
max_len_in_seconds,
6963
cut_silences,
7064
verbose,
71-
device),
65+
"cpu",
66+
phone_input),
7267
daemon=True))
7368
process_list[-1].start()
7469
for process in process_list:
@@ -140,7 +135,8 @@ def cache_builder_process(self,
140135
max_len,
141136
cut_silences,
142137
verbose,
143-
device):
138+
device,
139+
phone_input):
144140
process_internal_dataset_chunk = list()
145141
tf = ArticulatoryCombinedTextFrontend(language=lang, use_word_boundaries=False)
146142
_, sr = sf.read(path_list[0])
@@ -171,9 +167,9 @@ def cache_builder_process(self,
171167
# raw audio preprocessing is done
172168
transcript = self.path_to_transcript_dict[path]
173169
try:
174-
cached_text = tf.string_to_tensor(transcript, handle_missing=False).squeeze(0).cpu().numpy()
170+
cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy()
175171
except KeyError:
176-
tf.string_to_tensor(transcript, handle_missing=True).squeeze(0).cpu().numpy()
172+
tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy()
177173
continue # we skip sentences with unknown symbols
178174
try:
179175
if len(cached_text[0]) != 66:

TrainingInterfaces/Text_to_Spectrogram/FastSpeech2/meta_train_loop.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import random
2+
13
import librosa.display as lbd
24
import matplotlib.pyplot as plt
35
import torch
@@ -48,13 +50,7 @@ def train_loop(net,
4850
train_iters.append(iter(train_loaders[-1]))
4951
default_embeddings = {"en": None, "de": None, "el": None, "es": None, "fi": None, "ru": None, "hu": None, "nl": None, "fr": None}
5052
for index, lang in enumerate(["en", "de", "el", "es", "fi", "ru", "hu", "nl", "fr"]):
51-
default_embedding = None
52-
for datapoint in datasets[index]:
53-
if default_embedding is None:
54-
default_embedding = datapoint[7].squeeze()
55-
else:
56-
default_embedding = default_embedding + datapoint[7].squeeze()
57-
default_embeddings[lang] = (default_embedding / len(datasets[index])).to(device)
53+
default_embeddings[lang] = datasets[index][0][7].squeeze().to(device)
5854
optimizer = torch.optim.RAdam(net.parameters(), lr=lr, eps=1.0e-06, weight_decay=0.0)
5955
grad_scaler = GradScaler()
6056
scheduler = WarmupScheduler(optimizer, warmup_steps=warmup_steps)
@@ -84,7 +80,7 @@ def train_loop(net,
8480
# =============================
8581
for step in tqdm(range(step_counter, steps)):
8682
batches = []
87-
for index in range(len(datasets)):
83+
for index in random.sample(list(range(len(datasets))), len(datasets)):
8884
# we get one batch for each task (i.e. language in this case)
8985
try:
9086
batch = next(train_iters[index])
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import random
2+
3+
import torch
4+
from torch.utils.data import ConcatDataset
5+
6+
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.FastSpeech2 import FastSpeech2
7+
from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.fastspeech2_train_loop import train_loop
8+
from Utility.corpus_preparation import prepare_fastspeech_corpus
9+
from Utility.path_to_transcript_dicts import *
10+
11+
12+
def run(gpu_id, resume_checkpoint, finetune, model_dir, resume):
13+
if gpu_id == "cpu":
14+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
15+
device = torch.device("cpu")
16+
17+
else:
18+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
19+
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
20+
device = torch.device("cuda")
21+
22+
torch.manual_seed(131714)
23+
random.seed(131714)
24+
torch.random.manual_seed(131714)
25+
26+
print("Preparing")
27+
28+
if model_dir is not None:
29+
save_dir = model_dir
30+
else:
31+
save_dir = os.path.join("Models", "FastSpeech2_English")
32+
os.makedirs(save_dir, exist_ok=True)
33+
34+
datasets = list()
35+
datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_nancy(),
36+
corpus_dir=os.path.join("Corpora", "Nancy"),
37+
lang="en"))
38+
39+
datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_ljspeech(),
40+
corpus_dir=os.path.join("Corpora", "LJSpeech"),
41+
lang="en"))
42+
43+
datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_libritts_all_clean(),
44+
corpus_dir=os.path.join("Corpora", "libri_all_clean"),
45+
lang="en"))
46+
47+
datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_vctk(),
48+
corpus_dir=os.path.join("Corpora", "vctk"),
49+
lang="en"))
50+
51+
datasets.append(prepare_fastspeech_corpus(transcript_dict=build_path_to_transcript_dict_nvidia_hifitts(),
52+
corpus_dir=os.path.join("Corpora", "hifi"),
53+
lang="en"))
54+
55+
train_set = ConcatDataset(datasets)
56+
57+
model = FastSpeech2(lang_embs=100)
58+
# because we want to finetune it, we treat it as multilingual, even though we are only interested in German here
59+
60+
print("Training model")
61+
train_loop(net=model,
62+
train_dataset=train_set,
63+
device=device,
64+
save_directory=save_dir,
65+
steps=500000,
66+
batch_size=10,
67+
lang="en",
68+
lr=0.001,
69+
epochs_per_save=1,
70+
warmup_steps=4000,
71+
path_to_checkpoint="Models/FastSpeech2_Meta/best.pt",
72+
fine_tune=True,
73+
resume=resume)

0 commit comments

Comments
 (0)