Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d4f89a0
Fix processor usage and add chat_template support to TTS pipeline.
ebezzam Nov 21, 2025
f4dbca1
Fallback to tokenizer for musicgen.
ebezzam Nov 21, 2025
25f4816
Fallback to tokenizer for musicgen.
ebezzam Nov 21, 2025
29185d4
Make style
ebezzam Nov 21, 2025
27a56cf
style/quality after update?
ebezzam Nov 21, 2025
a654122
FIx copied from
ebezzam Nov 21, 2025
13234d9
Smaller things.
ebezzam Nov 21, 2025
729844b
Update src/transformers/pipelines/text_to_audio.py
ebezzam Nov 21, 2025
bb96d9b
Shift common utilities to chat template utils.
ebezzam Nov 21, 2025
7da8240
Merge branch 'main' into fix/tts_pipepine
ebezzam Nov 24, 2025
f801b96
Type nits
ebezzam Nov 24, 2025
716c3cf
Remove tied weights.
ebezzam Nov 24, 2025
c0060d1
Keep seamless error
ebezzam Nov 25, 2025
f8e90cf
Better audio output object.
ebezzam Nov 25, 2025
f5fb635
Properly handle DIa and add test.
ebezzam Nov 25, 2025
ed25458
Shift chat template prep to base, test dia batch.
ebezzam Nov 25, 2025
97158c6
Backward compatibility for dicts passed to pipelines
ebezzam Nov 26, 2025
5b80b03
Simplify postprocessing and tests.
ebezzam Nov 26, 2025
0b17759
Merge branch 'main' into fix/tts_pipepine
ebezzam Nov 26, 2025
838daf7
Make quality/style
ebezzam Nov 26, 2025
e3b882d
Remove chat from image text to text.
ebezzam Nov 26, 2025
661cafb
Only check first item, to not consume first item of generator inputs.
ebezzam Nov 26, 2025
94410f4
Nit
ebezzam Nov 26, 2025
79d7e73
Simplify
ebezzam Nov 26, 2025
7a438ff
Add checks for bark/musicgen to ensure output is audio.
ebezzam Nov 26, 2025
1428092
Better var
ebezzam Nov 26, 2025
4fde390
Merge branch 'main' into fix/tts_pipepine
ebezzam Nov 26, 2025
80364e3
Merge branch 'main' into fix/tts_pipepine
ebezzam Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1981,7 +1981,6 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel,
"text_encoder",
"text_decoder",
]
_tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"}

