Skip to content

Commit 5db335c

Browse files
committed
add prosody cloning and bigvgan vocoder for increased variance
1 parent bb2d9ab commit 5db335c

File tree

6 files changed

+209
-8
lines changed

6 files changed

+209
-8
lines changed

Architectures/Vocoder/BigVGAN.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) 2022 NVIDIA CORPORATION.
2+
# Licensed under the MIT license.
3+
4+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5+
# LICENSE is in incl_licenses directory.
6+
7+
import torch
8+
from alias_free_torch import Activation1d
9+
from torch.nn import Conv1d
10+
from torch.nn import ConvTranspose1d
11+
from torch.nn import ModuleList
12+
from torch.nn.utils import remove_weight_norm
13+
from torch.nn.utils import weight_norm
14+
15+
from Architectures.Vocoder.AMP import AMPBlock1
16+
from Architectures.Vocoder.Snake import SnakeBeta
17+
18+
19+
class BigVGAN(torch.nn.Module):
20+
# this is the main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
21+
22+
def __init__(self,
23+
num_mels=128,
24+
upsample_initial_channel=512,
25+
upsample_rates=(8, 6, 4, 2), # CAREFUL: Avocodo discriminator assumes that there are always 4 upsample scales, because it takes intermediate results.
26+
upsample_kernel_sizes=(16, 12, 8, 4),
27+
resblock_kernel_sizes=(3, 7, 11),
28+
resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
29+
weights=None
30+
):
31+
super(BigVGAN, self).__init__()
32+
33+
self.num_kernels = len(resblock_kernel_sizes)
34+
self.num_upsamples = len(upsample_rates)
35+
36+
# pre conv
37+
self.conv_pre = weight_norm(Conv1d(num_mels, upsample_initial_channel, 7, 1, padding=3))
38+
39+
# transposed conv-based upsamplers. does not apply anti-aliasing
40+
self.ups = ModuleList()
41+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
42+
self.ups.append(ModuleList([
43+
weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i),
44+
upsample_initial_channel // (2 ** (i + 1)),
45+
k, u, padding=(k - u) // 2))
46+
]))
47+
48+
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
49+
self.resblocks = ModuleList()
50+
for i in range(len(self.ups)):
51+
ch = upsample_initial_channel // (2 ** (i + 1))
52+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
53+
self.resblocks.append(AMPBlock1(ch, k, d))
54+
55+
# post conv
56+
activation_post = SnakeBeta(ch, alpha_logscale=True)
57+
self.activation_post = Activation1d(activation=activation_post)
58+
59+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
60+
61+
# weight initialization
62+
for i in range(len(self.ups)):
63+
self.ups[i].apply(init_weights)
64+
self.conv_post.apply(init_weights)
65+
66+
# for Avocodo discriminator
67+
self.out_proj_x1 = torch.nn.Conv1d(upsample_initial_channel // 4, 1, 7, 1, padding=3)
68+
self.out_proj_x2 = torch.nn.Conv1d(upsample_initial_channel // 8, 1, 7, 1, padding=3)
69+
70+
if weights is not None:
71+
self.load_state_dict(weights)
72+
73+
def forward(self, x):
74+
# pre conv
75+
x = self.conv_pre(x)
76+
77+
for i in range(self.num_upsamples):
78+
# upsampling
79+
for i_up in range(len(self.ups[i])):
80+
x = self.ups[i][i_up](x)
81+
# AMP blocks
82+
xs = None
83+
for j in range(self.num_kernels):
84+
if xs is None:
85+
xs = self.resblocks[i * self.num_kernels + j](x)
86+
else:
87+
xs += self.resblocks[i * self.num_kernels + j](x)
88+
x = xs / self.num_kernels
89+
if i == 1:
90+
x1 = self.out_proj_x1(x)
91+
elif i == 2:
92+
x2 = self.out_proj_x2(x)
93+
94+
# post conv
95+
x = self.activation_post(x)
96+
x = self.conv_post(x)
97+
x = torch.tanh(x)
98+
99+
return x, x2, x1
100+
101+
def remove_weight_norm(self):
102+
print('Removing weight norm...')
103+
for l in self.ups:
104+
for l_i in l:
105+
remove_weight_norm(l_i)
106+
for l in self.resblocks:
107+
l.remove_weight_norm()
108+
remove_weight_norm(self.conv_pre)
109+
remove_weight_norm(self.conv_post)
110+
111+
112+
def init_weights(m, mean=0.0, std=0.01):
113+
classname = m.__class__.__name__
114+
if classname.find("Conv") != -1:
115+
m.weight.data.normal_(mean, std)
116+
117+
118+
def apply_weight_norm(m):
119+
classname = m.__class__.__name__
120+
if classname.find("Conv") != -1:
121+
weight_norm(m)
122+
123+
124+
def get_padding(kernel_size, dilation=1):
125+
return int((kernel_size * dilation - dilation) / 2)
126+
127+
128+
if __name__ == '__main__':
129+
print(BigVGAN()(torch.randn([1, 128, 100]))[0].shape)

InferenceInterfaces/ToucanTTSInterface.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from Architectures.EmbeddingModel.StyleEmbedding import StyleEmbedding
1313
from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS
14+
from Architectures.Vocoder.BigVGAN import BigVGAN
1415
from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN
1516
from Preprocessing.AudioPreprocessor import AudioPreprocessor
1617
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
@@ -28,7 +29,7 @@ def __init__(self,
2829
vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint
2930
embedding_model_path=None,
3031
language="eng", # initial language of the model, can be changed later with the setter methods
31-
):
32+
use_bigvgan=False):
3233
super().__init__()
3334
self.device = device
3435
if not tts_model_path.endswith(".pt"):
@@ -67,8 +68,13 @@ def __init__(self,
6768
################################
6869
# load mel to wave model #
6970
################################
70-
vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu")
71-
self.vocoder = HiFiGAN()
71+
if use_bigvgan:
72+
vocoder_checkpoint = torch.load(os.path.join(MODELS_DIR, f"Vocoder", "bigvgan.pt"), map_location="cpu")
73+
self.vocoder = BigVGAN()
74+
else:
75+
vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu")
76+
self.vocoder = HiFiGAN()
77+
7278
self.vocoder.load_state_dict(vocoder_checkpoint)
7379
self.vocoder = self.vocoder.to(device).eval()
7480
self.vocoder.remove_weight_norm()

InferenceInterfaces/UtteranceCloner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class UtteranceCloner:
2323
Useful for Privacy Applications
2424
"""
2525

26-
def __init__(self, model_id, device, language="eng"):
27-
self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id)
26+
def __init__(self, model_id, device, language="eng", use_bigvgan=False):
27+
self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id, use_bigvgan=use_bigvgan)
2828
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False)
2929
self.tf = ArticulatoryCombinedTextFrontend(language=language)
3030
self.device = device
@@ -41,7 +41,7 @@ def __init__(self, model_id, device, language="eng"):
4141
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
4242
# this to false globally during model loading rather than using inference_mode or no_grad
4343

44-
def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=True):
44+
def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=False):
4545
acoustic_model = Aligner()
4646
acoustic_model.load_state_dict(self.aligner_weights)
4747
acoustic_model = acoustic_model.to(self.device)

run_asvspoof_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
PATH_TO_GENERATION_FILE = "p1_ttsvc_surrogate.tsv"
1212
PATH_TO_OUTPUT_DIR = "asv_spoof_outputs_no_pros"
1313
DEVICE = "cuda"
14+
USE_BIGVGAN = False
1415

1516

1617
def build_path_to_transcript_dict_mls_english():
@@ -28,7 +29,7 @@ def build_path_to_transcript_dict_mls_english():
2829

2930
if __name__ == '__main__':
3031
print("loading model...")
31-
tts = ToucanTTSInterface(device=DEVICE, tts_model_path="ASVSpoof")
32+
tts = ToucanTTSInterface(device=DEVICE, tts_model_path="ASVSpoof", use_bigvgan=USE_BIGVGAN)
3233
print("prepare path to transcript lookup...")
3334
path_to_transcript_dict = build_path_to_transcript_dict_mls_english()
3435
filename_to_path = dict()
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
3+
import librosa
4+
import soundfile as sf
5+
from tqdm import tqdm
6+
7+
from InferenceInterfaces.UtteranceCloner import UtteranceCloner
8+
from Utility.utils import float2pcm
9+
10+
PATH_TO_MLS_ENGLISH_TRAIN = "/mount/resources/speech/corpora/MultiLingLibriSpeech/mls_english/train"
11+
PATH_TO_GENERATION_FILE = "p1_ttsvc_surrogate.tsv"
12+
PATH_TO_OUTPUT_DIR = "asv_spoof_outputs_with_pros"
13+
DEVICE = "cuda"
14+
USE_BIGVGAN = False
15+
16+
17+
def build_path_to_transcript_dict_mls_english():
18+
path_to_transcript = dict()
19+
with open(os.path.join(PATH_TO_MLS_ENGLISH_TRAIN, "transcripts.txt"), "r", encoding="utf8") as file:
20+
lookup = file.read()
21+
for line in lookup.split("\n"):
22+
if line.strip() != "":
23+
fields = line.split("\t")
24+
wav_folders = fields[0].split("_")
25+
wav_path = f"{PATH_TO_MLS_ENGLISH_TRAIN}/audio/{wav_folders[0]}/{wav_folders[1]}/{fields[0]}.flac"
26+
path_to_transcript[wav_path] = fields[1]
27+
return path_to_transcript
28+
29+
30+
if __name__ == '__main__':
31+
print("loading model...")
32+
uc = UtteranceCloner(model_id="ASVSpoof", device=DEVICE, language="eng", use_bigvgan=USE_BIGVGAN)
33+
print("prepare path to transcript lookup...")
34+
path_to_transcript_dict = build_path_to_transcript_dict_mls_english()
35+
filename_to_path = dict()
36+
for p in path_to_transcript_dict:
37+
filename_to_path[p.split("/")[-1].rstrip(".flac")] = p
38+
with open(PATH_TO_GENERATION_FILE, "r") as file:
39+
generation_list = file.read().split("\n")
40+
os.makedirs(PATH_TO_OUTPUT_DIR, exist_ok=True)
41+
print("generating audios...")
42+
for generation_item in tqdm(generation_list):
43+
if generation_item == "":
44+
continue
45+
speaker_id, voice_sources, _, prosody_source, output_name = generation_item.split()
46+
voice_source_list = voice_sources.split(",")
47+
transcript = path_to_transcript_dict[filename_to_path[prosody_source]]
48+
source_list = list()
49+
for source in voice_source_list:
50+
source_list.append(filename_to_path[source])
51+
52+
cloned_utterance = uc.clone_utterance(path_to_reference_audio_for_voice=source_list,
53+
path_to_reference_audio_for_intonation=filename_to_path[prosody_source],
54+
transcription_of_intonation_reference=transcript)
55+
56+
resampled_utt = librosa.resample(cloned_utterance, orig_sr=24000, target_sr=16000)
57+
sf.write(file=f"{PATH_TO_OUTPUT_DIR}/" + output_name + ".flac", data=float2pcm(resampled_utt), samplerate=16000, subtype="PCM_16")

run_model_downloader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,21 @@ def download_models():
3232
reporthook=report)
3333

3434
#############
35-
print("Downloading Vocoder")
35+
print("Downloading HiFiGAN Vocoder")
3636
os.makedirs(os.path.join(MODELS_DIR, "Vocoder"), exist_ok=True)
3737
filename, headers = urllib.request.urlretrieve(
3838
url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v2.asvspoof/hifigan.pt",
3939
filename=os.path.abspath(os.path.join(MODELS_DIR, "Vocoder", "best.pt")),
4040
reporthook=report)
4141

42+
#############
43+
print("Downloading BigVGAN Vocoder")
44+
os.makedirs(os.path.join(MODELS_DIR, "Vocoder"), exist_ok=True)
45+
filename, headers = urllib.request.urlretrieve(
46+
url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v2.asvspoof/bigvgan.pt",
47+
filename=os.path.abspath(os.path.join(MODELS_DIR, "Vocoder", "bigvgan.pt")),
48+
reporthook=report)
49+
4250
#############
4351
print("Downloading Embedding Model")
4452
os.makedirs(os.path.join(MODELS_DIR, "Embedding"), exist_ok=True)

0 commit comments

Comments
 (0)