Skip to content

Commit 3faccbd

Browse files
authored
Fix dataset handling with the new embedding file keys (#1991)
1 parent 0a112f7 commit 3faccbd

File tree

5 files changed

+34
-20
lines changed

5 files changed

+34
-20
lines changed

TTS/bin/compute_embeddings.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,9 @@
102102
for idx, fields in enumerate(tqdm(samples)):
103103
class_name = fields[class_name_key]
104104
audio_file = fields["audio_file"]
105-
dataset_name = fields["dataset_name"]
105+
embedding_key = fields["audio_unique_name"]
106106
root_path = fields["root_path"]
107107

108-
relfilepath = os.path.splitext(audio_file.replace(root_path, ""))[0]
109-
embedding_key = f"{dataset_name}#{relfilepath}"
110108
if args.old_file is not None and embedding_key in encoder_manager.clip_ids:
111109
# get the embedding from the old file
112110
embedd = encoder_manager.get_embedding_by_clip(embedding_key)

TTS/tts/datasets/__init__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
5959
return items[:eval_split_size], items[eval_split_size:]
6060

6161

62+
def add_extra_keys(metadata, language, dataset_name):
63+
for item in metadata:
64+
# add language name
65+
item["language"] = language
66+
# add unique audio name
67+
relfilepath = os.path.splitext(item["audio_file"].replace(item["root_path"], ""))[0]
68+
audio_unique_name = f"{dataset_name}#{relfilepath}"
69+
item["audio_unique_name"] = audio_unique_name
70+
71+
return metadata
72+
73+
6274
def load_tts_samples(
6375
datasets: Union[List[Dict], Dict],
6476
eval_split=True,
@@ -111,15 +123,15 @@ def load_tts_samples(
111123
# load train set
112124
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
113125
assert len(meta_data_train) > 0, f" [!] No training samples found in {root_path}/{meta_file_train}"
114-
meta_data_train = [{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_train]
126+
127+
meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)
128+
115129
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
116130
# load evaluation split if set
117131
if eval_split:
118132
if meta_file_val:
119133
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
120-
meta_data_eval = [
121-
{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_eval
122-
]
134+
meta_data_eval = add_extra_keys(meta_data_eval, language, dataset_name)
123135
else:
124136
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
125137
meta_data_eval_all += meta_data_eval

TTS/tts/datasets/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def load_data(self, idx):
256256
"speaker_name": item["speaker_name"],
257257
"language_name": item["language"],
258258
"wav_file_name": os.path.basename(item["audio_file"]),
259+
"audio_unique_name": item["audio_unique_name"],
259260
}
260261
return sample
261262

@@ -397,8 +398,8 @@ def collate_fn(self, batch):
397398
language_ids = None
398399
# get pre-computed d-vectors
399400
if self.d_vector_mapping is not None:
400-
wav_files_names = list(batch["wav_file_name"])
401-
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
401+
embedding_keys = list(batch["audio_unique_name"])
402+
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys]
402403
else:
403404
d_vectors = None
404405

TTS/tts/models/vits.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __getitem__(self, idx):
284284
"wav_file": wav_filename,
285285
"speaker_name": item["speaker_name"],
286286
"language_name": item["language"],
287+
"audio_unique_name": item["audio_unique_name"],
287288
}
288289

289290
@property
@@ -308,6 +309,7 @@ def collate_fn(self, batch):
308309
- language_names: :math:`[B]`
309310
- audiofile_paths: :math:`[B]`
310311
- raw_texts: :math:`[B]`
312+
- audio_unique_names: :math:`[B]`
311313
"""
312314
# convert list of dicts to dict of lists
313315
B = len(batch)
@@ -348,6 +350,7 @@ def collate_fn(self, batch):
348350
"language_names": batch["language_name"],
349351
"audio_files": batch["wav_file"],
350352
"raw_text": batch["raw_text"],
353+
"audio_unique_names": batch["audio_unique_name"],
351354
}
352355

353356

@@ -1470,7 +1473,7 @@ def format_batch(self, batch: Dict) -> Dict:
14701473
# get d_vectors from audio file names
14711474
if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file:
14721475
d_vector_mapping = self.speaker_manager.embeddings
1473-
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]]
1476+
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]]
14741477
d_vectors = torch.FloatTensor(d_vectors)
14751478

14761479
# get language ids from language names

tests/data/ljspeech/speakers.json

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"LJ001-0001.wav": {
2+
"#/wavs/LJ001-0001": {
33
"name": "ljspeech-0",
44
"embedding": [
55
0.05539746582508087,
@@ -260,7 +260,7 @@
260260
-0.09469571709632874
261261
]
262262
},
263-
"LJ001-0002.wav": {
263+
"#/wavs/LJ001-0002": {
264264
"name": "ljspeech-1",
265265
"embedding": [
266266
0.05539746582508087,
@@ -521,7 +521,7 @@
521521
-0.09469571709632874
522522
]
523523
},
524-
"LJ001-0003.wav": {
524+
"#/wavs/LJ001-0003": {
525525
"name": "ljspeech-2",
526526
"embedding": [
527527
0.05539746582508087,
@@ -782,7 +782,7 @@
782782
-0.09469571709632874
783783
]
784784
},
785-
"LJ001-0004.wav": {
785+
"#/wavs/LJ001-0004": {
786786
"name": "ljspeech-3",
787787
"embedding": [
788788
0.05539746582508087,
@@ -1043,7 +1043,7 @@
10431043
-0.09469571709632874
10441044
]
10451045
},
1046-
"LJ001-0005.wav": {
1046+
"#/wavs/LJ001-0005": {
10471047
"name": "ljspeech-4",
10481048
"embedding": [
10491049
0.05539746582508087,
@@ -1304,7 +1304,7 @@
13041304
-0.09469571709632874
13051305
]
13061306
},
1307-
"LJ001-0006.wav": {
1307+
"#/wavs/LJ001-0006": {
13081308
"name": "ljspeech-5",
13091309
"embedding": [
13101310
0.05539746582508087,
@@ -1565,7 +1565,7 @@
15651565
-0.09469571709632874
15661566
]
15671567
},
1568-
"LJ001-0007.wav": {
1568+
"#/wavs/LJ001-0007": {
15691569
"name": "ljspeech-6",
15701570
"embedding": [
15711571
0.05539746582508087,
@@ -1826,7 +1826,7 @@
18261826
-0.09469571709632874
18271827
]
18281828
},
1829-
"LJ001-0008.wav": {
1829+
"#/wavs/LJ001-0008": {
18301830
"name": "ljspeech-7",
18311831
"embedding": [
18321832
0.05539746582508087,
@@ -2087,7 +2087,7 @@
20872087
-0.09469571709632874
20882088
]
20892089
},
2090-
"LJ001-0009.wav": {
2090+
"#/wavs/LJ001-0009": {
20912091
"name": "ljspeech-8",
20922092
"embedding": [
20932093
0.05539746582508087,
@@ -2348,7 +2348,7 @@
23482348
-0.09469571709632874
23492349
]
23502350
},
2351-
"LJ001-0010.wav": {
2351+
"#/wavs/LJ001-0010": {
23522352
"name": "ljspeech-9",
23532353
"embedding": [
23542354
0.05539746582508087,

0 commit comments

Comments
 (0)