Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
vocoder_config_path: str = None,
progress_bar: bool = True,
gpu=False,
device: str = None,
):
"""🐸TTS python interface that allows to load and use the released models.

Expand Down Expand Up @@ -66,8 +67,15 @@ def __init__(
self.synthesizer = None
self.voice_converter = None
self.model_name = ""
if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")

# Handle device parameter with backward compatibility for gpu flag
if gpu and device is None:
warnings.warn("`gpu` will be deprecated. Please use `device='cuda'` or `tts.to(device)` instead.")
device = "cuda" if gpu else "cpu"
elif device is None:
device = "cpu"

self.device = device

if model_name is not None and len(model_name) > 0:
if "tts_models" in model_name:
Expand Down Expand Up @@ -118,6 +126,34 @@ def languages(self):
return None
return self.synthesizer.tts_model.language_manager.language_names

def to(self, device):
"""Move the model to a different device.

Args:
device (str): Device to move the model to (e.g., "cuda", "cpu", "mps").

Returns:
self: Returns self for method chaining.
"""
import torch
self.device = device if isinstance(device, str) else str(device)

# Move synthesizer models to device
if self.synthesizer is not None:
if self.synthesizer.tts_model is not None:
self.synthesizer.tts_model.to(device)
# Update synthesizer's device
self.synthesizer.device = torch.device(device)
self.synthesizer.use_cuda = (torch.device(device).type == "cuda")
if self.synthesizer.vocoder_model is not None:
self.synthesizer.vocoder_model.to(device)

# Move voice converter model to device
if self.voice_converter is not None:
self.voice_converter.to(device)

return self

@staticmethod
def get_models_file_path():
return Path(__file__).parent / ".models.json"
Expand Down Expand Up @@ -184,7 +220,7 @@ def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
encoder_checkpoint=None,
encoder_config=None,
model_dir=model_dir,
use_cuda=gpu,
device=self.device,
)

def load_tts_model_by_path(
Expand All @@ -209,7 +245,7 @@ def load_tts_model_by_path(
vocoder_config=vocoder_config,
encoder_checkpoint=None,
encoder_config=None,
use_cuda=gpu,
device=self.device,
)

def _check_arguments(
Expand Down
31 changes: 28 additions & 3 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,25 +587,50 @@ def generate(
**hf_generate_kwargs,
):
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)

# Ensure attention_mask is provided to the HF generate call.
# If the caller did not specify one, create a full-ones mask
# matching the shape of the input ids (including prefix tokens).
hf_kwargs = dict(hf_generate_kwargs)
if "attention_mask" not in hf_kwargs:
hf_kwargs["attention_mask"] = torch.ones_like(
gpt_inputs,
dtype=torch.long,
device=gpt_inputs.device,
)

gen = self.gpt_inference.generate(
gpt_inputs,
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
**hf_generate_kwargs,
**hf_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
if "return_dict_in_generate" in hf_kwargs:
return gen.sequences[:, gpt_inputs.shape[1] :], gen
return gen[:, gpt_inputs.shape[1] :]

def get_generator(self, fake_inputs, **hf_generate_kwargs):
# Ensure attention_mask is provided to the HF streaming generate call
# and that `return_dict_in_generate` is True when requesting hidden
# states, to avoid configuration warnings in newer transformers.
hf_kwargs = dict(hf_generate_kwargs)
if "attention_mask" not in hf_kwargs:
hf_kwargs["attention_mask"] = torch.ones_like(
fake_inputs,
dtype=torch.long,
device=fake_inputs.device,
)
if hf_kwargs.get("output_hidden_states"):
hf_kwargs.setdefault("return_dict_in_generate", True)

return self.gpt_inference.generate_stream(
fake_inputs,
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
do_stream=True,
**hf_generate_kwargs,
**hf_kwargs,
)
4 changes: 2 additions & 2 deletions TTS/tts/layers/xtts/gpt_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import torch
from torch import nn
from transformers import GPT2PreTrainedModel
from transformers import GPT2PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions


class GPT2InferenceModel(GPT2PreTrainedModel):
class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
"""Override GPT2LMHeadModel to allow for prefix conditioning."""

def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
Expand Down
40 changes: 31 additions & 9 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,15 @@ def generate(

# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
# Use a dedicated StreamGenerationConfig derived from the model config,
# without emitting the legacy deprecation warning from upstream since
# this path is internal to XTTS streaming.
if getattr(self, "generation_config", None) is not None and self.generation_config._from_model_config:
new_generation_config = StreamGenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self.generation_config = new_generation_config
# Mark as not coming directly from the original model config anymore.
self.generation_config._from_model_config = False
generation_config = self.generation_config

generation_config = copy.deepcopy(generation_config)
Expand Down Expand Up @@ -169,11 +166,18 @@ def generate(
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]

# Prepare special token tensors (bos/eos/pad, etc.) on the correct device.
# This mirrors the upstream GenerationMixin.generate behavior and is required
# for attributes like `generation_config._eos_token_tensor` to exist.
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
Expand Down Expand Up @@ -801,6 +805,12 @@ def sample_stream(
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

# Initialize `cache_position` for compatibility with newer `transformers`
# caching logic. This mirrors the behavior in the upstream generation
# utilities and prevents KeyError on `cache_position` when updating
# model kwargs during streaming generation.
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
while True:
Expand Down Expand Up @@ -880,6 +890,18 @@ def sample_stream(

def init_stream_support():
"""Overload PreTrainedModel for streaming."""
# Add compatibility shim for newer `transformers` versions (>=4.46) where
# `_get_logits_warper` was removed and its behavior folded into
# `_get_logits_processor`. Our streaming generator still calls
# `self._get_logits_warper(...)`, so provide a no-op implementation that
# returns an empty `LogitsProcessorList`. Sampling-related warpers are
# already included in `_get_logits_processor` in these versions.
if not hasattr(PreTrainedModel, "_get_logits_warper"):
def _get_logits_warper(self, generation_config):
return LogitsProcessorList()

PreTrainedModel._get_logits_warper = _get_logits_warper

PreTrainedModel.generate_stream = NewGenerationMixin.generate
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream

Expand Down
62 changes: 54 additions & 8 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from dataclasses import dataclass

import librosa
import torch
import torch.nn.functional as F
import torchaudio
Expand Down Expand Up @@ -70,7 +69,20 @@ def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark

# torchaudio should chose proper backend to load audio depending on platform
audio, lsr = torchaudio.load(audiopath)
# For PyTorch 2.9+, use soundfile directly to avoid torchcodec requirement
try:
import soundfile as sf
audio, lsr = sf.read(audiopath)
# Convert to torch tensor and ensure correct shape (channels, samples)
audio = torch.as_tensor(audio, dtype=torch.float32)
# soundfile returns (samples,) for mono or (samples, channels) for multi-channel
if audio.ndim == 1:
audio = audio.unsqueeze(0) # (1, samples)
elif audio.ndim == 2:
audio = audio.transpose(0, 1) # (channels, samples)
except ImportError:
# Fallback to torchaudio if soundfile not available
audio, lsr = torchaudio.load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
Expand Down Expand Up @@ -254,6 +266,15 @@ def init_models(self):
def device(self):
return next(self.parameters()).device

def to(self, device):
"""Override to() to ensure FP32 on MPS and keep device in sync."""
super().to(device)
# Ensure FP32 on MPS (MPS doesn't support FP16 well for all operations)
device_obj = torch.device(device) if isinstance(device, str) else device
if device_obj.type == "mps":
self.float()
return self

@torch.inference_mode()
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
"""Compute the conditioning latents for the GPT model from the given audio.
Expand Down Expand Up @@ -316,11 +337,29 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int =
@torch.inference_mode()
def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
return (
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
.unsqueeze(-1)
.to(self.device)
)

