Skip to content

Commit 2795dfd

Browse files
committed
remove transcript caches
1 parent c186583 commit 2795dfd

File tree

2 files changed

+299
-24
lines changed

2 files changed

+299
-24
lines changed
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
"""
2+
This script is meant to be executed from the top level of the repo to make all the paths resolve. It is just here for clean storage.
3+
"""
4+
5+
import itertools
6+
7+
import librosa
8+
import matplotlib.pyplot as plt
9+
import sounddevice
10+
import soundfile
11+
import soundfile as sf
12+
import torch
13+
from speechbrain.pretrained import EncoderClassifier
14+
from torchaudio.transforms import Resample
15+
from tqdm import tqdm
16+
17+
from Modules.Aligner.Aligner import Aligner
18+
from Modules.ToucanTTS.DurationCalculator import DurationCalculator
19+
from Modules.ToucanTTS.EnergyCalculator import EnergyCalculator
20+
from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS
21+
from Modules.ToucanTTS.PitchCalculator import Parselmouth
22+
from Preprocessing.AudioPreprocessor import AudioPreprocessor
23+
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
24+
from Preprocessing.TextFrontend import get_language_id
25+
from Preprocessing.articulatory_features import get_feature_to_index_lookup
26+
from Utility.path_to_transcript_dicts import *
27+
from Utility.storage_config import MODELS_DIR
28+
from Utility.storage_config import PREPROCESSING_DIR
29+
from Utility.utils import float2pcm
30+
31+
32+
class ToucanTTSInterface(torch.nn.Module):
33+
34+
def __init__(self,
35+
device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude.
36+
tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone
37+
vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint
38+
language="eng", # initial language of the model, can be changed later with the setter methods
39+
):
40+
super().__init__()
41+
self.device = device
42+
if not tts_model_path.endswith(".pt"):
43+
tts_model_path = os.path.join(MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt")
44+
45+
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True, device=device)
46+
checkpoint = torch.load(tts_model_path, map_location='cpu')
47+
self.phone2mel = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"])
48+
with torch.no_grad():
49+
self.phone2mel.store_inverse_all() # this also removes weight norm
50+
self.phone2mel = self.phone2mel.to(torch.device(device))
51+
self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
52+
run_opts={"device": str(device)},
53+
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa"))
54+
self.default_utterance_embedding = checkpoint["default_emb"].to(self.device)
55+
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, device=device)
56+
self.phone2mel.eval()
57+
self.lang_id = get_language_id(language)
58+
self.to(torch.device(device))
59+
self.eval()
60+
61+
def set_utterance_embedding(self, path_to_reference_audio="", embedding=None):
62+
if embedding is not None:
63+
self.default_utterance_embedding = embedding.squeeze().to(self.device)
64+
return
65+
if type(path_to_reference_audio) != list:
66+
path_to_reference_audio = [path_to_reference_audio]
67+
if len(path_to_reference_audio) > 0:
68+
for path in path_to_reference_audio:
69+
assert os.path.exists(path)
70+
speaker_embs = list()
71+
for path in path_to_reference_audio:
72+
wave, sr = soundfile.read(path)
73+
if len(wave.shape) > 1: # oh no, we found a stereo audio!
74+
if len(wave[0]) == 2: # let's figure out whether we need to switch the axes
75+
wave = wave.transpose() # if yes, we switch the axes.
76+
wave = librosa.to_mono(wave)
77+
wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32))
78+
speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).squeeze().unsqueeze(0)).squeeze()
79+
speaker_embs.append(speaker_embedding)
80+
self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs)
81+
82+
def set_language(self, lang_id):
83+
self.set_phonemizer_language(lang_id=lang_id)
84+
self.set_accent_language(lang_id=lang_id)
85+
86+
def set_phonemizer_language(self, lang_id):
87+
self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True, device=self.device)
88+
89+
def set_accent_language(self, lang_id):
90+
if lang_id in {'ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so'}:
91+
if lang_id == 'vi-so' or lang_id == 'vi-ctr':
92+
lang_id = 'vie'
93+
elif lang_id == 'spa-lat':
94+
lang_id = 'spa'
95+
elif lang_id == 'pt-br':
96+
lang_id = 'por'
97+
elif lang_id == 'fr-sw' or lang_id == 'fr-be':
98+
lang_id = 'fra'
99+
elif lang_id == 'en-sc' or lang_id == 'en-us':
100+
lang_id = 'eng'
101+
else:
102+
lang_id = 'eng'
103+
self.lang_id = get_language_id(lang_id).to(self.device)
104+
105+
def forward(self,
106+
text,
107+
duration_scaling_factor=1.0,
108+
pitch_variance_scale=1.0,
109+
energy_variance_scale=1.0,
110+
pause_duration_scaling_factor=1.0,
111+
durations=None,
112+
pitch=None,
113+
energy=None,
114+
input_is_phones=False,
115+
prosody_creativity=0.1):
116+
with torch.inference_mode():
117+
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device))
118+
mel, _, _, _ = self.phone2mel(phones,
119+
return_duration_pitch_energy=True,
120+
utterance_embedding=self.default_utterance_embedding,
121+
durations=durations,
122+
pitch=pitch,
123+
energy=energy,
124+
lang_id=self.lang_id,
125+
duration_scaling_factor=duration_scaling_factor,
126+
pitch_variance_scale=pitch_variance_scale,
127+
energy_variance_scale=energy_variance_scale,
128+
pause_duration_scaling_factor=pause_duration_scaling_factor,
129+
prosody_creativity=prosody_creativity)
130+
return mel
131+
132+
def read_to_file(self,
133+
text_list,
134+
file_location,
135+
duration_scaling_factor=1.0,
136+
pitch_variance_scale=1.0,
137+
energy_variance_scale=1.0,
138+
pause_duration_scaling_factor=1.0,
139+
dur_list=None,
140+
pitch_list=None,
141+
energy_list=None,
142+
prosody_creativity=0.1):
143+
if not dur_list:
144+
dur_list = []
145+
if not pitch_list:
146+
pitch_list = []
147+
if not energy_list:
148+
energy_list = []
149+
for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list):
150+
spoken_sentence = self(text,
151+
durations=durations.to(self.device) if durations is not None else None,
152+
pitch=pitch.to(self.device) if pitch is not None else None,
153+
energy=energy.to(self.device) if energy is not None else None,
154+
duration_scaling_factor=duration_scaling_factor,
155+
pitch_variance_scale=pitch_variance_scale,
156+
energy_variance_scale=energy_variance_scale,
157+
pause_duration_scaling_factor=pause_duration_scaling_factor,
158+
prosody_creativity=prosody_creativity)
159+
spoken_sentence = spoken_sentence.cpu()
160+
161+
torch.save(f=file_location, obj=spoken_sentence)
162+
163+
def read_aloud(self,
164+
text,
165+
view=False,
166+
duration_scaling_factor=1.0,
167+
pitch_variance_scale=1.0,
168+
energy_variance_scale=1.0,
169+
blocking=False,
170+
prosody_creativity=0.1):
171+
if text.strip() == "":
172+
return
173+
wav, sr = self(text,
174+
view,
175+
duration_scaling_factor=duration_scaling_factor,
176+
pitch_variance_scale=pitch_variance_scale,
177+
energy_variance_scale=energy_variance_scale,
178+
prosody_creativity=prosody_creativity)
179+
silence = torch.zeros([sr // 2])
180+
wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy()
181+
sounddevice.play(float2pcm(wav), samplerate=sr)
182+
if view:
183+
plt.show()
184+
if blocking:
185+
sounddevice.wait()
186+
187+
188+
class UtteranceCloner:
189+
190+
def __init__(self, model_id, device, language="eng"):
191+
self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id)
192+
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False)
193+
self.tf = ArticulatoryCombinedTextFrontend(language=language, device=device)
194+
self.device = device
195+
acoustic_checkpoint_path = os.path.join(PREPROCESSING_DIR, "libri_all_clean", "Aligner", "aligner.pt")
196+
self.aligner_weights = torch.load(acoustic_checkpoint_path, map_location=device)["asr_model"]
197+
self.acoustic_model = Aligner()
198+
self.acoustic_model = self.acoustic_model.to(self.device)
199+
self.acoustic_model.load_state_dict(self.aligner_weights)
200+
self.acoustic_model.eval()
201+
self.parsel = Parselmouth(reduction_factor=1, fs=16000)
202+
self.energy_calc = EnergyCalculator(reduction_factor=1, fs=16000)
203+
self.dc = DurationCalculator(reduction_factor=1)
204+
205+
def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=False):
206+
wave, sr = sf.read(ref_audio_path)
207+
if self.tf.language != lang:
208+
self.tf = ArticulatoryCombinedTextFrontend(language=lang, device=self.device)
209+
if self.ap.input_sr != sr:
210+
self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=False)
211+
try:
212+
norm_wave = self.ap.normalize_audio(audio=wave)
213+
except ValueError:
214+
print('Something went wrong, the reference wave might be too short.')
215+
raise RuntimeError
216+
217+
norm_wave_length = torch.LongTensor([len(norm_wave)])
218+
text = self.tf.string_to_tensor(transcript, handle_missing=False).squeeze(0)
219+
features = self.ap.audio_to_mel_spec_tensor(audio=norm_wave, explicit_sampling_rate=16000).transpose(0, 1)
220+
feature_length = torch.LongTensor([len(features)]).numpy()
221+
222+
text_without_word_boundaries = list()
223+
indexes_of_word_boundaries = list()
224+
for phoneme_index, vector in enumerate(text):
225+
if vector[get_feature_to_index_lookup()["word-boundary"]] == 0:
226+
text_without_word_boundaries.append(vector.numpy().tolist())
227+
else:
228+
indexes_of_word_boundaries.append(phoneme_index)
229+
matrix_without_word_boundaries = torch.Tensor(text_without_word_boundaries)
230+
231+
alignment_path = self.acoustic_model.inference(features=features.to(self.device),
232+
tokens=matrix_without_word_boundaries.to(self.device),
233+
return_ctc=False)
234+
235+
duration = self.dc(torch.LongTensor(alignment_path), vis=None).cpu()
236+
237+
for index_of_word_boundary in indexes_of_word_boundaries:
238+
duration = torch.cat([duration[:index_of_word_boundary],
239+
torch.LongTensor([0]), # insert a 0 duration wherever there is a word boundary
240+
duration[index_of_word_boundary:]])
241+
242+
energy = self.energy_calc(input_waves=norm_wave.unsqueeze(0),
243+
input_waves_lengths=norm_wave_length,
244+
feats_lengths=feature_length,
245+
text=text,
246+
durations=duration.unsqueeze(0),
247+
durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
248+
pitch = self.parsel(input_waves=norm_wave.unsqueeze(0),
249+
input_waves_lengths=norm_wave_length,
250+
feats_lengths=feature_length,
251+
text=text,
252+
durations=duration.unsqueeze(0),
253+
durations_lengths=torch.LongTensor([len(duration)]))[0].squeeze(0).cpu()
254+
return duration, pitch, energy
255+
256+
def clone_utterance(self,
257+
path_to_reference_audio_for_intonation,
258+
path_to_reference_audio_for_voice,
259+
transcription_of_intonation_reference,
260+
filename_of_result=None,
261+
lang="eng"):
262+
self.tts.set_utterance_embedding(path_to_reference_audio=path_to_reference_audio_for_voice)
263+
duration, pitch, energy = self.extract_prosody(transcription_of_intonation_reference,
264+
path_to_reference_audio_for_intonation,
265+
lang=lang)
266+
self.tts.set_language(lang)
267+
cloned_speech = self.tts(transcription_of_intonation_reference, view=False, durations=duration, pitch=pitch.transpose(0, 1), energy=energy.transpose(0, 1))
268+
if filename_of_result is not None:
269+
torch.save(f=filename_of_result, obj=cloned_speech)
270+
271+
272+
class Reader:
273+
274+
def __init__(self, language, device="cuda", model_id="Meta"):
275+
self.tts = UtteranceCloner(device=device, model_id=model_id, language=language)
276+
self.language = language
277+
278+
def read_texts(self, sentence, filename, speaker_reference):
279+
self.tts.clone_utterance(speaker_reference,
280+
speaker_reference,
281+
sentence,
282+
filename_of_result=filename,
283+
lang=self.language)
284+
285+
286+
if __name__ == '__main__':
287+
288+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
289+
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
290+
291+
all_dict = build_path_to_transcript_libritts_all_clean()
292+
293+
reader = Reader(language="eng")
294+
for path in tqdm(all_dict):
295+
filename = path.replace(".wav", "_synthetic_spec.pt")
296+
reader.read_texts(sentence=all_dict[path], filename=filename, speaker_reference=path)

Recipes/HiFiGAN_e2e.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def run(gpu_id, resume_checkpoint, finetune, resume, model_dir, use_wandb, wandb
2929
model_save_dir = os.path.join(MODELS_DIR, "HiFiGAN_e2e_scratch_direct_cont")
3030
os.makedirs(model_save_dir, exist_ok=True)
3131

32-
print("Preparing new data...")
32+
# To prepare the data, have a look at Modules/Vocoder/run_end-to-end_data_creation
33+
34+
print("Collecting new data...")
3335

3436
file_lists_for_this_run_combined = list()
3537
file_lists_for_this_run_combined_synthetic = list()
@@ -41,29 +43,6 @@ def run(gpu_id, resume_checkpoint, finetune, resume, model_dir, use_wandb, wandb
4143
if os.path.exists(f.replace(".wav", "_synthetic_spec.pt")):
4244
file_lists_for_this_run_combined.append(f)
4345
file_lists_for_this_run_combined_synthetic.append(f.replace(".wav", "_synthetic_spec.pt"))
44-
"""
45-
fl = list(build_path_to_transcript_hui_others().keys())
46-
fisher_yates_shuffle(fl)
47-
fisher_yates_shuffle(fl)
48-
for i, f in enumerate(fl):
49-
if os.path.exists(f.replace(".wav", "_synthetic.wav")):
50-
file_lists_for_this_run_combined.append(f)
51-
file_lists_for_this_run_combined_synthetic.append(f.replace(".wav", "_synthetic.wav"))
52-
fl = list(build_path_to_transcript_aishell3().keys())
53-
fisher_yates_shuffle(fl)
54-
fisher_yates_shuffle(fl)
55-
for i, f in enumerate(fl):
56-
if os.path.exists(f.replace(".wav", "_synthetic.wav")):
57-
file_lists_for_this_run_combined.append(f)
58-
file_lists_for_this_run_combined_synthetic.append(f.replace(".wav", "_synthetic.wav"))
59-
fl = list(build_path_to_transcript_jvs().keys())
60-
fisher_yates_shuffle(fl)
61-
fisher_yates_shuffle(fl)
62-
for i, f in enumerate(fl):
63-
if os.path.exists(f.replace(".wav", "_synthetic.wav")):
64-
file_lists_for_this_run_combined.append(f)
65-
file_lists_for_this_run_combined_synthetic.append(f.replace(".wav", "_synthetic.wav"))
66-
"""
6746
print("filepaths collected")
6847

6948
train_set = HiFiGANDataset(list_of_original_paths=file_lists_for_this_run_combined,

0 commit comments

Comments
 (0)