|
| 1 | +import base64 |
1 | 2 | import collections |
2 | 3 | import os |
3 | 4 | import random |
@@ -34,6 +35,12 @@ def noise_augment_audio(wav): |
34 | 35 | return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) |
35 | 36 |
|
36 | 37 |
|
| 38 | +def string2filename(string): |
| 39 | + # generate a safe and reversible filename based on a string |
| 40 | + filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") |
| 41 | + return filename |
| 42 | + |
| 43 | + |
37 | 44 | class TTSDataset(Dataset): |
38 | 45 | def __init__( |
39 | 46 | self, |
@@ -201,7 +208,7 @@ def get_phonemes(self, idx, text): |
201 | 208 | def get_f0(self, idx): |
202 | 209 | out_dict = self.f0_dataset[idx] |
203 | 210 | item = self.samples[idx] |
204 | | - assert item["audio_file"] == out_dict["audio_file"] |
| 211 | + assert item["audio_unique_name"] == out_dict["audio_unique_name"] |
205 | 212 | return out_dict |
206 | 213 |
|
207 | 214 | @staticmethod |
@@ -561,19 +568,18 @@ def __init__( |
561 | 568 |
|
562 | 569 | def __getitem__(self, index): |
563 | 570 | item = self.samples[index] |
564 | | - ids = self.compute_or_load(item["audio_file"], item["text"]) |
| 571 | + ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"]) |
565 | 572 | ph_hat = self.tokenizer.ids_to_text(ids) |
566 | 573 | return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} |
567 | 574 |
|
568 | 575 | def __len__(self): |
569 | 576 | return len(self.samples) |
570 | 577 |
|
571 | | - def compute_or_load(self, wav_file, text): |
| 578 | + def compute_or_load(self, file_name, text): |
572 | 579 | """Compute phonemes for the given text. |
573 | 580 |
|
574 | 581 | If the phonemes are already cached, load them from cache. |
575 | 582 | """ |
576 | | - file_name = os.path.splitext(os.path.basename(wav_file))[0] |
577 | 583 | file_ext = "_phoneme.npy" |
578 | 584 | cache_path = os.path.join(self.cache_path, file_name + file_ext) |
579 | 585 | try: |
@@ -670,11 +676,11 @@ def __init__( |
670 | 676 |
|
671 | 677 | def __getitem__(self, idx): |
672 | 678 | item = self.samples[idx] |
673 | | - f0 = self.compute_or_load(item["audio_file"]) |
| 679 | + f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"])) |
674 | 680 | if self.normalize_f0: |
675 | 681 | assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" |
676 | 682 | f0 = self.normalize(f0) |
677 | | - return {"audio_file": item["audio_file"], "f0": f0} |
| 683 | + return {"audio_unique_name": item["audio_unique_name"], "f0": f0} |
678 | 684 |
|
679 | 685 | def __len__(self): |
680 | 686 | return len(self.samples) |
@@ -706,8 +712,7 @@ def get_pad_id(self): |
706 | 712 | return self.pad_id |
707 | 713 |
|
708 | 714 | @staticmethod |
709 | | - def create_pitch_file_path(wav_file, cache_path): |
710 | | - file_name = os.path.splitext(os.path.basename(wav_file))[0] |
| 715 | + def create_pitch_file_path(file_name, cache_path): |
711 | 716 | pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") |
712 | 717 | return pitch_file |
713 | 718 |
|
@@ -745,26 +750,26 @@ def denormalize(self, pitch): |
745 | 750 | pitch[zero_idxs] = 0.0 |
746 | 751 | return pitch |
747 | 752 |
|
748 | | - def compute_or_load(self, wav_file): |
| 753 | + def compute_or_load(self, wav_file, audio_unique_name): |
749 | 754 | """ |
750 | 755 | compute pitch and return a numpy array of pitch values |
751 | 756 | """ |
752 | | - pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) |
| 757 | + pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path) |
753 | 758 | if not os.path.exists(pitch_file): |
754 | 759 | pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) |
755 | 760 | else: |
756 | 761 | pitch = np.load(pitch_file) |
757 | 762 | return pitch.astype(np.float32) |
758 | 763 |
|
759 | 764 | def collate_fn(self, batch): |
760 | | - audio_file = [item["audio_file"] for item in batch] |
| 765 | + audio_unique_name = [item["audio_unique_name"] for item in batch] |
761 | 766 | f0s = [item["f0"] for item in batch] |
762 | 767 | f0_lens = [len(item["f0"]) for item in batch] |
763 | 768 | f0_lens_max = max(f0_lens) |
764 | 769 | f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) |
765 | 770 | for i, f0_len in enumerate(f0_lens): |
766 | 771 | f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) |
767 | | - return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens} |
| 772 | + return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens} |
768 | 773 |
|
769 | 774 | def print_logs(self, level: int = 0) -> None: |
770 | 775 | indent = "\t" * level |
|
0 commit comments