diff --git a/TTS/server/server.py b/TTS/server/server.py index 6b2141a9aa..89e8891599 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -57,6 +57,35 @@ def convert_boolean(x): parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.") parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.") parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.") + + # --- NEW: option serveur pour la vitesse/rythme par défaut (VITS) --- + # length_scale < 1.0 = plus rapide ; > 1.0 = plus lent + parser.add_argument( + "--length_scale_default", + type=float, + default=1.0, + help="Default VITS length_scale. Smaller=faster, larger=slower.", + ) + # -------------------------------------------------------------------- + + # --- NEW: options serveur pour le contrôle de la variabilité à l'inférence --- + # inference_noise_scale: variation prosodique globale au décodeur + # inference_noise_scale_dp: variation sur le duration predictor (rythme local) + # Par défaut None = ne pas forcer et laisser la valeur du modèle. + parser.add_argument( + "--inference_noise_scale_default", + type=float, + default=None, + help="Default inference_noise_scale. Higher=more variation. None keeps model default.", + ) + parser.add_argument( + "--inference_noise_scale_dp_default", + type=float, + default=None, + help="Default inference_noise_scale_dp (duration predictor). Higher=more timing variation. None keeps model default.", + ) + # ------------------------------------------------------------------------------ + return parser @@ -127,6 +156,126 @@ def convert_boolean(x): use_gst = synthesizer.tts_config.get("use_gst", False) app = Flask(__name__) +# --- NEW: helpers pour length_scale côté requête --- +# On lit un éventuel paramètre de requête/entête et on applique sur le modèle VITS. +# Acceptés: header "length-scale" ou "length_scale", champs GET/POST "length_scale". +def _read_length_scale_from_request() -> Union[None, float]: + val = ( + request.headers.get("length-scale") + or request.headers.get("length_scale") + or request.values.get("length_scale") + ) + if val is None or val == "": + return None + try: + return float(val) + except Exception: + return None # on ignore silencieusement si non numérique + + +def _apply_length_scale_temporarily(ls: Union[None, float]): + """ + Applique length_scale sur tts_model si possible et renvoie un callback de reset. + - Si ls est None: on applique la valeur par défaut serveur. + - Si le modèle ne possède pas 'length_scale': on ne fait rien. + """ + # valeur à appliquer pour cette synthèse + target = args.length_scale_default if ls is None else ls + + if hasattr(synthesizer.tts_model, "length_scale"): + # sauvegarde pour reset après synthèse + old = synthesizer.tts_model.length_scale + synthesizer.tts_model.length_scale = target + + def _reset(): + try: + synthesizer.tts_model.length_scale = old + except Exception: + pass + + return _reset + else: + # pas de support length_scale sur ce modèle + def _noop(): + return None + + return _noop +# --------------------------------------------------- + +# --- NEW: helpers pour inference_noise_scale et inference_noise_scale_dp --- +# Lecture depuis headers/params et application temporaire avec reset après inférence. +def _read_float_from_request(*keys) -> Union[None, float]: + """ + Lit la première clé disponible dans headers ou params et tente un float. + Retourne None si absente ou invalide. + """ + for k in keys: + v = request.headers.get(k) + if v is None or v == "": + v = request.values.get(k) + if v not in (None, ""): + try: + return float(v) + except Exception: + return None + return None + + +def _apply_inference_noise_scale_temporarily(val: Union[None, float]): + """ + Applique inference_noise_scale pour cette requête. + - Si val est None: utilise --inference_noise_scale_default s'il est fourni, sinon ne force rien. + """ + to_apply = val if val is not None else args.inference_noise_scale_default + if to_apply is None: + # rien à faire + def _noop(): + return None + return _noop + + if hasattr(synthesizer.tts_model, "inference_noise_scale"): + old = synthesizer.tts_model.inference_noise_scale + synthesizer.tts_model.inference_noise_scale = to_apply + + def _reset(): + try: + synthesizer.tts_model.inference_noise_scale = old + except Exception: + pass + return _reset + else: + def _noop(): + return None + return _noop + + +def _apply_inference_noise_scale_dp_temporarily(val: Union[None, float]): + """ + Applique inference_noise_scale_dp pour cette requête. + - Si val est None: utilise --inference_noise_scale_dp_default s'il est fourni, sinon ne force rien. + """ + to_apply = val if val is not None else args.inference_noise_scale_dp_default + if to_apply is None: + def _noop(): + return None + return _noop + + if hasattr(synthesizer.tts_model, "inference_noise_scale_dp"): + old = synthesizer.tts_model.inference_noise_scale_dp + synthesizer.tts_model.inference_noise_scale_dp = to_apply + + def _reset(): + try: + synthesizer.tts_model.inference_noise_scale_dp = old + except Exception: + pass + return _reset + else: + def _noop(): + return None + return _noop +# --------------------------------------------------------------------------- + def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]: """Transform an uri style_wav, in either a string (path to wav file to be use for style transfer) @@ -197,10 +346,33 @@ def tts(): style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") style_wav = style_wav_uri_to_dict(style_wav) + # --- NEW: lecture et application temporaire du length_scale --- + # Permet de contrôler la vitesse/rythme depuis la requête. + req_ls = _read_length_scale_from_request() + _reset_length_scale = _apply_length_scale_temporarily(req_ls) + # -------------------------------------------------------------- + + # --- NEW: lecture et application des bruits d'inférence --- + # Headers/params acceptés: + # - inference-noise-scale, inference_noise_scale + # - inference-noise-scale-dp, inference_noise_scale_dp + req_ins = _read_float_from_request("inference-noise-scale", "inference_noise_scale") + req_ins_dp = _read_float_from_request("inference-noise-scale-dp", "inference_noise_scale_dp") + _reset_ins = _apply_inference_noise_scale_temporarily(req_ins) + _reset_ins_dp = _apply_inference_noise_scale_dp_temporarily(req_ins_dp) + # ----------------------------------------------------------- + print(f" > Model input: {text}") print(f" > Speaker Idx: {speaker_idx}") print(f" > Language Idx: {language_idx}") - wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) + try: + wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) + finally: + # --- NEW: on rétablit les valeurs précédentes après synthèse --- + _reset_length_scale() + _reset_ins() + _reset_ins_dp() + # --------------------------------------------------------------- out = io.BytesIO() synthesizer.save_wav(wavs, out) return send_file(out, mimetype="audio/wav") @@ -241,10 +413,42 @@ def mary_tts_api_process(): data = parse_qs(request.get_data(as_text=True)) # NOTE: we ignore param. LOCALE and VOICE for now since we have only one active model text = data.get("INPUT_TEXT", [""])[0] + # --- NEW: support length_scale en POST MaryTTS (optionnel) --- + # Si un client envoie length_scale dans le form-url-encoded, on le lit ici. + ls_str = data.get("length_scale", [None])[0] + req_ls = float(ls_str) if ls_str not in (None, "") else None + + # --- NEW: support des bruits d'inférence en POST MaryTTS --- + ins_str = data.get("inference_noise_scale", [None])[0] + ins_dp_str = data.get("inference_noise_scale_dp", [None])[0] + req_ins = float(ins_str) if ins_str not in (None, "") else None + req_ins_dp = float(ins_dp_str) if ins_dp_str not in (None, "") else None + # ----------------------------------------------------------- else: text = request.args.get("INPUT_TEXT", "") + # --- NEW: support length_scale en GET MaryTTS (optionnel) --- + req_ls = _read_length_scale_from_request() + # --- NEW: support des bruits d'inférence en GET MaryTTS --- + req_ins = _read_float_from_request("inference-noise-scale", "inference_noise_scale") + req_ins_dp = _read_float_from_request("inference-noise-scale-dp", "inference_noise_scale_dp") + # ------------------------------------------------------------ + + # --- NEW: application temporaire du length_scale --- + _reset_length_scale = _apply_length_scale_temporarily(req_ls) + # --- NEW: application temporaire des bruits d'inférence --- + _reset_ins = _apply_inference_noise_scale_temporarily(req_ins) + _reset_ins_dp = _apply_inference_noise_scale_dp_temporarily(req_ins_dp) + # --------------------------------------------------- + print(f" > Model input: {text}") - wavs = synthesizer.tts(text) + try: + wavs = synthesizer.tts(text) + finally: + # --- NEW: reset après synthèse --- + _reset_length_scale() + _reset_ins() + _reset_ins_dp() + # --------------------------------- out = io.BytesIO() synthesizer.save_wav(wavs, out) return send_file(out, mimetype="audio/wav") diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 19fb25bef8..a059f57f5a 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -2,10 +2,13 @@ import collections import os import random +import time +import tempfile from typing import Dict, List, Union import numpy as np import torch +import torch.distributed as dist import tqdm from torch.utils.data import Dataset @@ -15,11 +18,80 @@ import mutagen -# to prevent too many open files error as suggested here -# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +# Eviter "too many open files" torch.multiprocessing.set_sharing_strategy("file_system") +# ----------------------- +# DDP helpers +# ----------------------- +def _is_ddp(): + return dist.is_available() and dist.is_initialized() + + +def _rank(): + return dist.get_rank() if _is_ddp() else 0 + + +def _barrier(): + if _is_ddp(): + dist.barrier() + + +# ----------------------- +# Utils +# ----------------------- +def _np_load_retry(path, retries=5, delay=0.1): + """ + np.load robuste sans pickle. Retry si lecture partielle pendant une écriture atomique. + """ + for _ in range(retries): + try: + return np.load(path, allow_pickle=False) + except FileNotFoundError: + return None + except Exception: + time.sleep(delay) + try: + return np.load(path, allow_pickle=False) + except Exception: + return None + + +def _np_save_atomic(path, arr): + """ + Ecriture atomique: fichier temporaire puis replace. + """ + d = os.path.dirname(path) + os.makedirs(d, exist_ok=True) + with tempfile.NamedTemporaryFile(dir=d, delete=False) as tmp: + tmp_name = tmp.name + np.save(tmp_name, arr, allow_pickle=False) + os.replace(tmp_name, path) + + +def _safe_makedirs_once(path: str): + """Créer un dossier une seule fois à travers les ranks et synchroniser.""" + if path is None: + return + if _is_ddp(): + if _rank() == 0: + os.makedirs(path, exist_ok=True) + _barrier() + else: + os.makedirs(path, exist_ok=True) + + +def _ensure_finite_np(x, name): + if not np.all(np.isfinite(x)): + raise RuntimeError(f"Non-finite in {name} (numpy)") + + +def _ensure_finite_t(x: torch.Tensor, name): + if not torch.isfinite(x).all(): + raise RuntimeError(f"Non-finite in {name} (torch)") + + def _parse_sample(item): language_name = None attn_file = None @@ -39,7 +111,7 @@ def noise_augment_audio(wav): def string2filename(string): - # generate a safe and reversible filename based on a string + # nom de fichier sûr et réversible filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") return filename @@ -47,8 +119,9 @@ def string2filename(string): def get_audio_size(audiopath): extension = audiopath.rpartition(".")[-1].lower() if extension not in {"mp3", "wav", "flac"}: - raise RuntimeError(f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!") - + raise RuntimeError( + f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" + ) audio_info = mutagen.File(audiopath).info return int(audio_info.length * audio_info.sample_rate) @@ -80,65 +153,6 @@ def __init__( start_by_longest: bool = False, verbose: bool = False, ): - """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. - - If you need something different, you can subclass and override. - - Args: - outputs_per_step (int): Number of time frames predicted per step. - - compute_linear_spec (bool): compute linear spectrogram if True. - - ap (TTS.tts.utils.AudioProcessor): Audio processor object. - - samples (list): List of dataset samples. - - tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else - use the given. Defaults to None. - - compute_f0 (bool): compute f0 if True. Defaults to False. - - compute_energy (bool): compute energy if True. Defaults to False. - - f0_cache_path (str): Path to store f0 cache. Defaults to None. - - energy_cache_path (str): Path to store energy cache. Defaults to None. - - return_wav (bool): Return the waveform of the sample. Defaults to False. - - batch_group_size (int): Range of batch randomization after sorting - sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a - batch. Set 0 to disable. Defaults to 0. - - min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. - Defaults to 0. - - max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. - Defaults to float("inf"). - - min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. - Defaults to 0. - - max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. - The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to - this value if you encounter an OOM error in training. Defaults to float("inf"). - - phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a - separate file. Defaults to None. - - precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. - - speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the - embedding layer. Defaults to None. - - d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None. - - use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. - - start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. - - verbose (bool): Print diagnostic information. Defaults to false. - """ super().__init__() self.batch_group_size = batch_group_size self._samples = samples @@ -162,8 +176,6 @@ def __init__( self.start_by_longest = start_by_longest self.verbose = verbose - self.rescue_item_idx = 1 - self.pitch_computed = False self.tokenizer = tokenizer if self.tokenizer.use_phonemes: @@ -220,8 +232,9 @@ def print_logs(self, level: int = 0) -> None: print(f"{indent}| > Number of instances : {len(self.samples)}") def load_wav(self, filename): - waveform = self.ap.load_wav(filename) + waveform = self.ap.load_wav(filename).astype(np.float32) assert waveform.size > 0 + _ensure_finite_np(waveform, "waveform") return waveform def get_phonemes(self, idx, text): @@ -244,63 +257,64 @@ def get_energy(self, idx): @staticmethod def get_attn_mask(attn_file): - return np.load(attn_file) + attn = np.load(attn_file, allow_pickle=False) + _ensure_finite_np(attn, "attn") + return attn def get_token_ids(self, idx, text): if self.tokenizer.use_phonemes: token_ids = self.get_phonemes(idx, text)["token_ids"] else: token_ids = self.tokenizer.text_to_ids(text) - return np.array(token_ids, dtype=np.int32) - - def load_data(self, idx): - item = self.samples[idx] + token_ids = np.array(token_ids, dtype=np.int32) + _ensure_finite_np(token_ids, "token_ids") + return token_ids + + def load_data(self, start_idx): + # Remplace la récursion par une boucle bornée + tries = 0 + idx = start_idx + n = len(self.samples) + while tries < n: + item = self.samples[idx] + + raw_text = item["text"] + wav = self.load_wav(item["audio_file"]) + + if self.use_noise_augment: + wav = noise_augment_audio(wav) + _ensure_finite_np(wav, "wav_aug") + + token_ids = self.get_token_ids(idx, item["text"]) + + attn = None + if "alignment_file" in item: + attn = self.get_attn_mask(item["alignment_file"]) + + if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: + tries += 1 + idx = (idx + 1) % n + continue + + f0 = self.get_f0(idx)["f0"] if self.compute_f0 else None + energy = self.get_energy(idx)["energy"] if self.compute_energy else None + + sample = { + "raw_text": raw_text, + "token_ids": token_ids, + "wav": wav, + "pitch": f0, + "energy": energy, + "attn": attn, + "item_idx": item["audio_file"], + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "wav_file_name": os.path.basename(item["audio_file"]), + "audio_unique_name": item["audio_unique_name"], + } + return sample - raw_text = item["text"] - - wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) - - # apply noise for augmentation - if self.use_noise_augment: - wav = noise_augment_audio(wav) - - # get token ids - token_ids = self.get_token_ids(idx, item["text"]) - - # get pre-computed attention maps - attn = None - if "alignment_file" in item: - attn = self.get_attn_mask(item["alignment_file"]) - - # after phonemization the text length may change - # this is a shareful 🤭 hack to prevent longer phonemes - # TODO: find a better fix - if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: - self.rescue_item_idx += 1 - return self.load_data(self.rescue_item_idx) - - # get f0 values - f0 = None - if self.compute_f0: - f0 = self.get_f0(idx)["f0"] - energy = None - if self.compute_energy: - energy = self.get_energy(idx)["energy"] - - sample = { - "raw_text": raw_text, - "token_ids": token_ids, - "wav": wav, - "pitch": f0, - "energy": energy, - "attn": attn, - "item_idx": item["audio_file"], - "speaker_name": item["speaker_name"], - "language_name": item["language"], - "wav_file_name": os.path.basename(item["audio_file"]), - "audio_unique_name": item["audio_unique_name"], - } - return sample + raise RuntimeError(" [!] No valid sample found after scanning the dataset.") @staticmethod def _compute_lengths(samples): @@ -351,12 +365,8 @@ def _select_samples_by_idx(idxs, samples): return samples_new def preprocess_samples(self): - r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length - range. - """ samples = self._compute_lengths(self.samples) - # sort items based on the sequence length in ascending order text_lengths = [i["text_length"] for i in samples] audio_lengths = [i["audio_length"] for i in samples] text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) @@ -365,7 +375,6 @@ def preprocess_samples(self): ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) samples = self._select_samples_by_idx(keep_idx, samples) - sorted_idxs = self.sort_by_length(samples) if self.start_by_longest: @@ -378,13 +387,9 @@ def preprocess_samples(self): if len(samples) == 0: raise RuntimeError(" [!] No samples left") - # shuffle batch groups - # create batches with similar length items - # the larger the `batch_group_size`, the higher the length variety in a batch. if self.batch_group_size > 0: samples = self.create_buckets(samples, self.batch_group_size) - # update items to the new sorted items audio_lengths = [s["audio_length"] for s in samples] text_lengths = [s["text_length"] for s in samples] self.samples = samples @@ -403,58 +408,35 @@ def preprocess_samples(self): @staticmethod def _sort_batch(batch, text_lengths): - """Sort the batch by the input text length for RNN efficiency. - - Args: - batch (Dict): Batch returned by `__getitem__`. - text_lengths (List[int]): Lengths of the input character sequences. - """ text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) batch = [batch[idx] for idx in ids_sorted_decreasing] return batch, text_lengths, ids_sorted_decreasing def collate_fn(self, batch): - r""" - Perform preprocessing and create a final data batch: - 1. Sort batch instances by text-length - 2. Convert Audio signal to features. - 3. PAD sequences wrt r. - 4. Load to Torch. - """ - - # Puts each data field into a tensor with outer dimension batch size + # Puts each data field into un tenseur avec dimension batch if isinstance(batch[0], collections.abc.Mapping): - token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) + token_ids_lengths = np.array([len(d["token_ids"]) for d in batch], dtype=np.int64) - # sort items with text input length for RNN efficiency + # sort by text length batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) - # convert list of dicts to dict of lists + # list[dict] -> dict[list] batch = {k: [dic[k] for dic in batch] for k in batch[0]} - # get language ids from language names - if self.language_id_mapping is not None: - language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] - else: - language_ids = None - # get pre-computed d-vectors + # mappings + language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] if self.language_id_mapping is not None else None if self.d_vector_mapping is not None: embedding_keys = list(batch["audio_unique_name"]) d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys] else: d_vectors = None + speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] if self.speaker_id_mapping else None - # get numerical speaker ids from speaker names - if self.speaker_id_mapping: - speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] - else: - speaker_ids = None - # compute features + # features mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] - + for m in mel: + _ensure_finite_np(m, "mel") mel_lengths = [m.shape[1] for m in mel] - - # lengths adjusted by the reduction factor mel_lengths_adjusted = [ m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) if m.shape[1] % self.outputs_per_step @@ -462,76 +444,65 @@ def collate_fn(self, batch): for m in mel ] - # compute 'stop token' targets - stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] - - # PAD stop targets + # stop targets + stop_targets = [np.array([0.0] * (ml - 1) + [1.0], dtype=np.float32) for ml in mel_lengths] stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) - # PAD sequences with longest instance in the batch token_ids = prepare_data(batch["token_ids"]).astype(np.int32) + mel = prepare_tensor(mel, self.outputs_per_step).transpose(0, 2, 1) # BxTxD - # PAD features with longest instance - mel = prepare_tensor(mel, self.outputs_per_step) - - # B x D x T --> B x T x D - mel = mel.transpose(0, 2, 1) - - # convert things to pytorch token_ids_lengths = torch.LongTensor(token_ids_lengths) token_ids = torch.LongTensor(token_ids) mel = torch.FloatTensor(mel).contiguous() mel_lengths = torch.LongTensor(mel_lengths) stop_targets = torch.FloatTensor(stop_targets) - # speaker vectors if d_vectors is not None: d_vectors = torch.FloatTensor(d_vectors) - + _ensure_finite_t(d_vectors, "d_vectors") if speaker_ids is not None: speaker_ids = torch.LongTensor(speaker_ids) - if language_ids is not None: language_ids = torch.LongTensor(language_ids) - # compute linear spectrogram linear = None if self.compute_linear_spec: linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] - linear = prepare_tensor(linear, self.outputs_per_step) - linear = linear.transpose(0, 2, 1) + for l in linear: + _ensure_finite_np(l, "linear") + linear = prepare_tensor(linear, self.outputs_per_step).transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] linear = torch.FloatTensor(linear).contiguous() - # format waveforms wav_padded = None if self.return_wav: wav_lengths = [w.shape[0] for w in batch["wav"]] max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length wav_lengths = torch.LongTensor(wav_lengths) - wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) + wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len, dtype=torch.float32) for i, w in enumerate(batch["wav"]): mel_length = mel_lengths_adjusted[i] w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") w = w[: mel_length * self.ap.hop_length] - wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) + wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w.astype(np.float32)) wav_padded.transpose_(1, 2) - # format F0 if self.compute_f0: - pitch = prepare_data(batch["pitch"]) + pitch = prepare_data(batch["pitch"]).astype(np.float32) assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" - pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + pitch = torch.from_numpy(pitch)[:, None, :].contiguous().to(torch.float32) + _ensure_finite_t(pitch, "pitch_batch") else: pitch = None - # format energy + if self.compute_energy: - energy = prepare_data(batch["energy"]) + energy = prepare_data(batch["energy"]).astype(np.float32) assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}" - energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT + energy = torch.from_numpy(energy)[:, None, :].contiguous().to(torch.float32) + _ensure_finite_t(energy, "energy_batch") else: energy = None - # format attention masks + attns = None if batch["attn"][0] is not None: attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] @@ -540,9 +511,13 @@ def collate_fn(self, batch): pad1 = token_ids.shape[1] - attn.shape[0] assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" attn = np.pad(attn, [[0, pad1], [0, pad2]]) + _ensure_finite_np(attn, "attn_padded") attns[idx] = attn attns = prepare_tensor(attns, self.outputs_per_step) attns = torch.FloatTensor(attns).unsqueeze(1) + # Sanity check finale des dims + assert attns.shape[2] == token_ids.shape[1] and attns.shape[3] == mel.shape[1], \ + f"[!] attn shape mismatch {attns.shape} vs tokens {token_ids.shape} / mel {mel.shape}" return { "token_id": token_ids, @@ -564,36 +539,16 @@ def collate_fn(self, batch): "audio_unique_names": batch["audio_unique_name"], } - raise TypeError( - ( - "batch must contain tensors, numbers, dicts or lists;\ - found {}".format( - type(batch[0]) - ) - ) - ) + raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}".format(type(batch[0])))) class PhonemeDataset(Dataset): - """Phoneme Dataset for converting input text to phonemes and then token IDs - - At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data - loading latency. If `cache_path` is already present, it skips the pre-computation. - - Args: - samples (Union[List[List], List[Dict]]): - List of samples. Each sample is a list or a dict. - - tokenizer (TTSTokenizer): - Tokenizer to convert input text to phonemes. - - cache_path (str): - Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. - - precompute_num_workers (int): - Number of workers used for pre-computing the phonemes. Defaults to 0. """ - + DDP-safe phoneme caching: + * rank 0 crée le dossier et peut pré-calculer + * écriture atomique .npy + * lecture robuste sans pickle, avec retries + """ def __init__( self, samples: Union[List[Dict], List[List]], @@ -604,9 +559,25 @@ def __init__( self.samples = samples self.tokenizer = tokenizer self.cache_path = cache_path - if cache_path is not None and not os.path.exists(cache_path): - os.makedirs(cache_path) - self.precompute(precompute_num_workers) + + # créer dossier une fois + need_precompute = False + if cache_path is not None: + if _is_ddp(): + if _rank() == 0 and not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + need_precompute = True + _barrier() + else: + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + need_precompute = True + + # précompute seulement par rank 0, puis sync + if need_precompute and precompute_num_workers > 0: + if _rank() == 0: + self.precompute(num_workers=0) # 0 pour éviter write concurrent + _barrier() def __getitem__(self, index): item = self.samples[index] @@ -618,31 +589,41 @@ def __len__(self): return len(self.samples) def compute_or_load(self, file_name, text, language): - """Compute phonemes for the given text. - - If the phonemes are already cached, load them from cache. - """ file_ext = "_phoneme.npy" cache_path = os.path.join(self.cache_path, file_name + file_ext) - try: - ids = np.load(cache_path) - except FileNotFoundError: - ids = self.tokenizer.text_to_ids(text, language=language) - np.save(cache_path, ids) + + # lecture robuste si présent + if os.path.exists(cache_path): + ids = _np_load_retry(cache_path) + if ids is not None: + ids = np.asarray(ids, dtype=np.int64) + _ensure_finite_np(ids, "phoneme_ids_cached") + return ids + + # calcul local + ids = self.tokenizer.text_to_ids(text, language=language) + ids = np.asarray(ids, dtype=np.int64) + _ensure_finite_np(ids, "phoneme_ids_new") + + # seul rank 0 écrit (atomique) + if _rank() == 0: + try: + _np_save_atomic(cache_path, ids) + except Exception: + try: + if os.path.exists(cache_path): + os.remove(cache_path) + except Exception: + pass return ids def get_pad_id(self): - """Get pad token ID for sequence padding""" return self.tokenizer.pad_id - def precompute(self, num_workers=1): - """Precompute phonemes for all samples. - - We use pytorch dataloader because we are lazy. - """ + def precompute(self, num_workers=0): print("[*] Pre-computing phonemes...") with tqdm.tqdm(total=len(self)) as pbar: - batch_size = num_workers if num_workers > 0 else 1 + batch_size = 1 dataloder = torch.utils.data.DataLoader( batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn ) @@ -670,29 +651,7 @@ def print_logs(self, level: int = 0) -> None: class F0Dataset: - """F0 Dataset for computing F0 from wav files in CPU - - Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It - also computes the mean and std of F0 values if `normalize_f0` is True. - - Args: - samples (Union[List[List], List[Dict]]): - List of samples. Each sample is a list or a dict. - - ap (AudioProcessor): - AudioProcessor to compute F0 from wav files. - - cache_path (str): - Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. - Defaults to None. - - precompute_num_workers (int): - Number of workers used for pre-computing the F0 values. Defaults to 0. - - normalize_f0 (bool): - Whether to normalize F0 values by mean and std. Defaults to True. - """ - + """DDP-safe F0 cache (single writer + atomic save).""" def __init__( self, samples: Union[List[List], List[Dict]], @@ -711,9 +670,24 @@ def __init__( self.pad_id = 0.0 self.mean = None self.std = None - if cache_path is not None and not os.path.exists(cache_path): - os.makedirs(cache_path) - self.precompute(precompute_num_workers) + + need_precompute = False + if cache_path is not None: + if _is_ddp(): + if _rank() == 0 and not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + need_precompute = True + _barrier() + else: + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + need_precompute = True + + if need_precompute and precompute_num_workers > 0: + if _rank() == 0: + self.precompute(num_workers=0) + _barrier() + if normalize_f0: self.load_stats(cache_path) @@ -731,8 +705,7 @@ def __len__(self): def precompute(self, num_workers=0): print("[*] Pre-computing F0s...") with tqdm.tqdm(total=len(self)) as pbar: - batch_size = num_workers if num_workers > 0 else 1 - # we do not normalize at preproessing + batch_size = 1 normalize_f0 = self.normalize_f0 self.normalize_f0 = False dataloder = torch.utils.data.DataLoader( @@ -741,67 +714,71 @@ def precompute(self, num_workers=0): computed_data = [] for batch in dataloder: f0 = batch["f0"] - computed_data.append(f for f in f0) + computed_data.extend(f0) # corrige append de générateur pbar.update(batch_size) self.normalize_f0 = normalize_f0 if self.normalize_f0: - computed_data = [tensor for batch in computed_data for tensor in batch] # flatten pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) pitch_stats = {"mean": pitch_mean, "std": pitch_std} - np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + _np_save_atomic(os.path.join(self.cache_path, "pitch_stats.npy"), pitch_stats) def get_pad_id(self): return self.pad_id @staticmethod def create_pitch_file_path(file_name, cache_path): - pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") - return pitch_file + return os.path.join(cache_path, file_name + "_pitch.npy") @staticmethod def _compute_and_save_pitch(ap, wav_file, pitch_file=None): - wav = ap.load_wav(wav_file) - pitch = ap.compute_f0(wav) + wav = ap.load_wav(wav_file).astype(np.float32) + _ensure_finite_np(wav, "f0_wav") + pitch = ap.compute_f0(wav).astype(np.float32) + _ensure_finite_np(pitch, "f0_values") if pitch_file: - np.save(pitch_file, pitch) + _np_save_atomic(pitch_file, pitch) return pitch @staticmethod def compute_pitch_stats(pitch_vecs): - nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) if len(pitch_vecs) > 0 else np.array([1.0], dtype=np.float32) mean, std = np.mean(nonzeros), np.std(nonzeros) - return mean, std + return np.float32(mean), np.float32(std if std > 1e-8 else 1.0) def load_stats(self, cache_path): stats_path = os.path.join(cache_path, "pitch_stats.npy") - stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"].astype(np.float32) - self.std = stats["std"].astype(np.float32) + stats = _np_load_retry(stats_path) + if stats is None: + return + stats = stats.item() + self.mean = np.float32(stats["mean"]) + self.std = np.float32(stats["std"] if stats["std"] > 1e-8 else 1.0) def normalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] pitch = pitch - self.mean pitch = pitch / self.std pitch[zero_idxs] = 0.0 - return pitch + return pitch.astype(np.float32) def denormalize(self, pitch): zero_idxs = np.where(pitch == 0.0)[0] - pitch *= self.std - pitch += self.mean + pitch = pitch * self.std + pitch = pitch + self.mean pitch[zero_idxs] = 0.0 - return pitch + return pitch.astype(np.float32) def compute_or_load(self, wav_file, audio_unique_name): - """ - compute pitch and return a numpy array of pitch values - """ pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path) - if not os.path.exists(pitch_file): - pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) - else: - pitch = np.load(pitch_file) + if os.path.exists(pitch_file): + pitch = _np_load_retry(pitch_file) + if pitch is not None: + pitch = pitch.astype(np.float32) + _ensure_finite_np(pitch, "f0_cached") + return pitch + + pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file if _rank() == 0 else None) return pitch.astype(np.float32) def collate_fn(self, batch): @@ -809,9 +786,10 @@ def collate_fn(self, batch): f0s = [item["f0"] for item in batch] f0_lens = [len(item["f0"]) for item in batch] f0_lens_max = max(f0_lens) - f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) + # FloatTensor, pas LongTensor + f0s_torch = torch.full((len(f0s), f0_lens_max), fill_value=self.get_pad_id(), dtype=torch.float32) for i, f0_len in enumerate(f0_lens): - f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) + f0s_torch[i, :f0_len] = torch.from_numpy(f0s[i].astype(np.float32)) return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens} def print_logs(self, level: int = 0) -> None: @@ -822,29 +800,7 @@ def print_logs(self, level: int = 0) -> None: class EnergyDataset: - """Energy Dataset for computing Energy from wav files in CPU - - Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It - also computes the mean and std of Energy values if `normalize_Energy` is True. - - Args: - samples (Union[List[List], List[Dict]]): - List of samples. Each sample is a list or a dict. - - ap (AudioProcessor): - AudioProcessor to compute Energy from wav files. - - cache_path (str): - Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation. - Defaults to None. - - precompute_num_workers (int): - Number of workers used for pre-computing the Energy values. Defaults to 0. - - normalize_Energy (bool): - Whether to normalize Energy values by mean and std. Defaults to True. - """ - + """DDP-safe Energy cache (single writer + atomic save).""" def __init__( self, samples: Union[List[List], List[Dict]], @@ -862,9 +818,24 @@ def __init__( self.pad_id = 0.0 self.mean = None self.std = None - if cache_path is not None and not os.path.exists(cache_path): - os.makedirs(cache_path) - self.precompute(precompute_num_workers) + + need_precompute = False + if cache_path is not None: + if _is_ddp(): + if _rank() == 0 and not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + need_precompute = True + _barrier() + else: + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + need_precompute = True + + if need_precompute and precompute_num_workers > 0: + if _rank() == 0: + self.precompute(num_workers=0) + _barrier() + if normalize_energy: self.load_stats(cache_path) @@ -880,10 +851,9 @@ def __len__(self): return len(self.samples) def precompute(self, num_workers=0): - print("[*] Pre-computing energys...") + print("[*] Pre-computing energys...]") with tqdm.tqdm(total=len(self)) as pbar: - batch_size = num_workers if num_workers > 0 else 1 - # we do not normalize at preproessing + batch_size = 1 normalize_energy = self.normalize_energy self.normalize_energy = False dataloder = torch.utils.data.DataLoader( @@ -892,68 +862,75 @@ def precompute(self, num_workers=0): computed_data = [] for batch in dataloder: energy = batch["energy"] - computed_data.append(e for e in energy) + computed_data.extend(energy) # corrige append de générateur pbar.update(batch_size) self.normalize_energy = normalize_energy if self.normalize_energy: - computed_data = [tensor for batch in computed_data for tensor in batch] # flatten energy_mean, energy_std = self.compute_energy_stats(computed_data) energy_stats = {"mean": energy_mean, "std": energy_std} - np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True) + _np_save_atomic(os.path.join(self.cache_path, "energy_stats.npy"), energy_stats) def get_pad_id(self): return self.pad_id @staticmethod - def create_energy_file_path(wav_file, cache_path): - file_name = os.path.splitext(os.path.basename(wav_file))[0] - energy_file = os.path.join(cache_path, file_name + "_energy.npy") - return energy_file + def create_energy_file_path(wav_or_name, cache_path): + # compat: accepte chemin complet ou nom unique + base = os.path.splitext(os.path.basename(wav_or_name))[0] + return os.path.join(cache_path, base + "_energy.npy") @staticmethod def _compute_and_save_energy(ap, wav_file, energy_file=None): - wav = ap.load_wav(wav_file) - energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length) + wav = ap.load_wav(wav_file).astype(np.float32) + _ensure_finite_np(wav, "energy_wav") + energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length).astype( + np.float32 + ) + _ensure_finite_np(energy, "energy_values") if energy_file: - np.save(energy_file, energy) + _np_save_atomic(energy_file, energy) return energy @staticmethod def compute_energy_stats(energy_vecs): - nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs]) + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs]) if len(energy_vecs) > 0 else np.array([1.0], dtype=np.float32) mean, std = np.mean(nonzeros), np.std(nonzeros) - return mean, std + return np.float32(mean), np.float32(std if std > 1e-8 else 1.0) def load_stats(self, cache_path): stats_path = os.path.join(cache_path, "energy_stats.npy") - stats = np.load(stats_path, allow_pickle=True).item() - self.mean = stats["mean"].astype(np.float32) - self.std = stats["std"].astype(np.float32) + stats = _np_load_retry(stats_path) + if stats is None: + return + stats = stats.item() + self.mean = np.float32(stats["mean"]) + self.std = np.float32(stats["std"] if stats["std"] > 1e-8 else 1.0) def normalize(self, energy): zero_idxs = np.where(energy == 0.0)[0] energy = energy - self.mean energy = energy / self.std energy[zero_idxs] = 0.0 - return energy + return energy.astype(np.float32) def denormalize(self, energy): zero_idxs = np.where(energy == 0.0)[0] - energy *= self.std - energy += self.mean + energy = energy * self.std + energy = energy + self.mean energy[zero_idxs] = 0.0 - return energy + return energy.astype(np.float32) def compute_or_load(self, wav_file, audio_unique_name): - """ - compute energy and return a numpy array of energy values - """ energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path) - if not os.path.exists(energy_file): - energy = self._compute_and_save_energy(self.ap, wav_file, energy_file) - else: - energy = np.load(energy_file) + if os.path.exists(energy_file): + energy = _np_load_retry(energy_file) + if energy is not None: + energy = energy.astype(np.float32) + _ensure_finite_np(energy, "energy_cached") + return energy + + energy = self._compute_and_save_energy(self.ap, wav_file, energy_file if _rank() == 0 else None) return energy.astype(np.float32) def collate_fn(self, batch): @@ -961,9 +938,10 @@ def collate_fn(self, batch): energys = [item["energy"] for item in batch] energy_lens = [len(item["energy"]) for item in batch] energy_lens_max = max(energy_lens) - energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id()) + # FloatTensor, pas LongTensor + energys_torch = torch.full((len(energys), energy_lens_max), fill_value=self.get_pad_id(), dtype=torch.float32) for i, energy_len in enumerate(energy_lens): - energys_torch[i, :energy_len] = torch.LongTensor(energys[i]) + energys_torch[i, :energy_len] = torch.from_numpy(energys[i].astype(np.float32)) return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens} def print_logs(self, level: int = 0) -> None: diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 74d3910b51..0df6dd1f1b 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -66,6 +66,8 @@ def replace_symbols(text, lang="en"): text = text.replace(":", ",") if lang == "en": text = text.replace("&", " and ") + elif lang == "mlg": + text = text.replace("&", " sy ") elif lang == "fr": text = text.replace("&", " et ") elif lang == "pt": @@ -119,6 +121,15 @@ def english_cleaners(text): text = collapse_whitespace(text) return text +def malagasy_cleaners(text): + """Pipeline for Malagasy text, including number and abbreviation expansion.""" + # text = convert_to_ascii(text) + text = lowercase(text) + text = replace_symbols(text, lang="mlg") + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + def phoneme_cleaners(text): """Pipeline for phonemes mode, including number and abbreviation expansion."""