Skip to content

Commit e6aeac6

Browse files
author
Sarina Meyer
committed
Added IMSToucan
1 parent 8505eb6 commit e6aeac6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+6113
-0
lines changed

IMSToucan/.gitignore

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
.idea
2+
*.pyc
3+
*.png
4+
*.pdf
5+
tensorboard_logs
6+
Corpora
7+
Models
8+
*_graph
9+
*.out
10+
*.wav
11+
audios/notes.txt
12+
audios/
13+
*playground*
14+
apex/
15+
pretrained_models/
16+
*.json
17+
.tmp/
18+
.vscode/
19+
split/
20+
singing/
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import itertools
2+
import os
3+
4+
import librosa.display as lbd
5+
import matplotlib.pyplot as plt
6+
import noisereduce
7+
import sounddevice
8+
import soundfile
9+
import torch
10+
11+
from ..InferenceInterfaces.InferenceArchitectures.InferenceFastSpeech2 import FastSpeech2
12+
from ..InferenceInterfaces.InferenceArchitectures.InferenceHiFiGAN import HiFiGANGenerator
13+
from ..Preprocessing.AudioPreprocessor import AudioPreprocessor
14+
from ..Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
15+
from ..Preprocessing.TextFrontend import get_language_id
16+
from ..TrainingInterfaces.Spectrogram_to_Embedding.StyleEmbedding import StyleEmbedding
17+
18+
19+
class AnonFastSpeech2(torch.nn.Module):
20+
21+
def __init__(self, path_to_hifigan_model, path_to_fastspeech_model, path_to_embed_model, device="cpu", language="en", noise_reduce=False):
22+
super().__init__()
23+
self.device = device
24+
self.audio_preprocessor = AudioPreprocessor(input_sr=16000, output_sr=16000, cut_silence=True, device=self.device)
25+
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True)
26+
checkpoint = torch.load(path_to_fastspeech_model, map_location='cpu')
27+
try:
28+
self.use_lang_id = False
29+
self.phone2mel = FastSpeech2(weights=checkpoint["model"]).to(torch.device(device))
30+
except RuntimeError:
31+
print("Loading a multilingual model, which is strange for this purpose. Please double check that the correct model is being loaded.")
32+
self.use_lang_id = True
33+
self.phone2mel = FastSpeech2(weights=checkpoint["model"], lang_embs=1000).to(torch.device(device))
34+
self.mel2wav = HiFiGANGenerator(path_to_weights=path_to_hifigan_model).to(torch.device(device))
35+
self.style_embedding_function = StyleEmbedding()
36+
check_dict = torch.load(path_to_embed_model, map_location="cpu")
37+
self.style_embedding_function.load_state_dict(check_dict["style_emb_func"])
38+
self.style_embedding_function.to(self.device)
39+
self.default_utterance_embedding = checkpoint["default_emb"].to(self.device)
40+
self.phone2mel.eval()
41+
self.mel2wav.eval()
42+
if self.use_lang_id:
43+
self.lang_id = get_language_id(language)
44+
else:
45+
self.lang_id = None
46+
self.to(torch.device(device))
47+
self.noise_reduce = noise_reduce
48+
if self.noise_reduce:
49+
self.prototypical_noise = None
50+
self.update_noise_profile()
51+
52+
def set_utterance_embedding(self, path_to_reference_audio="", embedding=None):
53+
if embedding is not None:
54+
self.default_utterance_embedding = embedding.squeeze().to(self.device)
55+
return
56+
assert os.path.exists(path_to_reference_audio)
57+
wave, sr = soundfile.read(path_to_reference_audio)
58+
if sr != self.audio_preprocessor.sr:
59+
self.audio_preprocessor = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True, device=self.device)
60+
spec = self.audio_preprocessor.audio_to_mel_spec_tensor(wave).transpose(0, 1)
61+
spec_len = torch.LongTensor([len(spec)])
62+
self.default_utterance_embedding = self.style_embedding_function(spec.unsqueeze(0).to(self.device),
63+
spec_len.unsqueeze(0).to(self.device)).squeeze()
64+
65+
def set_language(self, lang_id):
66+
"""
67+
The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs
68+
"""
69+
self.set_phonemizer_language(lang_id=lang_id)
70+
self.set_accent_language(lang_id=lang_id)
71+
72+
def set_phonemizer_language(self, lang_id):
73+
self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True)
74+
75+
def set_accent_language(self, lang_id):
76+
if self.use_lang_id:
77+
self.lang_id = get_language_id(lang_id).to(self.device)
78+
else:
79+
self.lang_id = None
80+
81+
def forward(self,
82+
text,
83+
view=False,
84+
duration_scaling_factor=1.0,
85+
pitch_variance_scale=1.0,
86+
energy_variance_scale=1.0,
87+
durations=None,
88+
pitch=None,
89+
energy=None,
90+
text_is_phonemes=False):
91+
"""
92+
duration_scaling_factor: reasonable values are 0.5 < scale < 1.5.
93+
1.0 means no scaling happens, higher values increase durations for the whole
94+
utterance, lower values decrease durations for the whole utterance.
95+
pitch_variance_scale: reasonable values are 0.0 < scale < 2.0.
96+
1.0 means no scaling happens, higher values increase variance of the pitch curve,
97+
lower values decrease variance of the pitch curve.
98+
energy_variance_scale: reasonable values are 0.0 < scale < 2.0.
99+
1.0 means no scaling happens, higher values increase variance of the energy curve,
100+
lower values decrease variance of the energy curve.
101+
"""
102+
with torch.inference_mode():
103+
phones = self.text2phone.string_to_tensor(text, input_phonemes=text_is_phonemes).to(torch.device(self.device))
104+
mel, durations, pitch, energy = self.phone2mel(phones,
105+
return_duration_pitch_energy=True,
106+
utterance_embedding=self.default_utterance_embedding,
107+
durations=durations,
108+
pitch=pitch,
109+
energy=energy,
110+
lang_id=self.lang_id)
111+
mel = mel.transpose(0, 1)
112+
wave = self.mel2wav(mel)
113+
if view:
114+
from ..Utility.utils import cumsum_durations
115+
fig, ax = plt.subplots(nrows=2, ncols=1)
116+
ax[0].plot(wave.cpu().numpy())
117+
lbd.specshow(mel.cpu().numpy(),
118+
ax=ax[1],
119+
sr=16000,
120+
cmap='GnBu',
121+
y_axis='mel',
122+
x_axis=None,
123+
hop_length=256)
124+
ax[0].yaxis.set_visible(False)
125+
ax[1].yaxis.set_visible(False)
126+
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy())
127+
ax[1].set_xticks(duration_splits, minor=True)
128+
ax[1].xaxis.grid(True, which='minor')
129+
ax[1].set_xticks(label_positions, minor=False)
130+
ax[1].set_xticklabels(self.text2phone.get_phone_string(text, for_plot_labels=True))
131+
ax[0].set_title(text)
132+
plt.subplots_adjust(left=0.05, bottom=0.1, right=0.95, top=.9, wspace=0.0, hspace=0.0)
133+
plt.show()
134+
if self.noise_reduce:
135+
wave = torch.tensor(noisereduce.reduce_noise(y=wave.cpu().numpy(), y_noise=self.prototypical_noise, sr=48000, stationary=True), device=self.device)
136+
return wave
137+
138+
def read_to_file(self,
139+
text_list,
140+
file_location,
141+
duration_scaling_factor=1.0,
142+
pitch_variance_scale=1.0,
143+
energy_variance_scale=1.0,
144+
silent=False,
145+
dur_list=None,
146+
pitch_list=None,
147+
energy_list=None):
148+
"""
149+
Args:
150+
silent: Whether to be verbose about the process
151+
text_list: A list of strings to be read
152+
file_location: The path and name of the file it should be saved to
153+
energy_list: list of energy tensors to be used for the texts
154+
pitch_list: list of pitch tensors to be used for the texts
155+
dur_list: list of duration tensors to be used for the texts
156+
duration_scaling_factor: reasonable values are 0.5 < scale < 1.5.
157+
1.0 means no scaling happens, higher values increase durations for the whole
158+
utterance, lower values decrease durations for the whole utterance.
159+
pitch_variance_scale: reasonable values are 0.0 < scale < 12.0.
160+
1.0 means no scaling happens, higher values increase variance of the pitch curve,
161+
lower values decrease variance of the pitch curve.
162+
energy_variance_scale: reasonable values are 0.0 < scale < 2.0.
163+
1.0 means no scaling happens, higher values increase variance of the energy curve,
164+
lower values decrease variance of the energy curve.
165+
"""
166+
if not dur_list:
167+
dur_list = []
168+
if not pitch_list:
169+
pitch_list = []
170+
if not energy_list:
171+
energy_list = []
172+
wav = None
173+
silence = torch.zeros([24000])
174+
for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list):
175+
if text.strip() != "":
176+
if not silent:
177+
print("Now synthesizing: {}".format(text))
178+
if wav is None:
179+
if durations is not None:
180+
durations = durations.to(self.device)
181+
if pitch is not None:
182+
pitch = pitch.to(self.device)
183+
if energy is not None:
184+
energy = energy.to(self.device)
185+
wav = self(text,
186+
durations=durations,
187+
pitch=pitch,
188+
energy=energy,
189+
duration_scaling_factor=duration_scaling_factor,
190+
pitch_variance_scale=pitch_variance_scale,
191+
energy_variance_scale=energy_variance_scale).cpu()
192+
wav = torch.cat((wav, silence), 0)
193+
else:
194+
wav = torch.cat((wav, self(text,
195+
durations=durations.to(self.device),
196+
pitch=pitch.to(self.device),
197+
energy=energy.to(self.device),
198+
duration_scaling_factor=duration_scaling_factor,
199+
pitch_variance_scale=pitch_variance_scale,
200+
energy_variance_scale=energy_variance_scale).cpu()), 0)
201+
wav = torch.cat((wav, silence), 0)
202+
soundfile.write(file=file_location, data=wav.cpu().numpy(), samplerate=48000)
203+
204+
def read_aloud(self,
205+
text,
206+
view=False,
207+
duration_scaling_factor=1.0,
208+
pitch_variance_scale=1.0,
209+
energy_variance_scale=1.0,
210+
blocking=False):
211+
if text.strip() == "":
212+
return
213+
wav = self(text,
214+
view,
215+
duration_scaling_factor=duration_scaling_factor,
216+
pitch_variance_scale=pitch_variance_scale,
217+
energy_variance_scale=energy_variance_scale).cpu()
218+
wav = torch.cat((wav, torch.zeros([24000])), 0)
219+
if not blocking:
220+
sounddevice.play(wav.numpy(), samplerate=48000)
221+
else:
222+
sounddevice.play(torch.cat((wav, torch.zeros([12000])), 0).numpy(), samplerate=48000)
223+
sounddevice.wait()

0 commit comments

Comments
 (0)