def __init__(
self,
Expand Down
17 changes: 2 additions & 15 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..generation import GenerationConfig
from ..utils import ModelOutput, add_end_docstrings, is_torch_available
from ..utils.chat_template_utils import Chat, ChatType
from .base import Pipeline, build_pipeline_init_args


Expand All @@ -14,27 +15,13 @@
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .pt_utils import KeyDataset

ChatType = list[dict[str, str]]


class ReturnType(enum.Enum):
TENSORS = 0
NEW_TEXT = 1
FULL_TEXT = 2


class Chat:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
actually a batch of samples rather than messages in the same conversation."""

def __init__(self, messages: dict):
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
self.messages = messages


@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class TextGenerationPipeline(Pipeline):
"""
Expand Down Expand Up @@ -263,7 +250,7 @@ def __call__(self, text_inputs, **kwargs):
Complete the prompt(s) given as inputs.

Args:
text_inputs (`str`, `list[str]`, list[dict[str, str]], or `list[list[dict[str, str]]]`):
text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`):
One or several prompts (or one list of prompts) to complete. If strings or a list of string are
passed, this pipeline will continue each prompt. Alternatively, a "chat", in the form of a list
of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed,
Expand Down
85 changes: 60 additions & 25 deletions src/transformers/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.from typing import List, Union

import itertools
import types
from typing import Any, overload

from ..generation import GenerationConfig
from ..utils import is_torch_available
from ..utils.chat_template_utils import Chat, ChatType
from .base import Pipeline


AudioOutput = dict[str, Any] # {"audio": np.ndarray, "sampling_rate": int}


if is_torch_available():
import torch

Expand Down Expand Up @@ -81,7 +88,7 @@ class TextToAudioPipeline(Pipeline):
"""

_pipeline_calls_generate = True
_load_processor = False
_load_processor = None # prioritize processors as some models require it
_load_image_processor = False
_load_feature_extractor = False
_load_tokenizer = True
Expand All @@ -91,12 +98,9 @@ class TextToAudioPipeline(Pipeline):
max_new_tokens=256,
)

def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs):
def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
super().__init__(*args, **kwargs)

# Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time
self.no_processor = no_processor

self.vocoder = None
if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
self.vocoder = (
Expand All @@ -105,6 +109,10 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, *
else vocoder
)

if self.model.config.model_type in ["musicgen"]:
# MusicGen expect to use the tokenizer
self.processor = None

self.sampling_rate = sampling_rate
if self.vocoder is not None:
self.sampling_rate = self.vocoder.config.sampling_rate
Expand All @@ -127,7 +135,7 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, *
self.sampling_rate = sampling_rate

# last fallback to get the sampling rate based on processor
if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"):
if self.sampling_rate is None and self.processor is not None and hasattr(self.processor, "feature_extractor"):
self.sampling_rate = self.processor.feature_extractor.sampling_rate

def preprocess(self, text, **kwargs):
Expand All @@ -141,16 +149,22 @@ def preprocess(self, text, **kwargs):
"add_special_tokens": False,
"return_attention_mask": True,
"return_token_type_ids": False,
"padding": "max_length",
}

# priority is given to kwargs
new_kwargs.update(kwargs)

kwargs = new_kwargs

preprocessor = self.tokenizer if self.no_processor else self.processor
output = preprocessor(text, **kwargs, return_tensors="pt")
preprocessor = self.processor if self.processor is not None else self.tokenizer
if isinstance(text, Chat):
output = preprocessor.apply_chat_template(
text.messages,
tokenize=True,
return_dict=True,
**kwargs,
)
else:
output = preprocessor(text, **kwargs, return_tensors="pt")

return output

Expand Down Expand Up @@ -188,18 +202,27 @@ def _forward(self, model_inputs, **kwargs):
return output

@overload
def __call__(self, text_inputs: str, **forward_params: Any) -> dict[str, Any]: ...
def __call__(self, text_inputs: str, **forward_params: Any) -> AudioOutput: ...

@overload
def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[dict[str, Any]]: ...
def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[AudioOutput]: ...

def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, Any] | list[dict[str, Any]]:
@overload
def __call__(self, text_inputs: ChatType, **forward_params: Any) -> AudioOutput: ...

@overload
def __call__(self, text_inputs: list[ChatType], **forward_params: Any) -> list[AudioOutput]: ...

def __call__(self, text_inputs, **forward_params):
"""
Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information.

Args:
text_inputs (`str` or `list[str]`):
The text(s) to generate.
text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`):
One or several texts to generate. If strings or a list of string are passed, this pipeline will
generate the corresponding text. Alternatively, a "chat", in the form of a list of dicts with "role"
and "content" keys, can be passed, or a list of such chats. When chats are passed, the model's chat
template will be used to format them before passing them to the model.
forward_params (`dict`, *optional*):
Parameters passed to the model generation/forward method. `forward_params` are always passed to the
underlying model.
Expand All @@ -215,6 +238,23 @@ def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str,
- **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform.
- **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform.
"""
if isinstance(text_inputs, (list, tuple, types.GeneratorType)):
if isinstance(text_inputs, types.GeneratorType):
text_inputs, _ = itertools.tee(text_inputs)
text_inputs, first_item = (x for x in text_inputs), next(_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's exchange the wildcard tho, i.e. _ - trying to be explicit here

else:
first_item = text_inputs[0]
if isinstance(first_item, (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(first_item, dict):
return super().__call__(Chat(text_inputs), **forward_params)
else:
chats = (Chat(chat) for chat in text_inputs)
if isinstance(text_inputs, types.GeneratorType):
return super().__call__(chats, **forward_params)
else:
return super().__call__(list(chats), **forward_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly like this, it's short and sweet but I feel like we should maybe align with image e.g.

def _is_chat(arg):
return isinstance(arg, (list, tuple, KeyDataset)) and isinstance(arg[0], (list, tuple, dict))
if _is_chat(text):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(text[0], dict):
return super().__call__(Chat(text, images), **kwargs)
else:
if images is None:
images = [None] * len(text)
chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈
return super().__call__(chats, **kwargs)
# Same as above, but the `images` argument contains the chat. This can happen e.g. is the user only passes a
# chat as a positional argument.
elif text is None and _is_chat(images):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(images[0], dict):
return super().__call__(Chat(images), **kwargs)
else:
chats = [Chat(image) for image in images] # 🐈 🐈 🐈
return super().__call__(chats, **kwargs)

We cooking all our own soup 😢

Only thing I'd change would be avoid using a wildcard _ and give an explicit name instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed on Slack, the __call__ logic of image-text-to-text may be more complicated because they allow users to pass images as separate arguments rather than keeping everything within the chat template?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were breaking it tho for v5? Or did I misunderstand something?

Copy link
Contributor Author

@ebezzam ebezzam Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What they are planning to break for v5 in image-text-to-text is removing inputs like images for pipeline, so that users stick to chat template. See #42359

But perhaps the __call__ logic could still be simplified and shifted to base.py. Let me see...


return super().__call__(text_inputs, **forward_params)

def _sanitize_parameters(
Expand Down Expand Up @@ -248,17 +288,12 @@ def postprocess(self, audio):
else:
waveform_key = "waveform"

# We directly get the waveform
if self.no_processor:
if isinstance(audio, dict):
waveform = audio[waveform_key]
elif isinstance(audio, tuple):
waveform = audio[0]
else:
waveform = audio
# Or we need to postprocess to get the waveform
if isinstance(audio, dict):
waveform = audio[waveform_key]
elif isinstance(audio, tuple):
waveform = audio[0]
else:
waveform = self.processor.decode(audio)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was breaking for CSM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iirc it was for Dia but it has been broken way too often.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reopening because I want to verify that Dia works with this current version - I'm pretty sure we need the processor to decode for Dia which is why I wrote the initial long message on how we plan to standardize

  • Everything handled by the model, audio tokenizer within it already
  • Separate model / tokenizer, processor handles encoding/decoding into codebooks/waveform

Copy link
Contributor Author

@ebezzam ebezzam Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, Dia does need the processor for decoding and I'll also add a unit test for Dia so we don't miss this in the future.

However, I feel like a blanket self.processor.decode might be too broad. For example, CSM and VibeVoice don't require the processor to decode.

Since there is not standard approach (yet), how about something like below (which is working):

  if isinstance(audio, dict):
      waveform = audio[waveform_key]
  elif isinstance(audio, tuple):
      waveform = audio[0]
  elif self.model.config.model_type in ["dia"]:
      # models that require decoding, e.g. with codec
      waveform = self.processor.decode(audio)
  else:
      waveform = audio

Copy link
Contributor Author

@ebezzam ebezzam Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example usage:

from transformers import pipeline
import soundfile as sf


dia_pipeline = pipeline(
    "text-to-audio", model=model_checkpoint,
)
outputs = dia_pipeline(
    "[S1] Dia is an open weights text to dialogue model.",
    generate_kwargs={"max_new_tokens": 512},
)
assert outputs["sampling_rate"] == 44100

audio = outputs["audio"].squeeze()
fn = "dia_pipeline_output.wav"
sf.write(fn, audio, outputs["sampling_rate"])
print(f"Audio saved to {fn}")

I'm reluctant to allow voice cloning through pipeline, as this would require passing an audios input to pipeline (since Dia doesn't support chat templates).

Moreover, allowing inputs like audios is exactly what they are trying to phase out with image-text-to-text in #42359 (to only support chat template usage).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point on voice cloning, we should maybe update Dia with a chat template in the future. I did not have that in mind at that point, that's on me.

Re: standards. Yea, we have no choice atm - it's more of a question on how we handle this in the future

waveform = audio

if isinstance(audio, list):
output_dict["audio"] = [el.to(device="cpu", dtype=torch.float).numpy() for el in waveform]
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from torch import Tensor


ChatType = list[dict[str, Any]]


BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
# Extracts the initial segment of the docstring, containing the function description
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
Expand Down Expand Up @@ -459,9 +462,9 @@ def strftime_now(format):


def render_jinja_template(
conversations: list[list[dict[str, str]]],
conversations: list[ChatType],
tools: list[dict | Callable] | None = None,
documents: list[dict[str, str]] | None = None,
documents: ChatType | None = None,
chat_template: str | None = None,
return_assistant_tokens_mask: bool = False,
continue_final_message: bool = False,
Expand Down Expand Up @@ -558,3 +561,15 @@ def render_jinja_template(
rendered.append(rendered_chat)

return rendered, all_generation_indices


class Chat:
"""This class is intended to just be used internally for pipelines and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
actually a batch of samples rather than messages in the same conversation."""

def __init__(self, messages: dict):
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
self.messages = messages
1 change: 0 additions & 1 deletion tests/pipelines/test_pipelines_text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def test_conversion_additional_tensor(self):
"add_special_tokens": False,
"return_attention_mask": True,
"return_token_type_ids": False,
"padding": "max_length",
}
outputs = speech_generator(
"This is a test",
Expand Down