Skip to content

Commit 3bc32c9

Browse files
author
Sarina Meyer
committed
Updated code to latest version using prosody cloning and GAN embeddings
1 parent e6aeac6 commit 3bc32c9

19 files changed

+1055
-174
lines changed

README.md

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,52 @@
11
# Speaker Anonymization
22

33
This repository contains the speaker anonymization system developed at the Institute for Natural Language Processing
4-
(IMS) at the University of Stuttgart, Germany. The system is described in our paper [*Speaker Anonymization with
5-
Phonetic Intermediate Representations*](https://www.isca-speech.org/archive/interspeech_2022/meyer22b_interspeech.html).
4+
(IMS) at the University of Stuttgart, Germany. The system is described in the following papers:
5+
6+
| Paper | Published at | Branch | Demo |
7+
|-------|--------------|--------|------|
8+
| [Speaker Anonymization with Phonetic Intermediate Representations](https://www.isca-speech.org/archive/interspeech_2022/meyer22b_interspeech.html) | [Interspeech 2022](https://www.interspeech2022.org/) | [phonetic_representations](https://github.com/DigitalPhonetics/speaker-anonymization/tree/phonetic_representations) | [https://huggingface.co/spaces/sarinam/speaker-anonymization](https://huggingface.co/spaces/sarinam/speaker-anonymization) |
9+
| [Anonymizing Speech with Generative Adversarial Networks to Preserve Speaker Privacy](https://arxiv.org/abs/2210.07002) | Soon at [SLT 2022](https://slt2022.org/) | coming soon | coming soon |
10+
11+
If you want to see the code to the respective papers, go to the branch referenced in the table. The latest version
12+
of our system can be found here on the main branch.
613

714
**Check out our live demo on Hugging Face: [https://huggingface.co/spaces/sarinam/speaker-anonymization](https://huggingface.co/spaces/sarinam/speaker-anonymization)**
815

916
**Check also out [our contribution](https://www.voiceprivacychallenge.org/results-2022/docs/3___T04.pdf) to the [Voice Privacy Challenge 2022](https://www.voiceprivacychallenge.org/results-2022/)!**
1017

11-
**The code and live demo to our latest paper [Anonymizing Speech with Generative Adversarial Networks to Preserve Speaker Privacy](https://arxiv.org/abs/2210.07002) is going to be added soon.**
1218

1319
## System Description
1420
The system is based on the Voice Privacy Challenge 2020 which is included as submodule. It uses the basic idea of
1521
speaker embedding anonymization with neural synthesis, and uses the data and evaluation framework of the challenge.
16-
For a detailed description of the system, please read our paper linked above.
22+
For a detailed description of the system, please read our Interspeech paper linked above.
23+
24+
### Added Features
25+
Since the publication of the first paper, some features have been added. The new structure of the pipeline and its
26+
capabilities contain:
27+
* **GAN-based speaker anonymization**: We show in [this paper](https://arxiv.org/abs/2210.07002) that a Wasserstein
28+
GAN can be trained to generate artificial speaker embeddings that resemble real ones but are not connected to any
29+
known speaker -- in our opinion, a crucial condition for anonymization. The current GAN model in the latest
30+
release v2.0 has been trained to generate a custom type of 128-dimensional speaker embeddings (included also in our
31+
speech
32+
synthesis toolkit [IMSToucan](https://github.com/DigitalPhonetics/IMS-Toucan)) instead of x-vectors or ECAPA-TDNN
33+
embeddings.
34+
* **Prosody cloning**: We now provide an option to transfer the original prosody to the anonymized audio via [prosody
35+
cloning](https://arxiv.org/abs/2206.12229)! If you want to avoid an exact cloning but modify it slightly (but
36+
randomly to avoid reversability), use the random offset thresholds. They are given as lower and upper threshold,
37+
as an percentage in relation to the modification. For instance, if you give these thresholds as (80, 120), you
38+
will modify the pitch and energy values of each phone by multiplying it with a random value between 80% and 120%
39+
(leading to either weakening or amplifying the signal).
40+
* **ASR**: Our ASR is now using a [Branchformer](https://arxiv.org/abs/2207.02971) encoder and includes word
41+
boundaries and stress markers in its output.
1742

1843
![architecture](figures/architecture.png)
1944

45+
The current code on the main branch expects the models of release v2.0. If you want to use the pipeline as presented at
46+
Interspeech 2022,
47+
please go to
48+
the
49+
[phonetic_representations branch](https://github.com/DigitalPhonetics/speaker-anonymization/tree/phonetic_representations).
2050

2151
## Installation
2252
### 1. Clone repository
@@ -26,13 +56,14 @@ git clone --recurse-submodules https://github.com/DigitalPhonetics/speaker-anony
2656
```
2757

2858
### 2. Download models
29-
Download the models [from the release page (v1.0)](https://github.com/DigitalPhonetics/speaker-anonymization/releases/tag/v1.0), unzip the folders and place them into a *models* folder as stated in the release notes. Make sure to not unzip the single ASR models, only the outer folder.
59+
Download the models [from the release page (v2.0)](https://github.com/DigitalPhonetics/speaker-anonymization/releases/tag/v2.0), unzip the folders and place them into a *models*
60+
folder as stated in the release notes. Make sure to not unzip the single ASR models, only the outer folder.
3061
```
3162
cd speaker-anonymization
3263
mkdir models
3364
cd models
3465
for file in anonymization asr tts; do
35-
wget https://github.com/DigitalPhonetics/speaker-anonymization/releases/download/v1.0/${file}.zip
66+
wget https://github.com/DigitalPhonetics/speaker-anonymization/releases/download/v2.0/${file}.zip
3667
unzip ${file}.zip
3768
rm ${file}.zip
3869
done
@@ -98,7 +129,8 @@ on CPU (not recommended).
98129
The script will anonymize the development and test data of LibriSpeech and VCTK in three steps:
99130
1. ASR: Recognition of the linguistic content, output in form of text or phone sequences
100131
2. Anonymization: Modification of speaker embeddings, output as torch vectors
101-
3. TTS: Synthesis based on recognized transcription and anonymized speaker embedding, output as audio files (wav)
132+
3. TTS: Synthesis based on recognized transcription, extracted prosody and anonymized speaker embedding, output as
133+
audio files (wav)
102134

103135
Each module produces intermediate results that are saved to disk. A module is only executed if previous intermediate
104136
results for dependent pipeline combination do not exist or if recomputation is forced. Otherwise, the previous
@@ -119,25 +151,6 @@ Finally, for clarity, the most important parts of the evaluation results as well
119151
the [results](results) directory.
120152

121153

122-
## Models
123-
The following table lists all models for each module that are reported in the paper and are included in this
124-
repository. Each model is given by its name in the directory and the name used in the paper. In the *settings*
125-
dictionary in [run_inference.py](run_inference.py), the model name should be used. The *x* for default names the
126-
models that are used in the main configuration of the system.
127-
128-
| Module | Default| Model name | Name in paper|
129-
|--------|--------|------------|--------------|
130-
| ASR | x | asr_tts-phn_en.zip | phones |
131-
| | | asr_stt_en | STT |
132-
| | | asr_tts_en.zip | TTS |
133-
| Anonymization | x | pool_minmax_ecapa+xvector | pool |
134-
| | | pool_raw_ecapa+xvector | pool raw |
135-
| | | random_in-scale_ecapa+xvector | random |
136-
| TTS | x | trained_on_ground_truth_phonemes.pt| Libri100|
137-
| | | trained_on_asr_phoneme_outputs.pt | Libri100 + finetuned |
138-
| | | trained_on_libri600_asr_phoneme_outputs.pt | Libri600 |
139-
| | | trained_on_libri600_ground_truth_phonemes.pt | Libri600 + finetuned |
140-
141154
## Citation
142155
```
143156
@inproceedings{meyer22b_interspeech,

anonymization/WGAN/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Wasserstein GAN with Quadratic Transport Cost for the Generation of Artificial Speaker Embeddings
2+
3+
This model is also used in our [IMS Toucan toolkit](https://github.com/DigitalPhonetics/IMS-Toucan/tree/ControllableMultilingual) to control voices in multi-speaker speech synthesis.
4+
Check it out!

anonymization/WGAN/__init__.py

Whitespace-only changes.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
3+
from .init_wgan import create_wgan
4+
5+
6+
class EmbeddingsGenerator:
7+
8+
def __init__(self, gan_path, device):
9+
self.device = device
10+
self.gan_path = gan_path
11+
12+
self.mean = None
13+
self.std = None
14+
self.wgan = None
15+
16+
self._load_model(self.gan_path)
17+
18+
def generate_embeddings(self, n=1000):
19+
return self.wgan.sample_generator(num_samples=n, nograd=True, return_intermediate=False)
20+
21+
def _load_model(self, path):
22+
gan_checkpoint = torch.load(path, map_location="cpu")
23+
24+
self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
25+
self.wgan.G.load_state_dict(gan_checkpoint['generator_state_dict'])
26+
self.wgan.D.load_state_dict(gan_checkpoint['critic_state_dict'])
27+
28+
self.mean = gan_checkpoint["mean"]
29+
self.std = gan_checkpoint["std"]
30+

anonymization/WGAN/init_wgan.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from .wgan_qc import WassersteinGanQuadraticCost
5+
from .resnet_1 import ResNet_D, ResNet_G
6+
7+
8+
def create_wgan(parameters, device, optimizer='adam'):
9+
if parameters['model'] == 'resnet':
10+
generator, discriminator = init_resnet(parameters)
11+
else:
12+
raise NotImplementedError
13+
14+
if optimizer == 'adam':
15+
optimizer_g = torch.optim.Adam(generator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
16+
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
17+
elif optimizer == 'rmsprop':
18+
optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
19+
optimizer_d = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
20+
21+
criterion = torch.nn.MSELoss()
22+
23+
gan = WassersteinGanQuadraticCost(generator,
24+
discriminator,
25+
optimizer_g,
26+
optimizer_d,
27+
criterion=criterion,
28+
data_dimensions=parameters['data_dim'],
29+
epochs=parameters['epochs'],
30+
batch_size=parameters['batch_size'],
31+
device=device,
32+
n_max_iterations=parameters['n_max_iterations'],
33+
gamma=parameters['gamma'])
34+
35+
return gan
36+
37+
38+
def init_resnet(parameters):
39+
critic = ResNet_D(parameters['data_dim'][-1], parameters['size'], nfilter=parameters['nfilter'],
40+
nfilter_max=parameters['nfilter_max'])
41+
generator = ResNet_G(parameters['data_dim'][-1], parameters['z_dim'], parameters['size'],
42+
nfilter=parameters['nfilter'], nfilter_max=parameters['nfilter_max'])
43+
44+
generator.apply(weights_init_G)
45+
critic.apply(weights_init_D)
46+
47+
return generator, critic
48+
49+
50+
def weights_init_D(m):
51+
classname = m.__class__.__name__
52+
if classname.find('Conv') != -1:
53+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
54+
elif classname.find('BatchNorm') != -1:
55+
nn.init.constant_(m.weight, 1)
56+
nn.init.constant_(m.bias, 0)
57+
58+
59+
def weights_init_G(m):
60+
classname = m.__class__.__name__
61+
if classname.find('Conv') != -1:
62+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
63+
elif classname.find('BatchNorm') != -1:
64+
nn.init.constant_(m.weight, 1)
65+
nn.init.constant_(m.bias, 0)

anonymization/WGAN/resnet_1.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import numpy as np
2+
import torch
3+
import torch.utils.data
4+
import torch.utils.data.distributed
5+
from torch import nn
6+
7+
8+
class ResNet_G(nn.Module):
9+
10+
def __init__(self, data_dim, z_dim, size, nfilter=64, nfilter_max=512, bn=True, res_ratio=0.1, **kwargs):
11+
super().__init__()
12+
self.input_dim = z_dim
13+
self.output_dim = z_dim
14+
self.dropout_rate = 0
15+
16+
s0 = self.s0 = 4
17+
nf = self.nf = nfilter
18+
nf_max = self.nf_max = nfilter_max
19+
self.bn = bn
20+
self.z_dim = z_dim
21+
22+
# Submodules
23+
nlayers = int(np.log2(size / s0))
24+
self.nf0 = min(nf_max, nf * 2 ** (nlayers + 1))
25+
26+
self.fc = nn.Linear(z_dim, self.nf0 * s0 * s0)
27+
if self.bn:
28+
self.bn1d = nn.BatchNorm1d(self.nf0 * s0 * s0)
29+
self.relu = nn.LeakyReLU(0.2, inplace=True)
30+
31+
blocks = []
32+
for i in range(nlayers, 0, -1):
33+
nf0 = min(nf * 2 ** (i + 1), nf_max)
34+
nf1 = min(nf * 2 ** i, nf_max)
35+
blocks += [
36+
ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio),
37+
nn.Upsample(scale_factor=2)
38+
]
39+
40+
nf0 = min(nf * 2, nf_max)
41+
nf1 = min(nf, nf_max)
42+
blocks += [
43+
ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio),
44+
ResNetBlock(nf1, nf1, bn=self.bn, res_ratio=res_ratio)
45+
]
46+
47+
self.resnet = nn.Sequential(*blocks)
48+
self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
49+
50+
self.fc_out = nn.Linear(3 * size * size, data_dim)
51+
52+
def forward(self, z, return_intermediate=False):
53+
batch_size = z.size(0)
54+
out = self.fc(z)
55+
if self.bn:
56+
out = self.bn1d(out)
57+
out = self.relu(out)
58+
if return_intermediate:
59+
l_1 = out.detach().clone()
60+
out = out.view(batch_size, self.nf0, self.s0, self.s0)
61+
62+
out = self.resnet(out)
63+
64+
out = self.conv_img(out)
65+
out = self.relu(out)
66+
out.flatten(1)
67+
out = self.fc_out(out.flatten(1))
68+
69+
if return_intermediate:
70+
return out, l_1
71+
return out
72+
73+
def sample_latent(self, n_samples, z_size):
74+
return torch.randn((n_samples, z_size))
75+
76+
77+
class ResNet_D(nn.Module):
78+
79+
def __init__(self, data_dim, size, nfilter=64, nfilter_max=512, res_ratio=0.1):
80+
super().__init__()
81+
s0 = self.s0 = 4
82+
nf = self.nf = nfilter
83+
nf_max = self.nf_max = nfilter_max
84+
self.size = size
85+
86+
# Submodules
87+
nlayers = int(np.log2(size / s0))
88+
self.nf0 = min(nf_max, nf * 2 ** nlayers)
89+
90+
nf0 = min(nf, nf_max)
91+
nf1 = min(nf * 2, nf_max)
92+
blocks = [
93+
ResNetBlock(nf0, nf0, bn=False, res_ratio=res_ratio),
94+
ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio)
95+
]
96+
97+
self.fc_input = nn.Linear(data_dim, 3 * size * size)
98+
99+
for i in range(1, nlayers + 1):
100+
nf0 = min(nf * 2 ** i, nf_max)
101+
nf1 = min(nf * 2 ** (i + 1), nf_max)
102+
blocks += [
103+
nn.AvgPool2d(3, stride=2, padding=1),
104+
ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio),
105+
]
106+
107+
self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1)
108+
self.relu = nn.LeakyReLU(0.2, inplace=True)
109+
self.resnet = nn.Sequential(*blocks)
110+
111+
self.fc = nn.Linear(self.nf0 * s0 * s0, 1)
112+
113+
def forward(self, x):
114+
batch_size = x.size(0)
115+
116+
out = self.fc_input(x)
117+
out = self.relu(out).view(batch_size, 3, self.size, self.size)
118+
119+
out = self.relu((self.conv_img(out)))
120+
out = self.resnet(out)
121+
out = out.view(batch_size, self.nf0 * self.s0 * self.s0)
122+
out = self.fc(out)
123+
124+
return out
125+
126+
127+
class ResNetBlock(nn.Module):
128+
129+
def __init__(self, fin, fout, fhidden=None, bn=True, res_ratio=0.1):
130+
super().__init__()
131+
# Attributes
132+
self.bn = bn
133+
self.is_bias = not bn
134+
self.learned_shortcut = (fin != fout)
135+
self.fin = fin
136+
self.fout = fout
137+
if fhidden is None:
138+
self.fhidden = min(fin, fout)
139+
else:
140+
self.fhidden = fhidden
141+
self.res_ratio = res_ratio
142+
143+
# Submodules
144+
self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1, bias=self.is_bias)
145+
if self.bn:
146+
self.bn2d_0 = nn.BatchNorm2d(self.fhidden)
147+
self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=self.is_bias)
148+
if self.bn:
149+
self.bn2d_1 = nn.BatchNorm2d(self.fout)
150+
if self.learned_shortcut:
151+
self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
152+
if self.bn:
153+
self.bn2d_s = nn.BatchNorm2d(self.fout)
154+
self.relu = nn.LeakyReLU(0.2, inplace=True)
155+
156+
def forward(self, x):
157+
x_s = self._shortcut(x)
158+
dx = self.conv_0(x)
159+
if self.bn:
160+
dx = self.bn2d_0(dx)
161+
dx = self.relu(dx)
162+
dx = self.conv_1(dx)
163+
if self.bn:
164+
dx = self.bn2d_1(dx)
165+
out = self.relu(x_s + self.res_ratio * dx)
166+
return out
167+
168+
def _shortcut(self, x):
169+
if self.learned_shortcut:
170+
x_s = self.conv_s(x)
171+
if self.bn:
172+
x_s = self.bn2d_s(x_s)
173+
else:
174+
x_s = x
175+
return x_s

0 commit comments

Comments
 (0)