Skip to content

Commit d6ad9a0

Browse files
authored
Fix colliding dataset cache file names (#1994)
* Fix colliding dataset cache file names * Remove unused code
1 parent 3faccbd commit d6ad9a0

34 files changed

+17
-12
lines changed

TTS/tts/datasets/dataset.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import collections
23
import os
34
import random
@@ -34,6 +35,12 @@ def noise_augment_audio(wav):
3435
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)
3536

3637

38+
def string2filename(string):
39+
# generate a safe and reversible filename based on a string
40+
filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
41+
return filename
42+
43+
3744
class TTSDataset(Dataset):
3845
def __init__(
3946
self,
@@ -201,7 +208,7 @@ def get_phonemes(self, idx, text):
201208
def get_f0(self, idx):
202209
out_dict = self.f0_dataset[idx]
203210
item = self.samples[idx]
204-
assert item["audio_file"] == out_dict["audio_file"]
211+
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
205212
return out_dict
206213

207214
@staticmethod
@@ -561,19 +568,18 @@ def __init__(
561568

562569
def __getitem__(self, index):
563570
item = self.samples[index]
564-
ids = self.compute_or_load(item["audio_file"], item["text"])
571+
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"])
565572
ph_hat = self.tokenizer.ids_to_text(ids)
566573
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
567574

568575
def __len__(self):
569576
return len(self.samples)
570577

571-
def compute_or_load(self, wav_file, text):
578+
def compute_or_load(self, file_name, text):
572579
"""Compute phonemes for the given text.
573580
574581
If the phonemes are already cached, load them from cache.
575582
"""
576-
file_name = os.path.splitext(os.path.basename(wav_file))[0]
577583
file_ext = "_phoneme.npy"
578584
cache_path = os.path.join(self.cache_path, file_name + file_ext)
579585
try:
@@ -670,11 +676,11 @@ def __init__(
670676

671677
def __getitem__(self, idx):
672678
item = self.samples[idx]
673-
f0 = self.compute_or_load(item["audio_file"])
679+
f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
674680
if self.normalize_f0:
675681
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
676682
f0 = self.normalize(f0)
677-
return {"audio_file": item["audio_file"], "f0": f0}
683+
return {"audio_unique_name": item["audio_unique_name"], "f0": f0}
678684

679685
def __len__(self):
680686
return len(self.samples)
@@ -706,8 +712,7 @@ def get_pad_id(self):
706712
return self.pad_id
707713

708714
@staticmethod
709-
def create_pitch_file_path(wav_file, cache_path):
710-
file_name = os.path.splitext(os.path.basename(wav_file))[0]
715+
def create_pitch_file_path(file_name, cache_path):
711716
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
712717
return pitch_file
713718

@@ -745,26 +750,26 @@ def denormalize(self, pitch):
745750
pitch[zero_idxs] = 0.0
746751
return pitch
747752

748-
def compute_or_load(self, wav_file):
753+
def compute_or_load(self, wav_file, audio_unique_name):
749754
"""
750755
compute pitch and return a numpy array of pitch values
751756
"""
752-
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
757+
pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
753758
if not os.path.exists(pitch_file):
754759
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
755760
else:
756761
pitch = np.load(pitch_file)
757762
return pitch.astype(np.float32)
758763

759764
def collate_fn(self, batch):
760-
audio_file = [item["audio_file"] for item in batch]
765+
audio_unique_name = [item["audio_unique_name"] for item in batch]
761766
f0s = [item["f0"] for item in batch]
762767
f0_lens = [len(item["f0"]) for item in batch]
763768
f0_lens_max = max(f0_lens)
764769
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
765770
for i, f0_len in enumerate(f0_lens):
766771
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
767-
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
772+
return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens}
768773

769774
def print_logs(self, level: int = 0) -> None:
770775
indent = "\t" * level
0 Bytes
Binary file not shown.
-700 Bytes
Binary file not shown.
-244 Bytes
Binary file not shown.
-704 Bytes
Binary file not shown.
-440 Bytes
Binary file not shown.
-652 Bytes
Binary file not shown.
-412 Bytes
Binary file not shown.
-588 Bytes
Binary file not shown.
-208 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)