# Run speaker encoder on CPU if model is on MPS to avoid conv/STFT issues
if self.device.type == "mps":
# Move speaker encoder to CPU temporarily
original_device = next(self.hifigan_decoder.speaker_encoder.parameters()).device
self.hifigan_decoder.speaker_encoder.to("cpu")

spk = self.hifigan_decoder.speaker_encoder.forward(
audio_16k.to("cpu"),
l2_norm=True
)

# Move speaker encoder back to original device
self.hifigan_decoder.speaker_encoder.to(original_device)

# Move result to model device
return spk.to(self.device).unsqueeze(-1)
else:
spk = self.hifigan_decoder.speaker_encoder.forward(
audio_16k.to(self.device),
l2_norm=True
)
return spk.to(self.device).unsqueeze(-1)

@torch.inference_mode()
def get_conditioning_latents(
Expand Down Expand Up @@ -359,7 +398,13 @@ def get_conditioning_latents(
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
# Apply trimming via librosa lazily to avoid importing it
# (and its deprecated pkg_resources usage) unless needed.
import librosa

audio_np = audio.squeeze(0).detach().cpu().numpy()
audio_trimmed, _ = librosa.effects.trim(audio_np, top_db=librosa_trim_db)
audio = torch.from_numpy(audio_trimmed).unsqueeze(0).to(self.device)

# compute latents for the decoder
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
Expand Down Expand Up @@ -649,6 +694,7 @@ def inference_stream(
gpt_cond_latent.to(self.device),
text_tokens,
)

gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
Expand Down
4 changes: 4 additions & 0 deletions TTS/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def load_fsspec(
Returns:
Object stored in path.
"""
# Set weights_only=False for PyTorch 2.6+ compatibility with TTS model checkpoints
if "weights_only" not in kwargs:
kwargs["weights_only"] = False

is_local = os.path.isdir(path) or os.path.isfile(path)
if cache and not is_local:
with fsspec.open(
Expand Down
Loading
Loading