Skip to content

Conversation

@ebezzam
Copy link
Contributor

@ebezzam ebezzam commented Nov 21, 2025

What does this PR do?

Processor usage within TTS pipelines is faulty.

Moreover, it does not support chat template inputs with/without audio, which is of interest for the following models for generating conversations:

Example usage for CSM and inputs like such for VibeVoice:

conversation = [
    {"role": "0", "content": [
        {"type": "text", "text": "Hello everyone, and welcome to the VibeVoice podcast. I'm your host, Linda, and today we're getting into one of the biggest debates in all of sports: who's the greatest basketball player of all time? I'm so excited to have Thomas here to talk about it with me."},
        {"type": "audio", "path": https://hf.co/datasets/bezzam/vibevoice_samples/resolve/main/voices/en-Alice_woman.wav}
    ]},
    {"role": "1", "content": [
        {"type": "text", "text": "Thanks so much for having me, Linda. You're absolutely right—this question always brings out some seriously strong feelings."},
        {"type": "audio", "path": https://hf.co/datasets/bezzam/vibevoice_samples/resolve/main/voices/en-Frank_man.wav}
    ]},
    {"role": "0", "content": [
        {"type": "text", "text": "Okay, so let's get right into it. For me, it has to be Michael Jordan. Six trips to the Finals, six championships. That kind of perfection is just incredible."},
    ]},
    {"role": "1", "content": [
        {"type": "text", "text": "Oh man, the first thing that always pops into my head is that shot against the Cleveland Cavaliers back in '89. Jordan just rises, hangs in the air forever, and just sinks it"},
    ]},
]

This PR is related to #39796, but this one is different/simpler (and I'd say more urgent) in its objective 👉 fixing processor usage and enabling chat_template inputs like above.

So I think it's worth a separate PR as #39796 requires more testing/review.

Context for current error

This line fails when trying to use the processor because it isn’t loaded (so it is None instead)

Minimal failing example:

import torch
from transformers import pipeline

model_id = "sesame/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline("text-to-speech", model=model_id, device=device, no_processor=False)

# prepare the inputs
text = "[0]Hello from Sesame." # `[0]` for speaker id 0

# apply pipeline
output = pipe(text)

"""
Traceback (most recent call last):
  File "/home/eric_bezzam/transformers/src/transformers/pipelines/test_pipeline_chat_template.py", line 34, in <module>
    output = pipe(text)
             ^^^^^^^^^^
  File "/home/eric_bezzam/transformers/src/transformers/pipelines/text_to_audio.py", line 218, in __call__
    return super().__call__(text_inputs, **forward_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eric_bezzam/transformers/src/transformers/pipelines/base.py", line 1261, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eric_bezzam/transformers/src/transformers/pipelines/base.py", line 1267, in run_single
    model_inputs = self.preprocess(inputs, **preprocess_params)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eric_bezzam/transformers/src/transformers/pipelines/text_to_audio.py", line 153, in preprocess
    output = preprocessor(text, **kwargs, return_tensors="pt")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not callable
"""

After changes

No need to explicitly ask for processor usage (auto-detected internally).

TTS example

import torch
from transformers import pipeline
import soundfile as sf

model_id = "sesame/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline("text-to-speech", model=model_id, device=device)

# prepare the inputs
text = "[0]Hello from Sesame." # `[0]` for speaker id 0

# apply pipeline
output = pipe(text, generate_kwargs={"output_audio": True})

# save the audio to a file
audio = output["audio"][0].squeeze()
fn = "csm_pipeline_output.wav"
sf.write(fn, audio, output["sampling_rate"])
print(f"Audio saved to {fn}")

Conversation example with chat template, which was not possible before! Again auto-detected like in text-generation pipeline. Below example mimics this usage from CSM:

import torch
from transformers import pipeline
import soundfile as sf
from datasets import load_dataset, Audio

model_id = "sesame/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipeline("text-to-speech", model=model_id, device=device)

# prepare the inputs like here: https://huggingface.co/sesame/csm-1b#csm-sounds-best-when-provided-with-context
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
conversation = []
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
    conversation.append(
        {
            "role": f"{speaker_id}",
            "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
        }
    )
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})

# apply pipeline
output = pipe(conversation, generate_kwargs={"output_audio": True})

# save the audio to a file
audio = output["audio"][0].squeeze()
fn = "csm_pipeline_output_chat.wav"
sf.write(fn, audio, output["sampling_rate"])
print(f"Audio saved to {fn}")

@Rocketknight1 since it is pipeline related you can take a look if you want! but will definitely ask @vasqu an @eustlb for audio-specific feedback 🙂

@ebezzam ebezzam added the Audio label Nov 21, 2025
Copy link
Contributor Author

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu Self-review with pointers that hopefully helps!

Comment on lines 33 to 49
# Copied from transformers.pipelines.text_generation
ChatType = list[dict[str, str]]


# Copied from transformers.pipelines.text_generation
class Chat:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
actually a batch of samples rather than messages in the same conversation."""

def __init__(self, messages: dict):
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
self.messages = messages


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 254 to 274
if isinstance(
text_inputs,
(list, tuple, types.GeneratorType)
if is_torch_available()
else (list, tuple, types.GeneratorType),
):
if isinstance(text_inputs, types.GeneratorType):
text_inputs, _ = itertools.tee(text_inputs)
text_inputs, first_item = (x for x in text_inputs), next(_)
else:
first_item = text_inputs[0]
if isinstance(first_item, (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(first_item, dict):
return super().__call__(Chat(text_inputs), **forward_params)
else:
chats = (Chat(chat) for chat in text_inputs)
if isinstance(text_inputs, types.GeneratorType):
return super().__call__(chats, **forward_params)
else:
return super().__call__(list(chats), **forward_params)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly from text generation:

elif isinstance(audio, tuple):
waveform = audio[0]
else:
waveform = self.processor.decode(audio)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was breaking for CSM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iirc it was for Dia but it has been broken way too often.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reopening because I want to verify that Dia works with this current version - I'm pretty sure we need the processor to decode for Dia which is why I wrote the initial long message on how we plan to standardize

  • Everything handled by the model, audio tokenizer within it already
  • Separate model / tokenizer, processor handles encoding/decoding into codebooks/waveform

Copy link
Contributor Author

@ebezzam ebezzam Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, Dia does need the processor for decoding and I'll also add a unit test for Dia so we don't miss this in the future.

However, I feel like a blanket self.processor.decode might be too broad. For example, CSM and VibeVoice don't require the processor to decode.

Since there is not standard approach (yet), how about something like below (which is working):

  if isinstance(audio, dict):
      waveform = audio[waveform_key]
  elif isinstance(audio, tuple):
      waveform = audio[0]
  elif self.model.config.model_type in ["dia"]:
      # models that require decoding, e.g. with codec
      waveform = self.processor.decode(audio)
  else:
      waveform = audio

Copy link
Contributor Author

@ebezzam ebezzam Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example usage:

from transformers import pipeline
import soundfile as sf


dia_pipeline = pipeline(
    "text-to-audio", model=model_checkpoint,
)
outputs = dia_pipeline(
    "[S1] Dia is an open weights text to dialogue model.",
    generate_kwargs={"max_new_tokens": 512},
)
assert outputs["sampling_rate"] == 44100

audio = outputs["audio"].squeeze()
fn = "dia_pipeline_output.wav"
sf.write(fn, audio, outputs["sampling_rate"])
print(f"Audio saved to {fn}")

I'm reluctant to allow voice cloning through pipeline, as this would require passing an audios input to pipeline (since Dia doesn't support chat templates).

Moreover, allowing inputs like audios is exactly what they are trying to phase out with image-text-to-text in #42359 (to only support chat template usage).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point on voice cloning, we should maybe update Dia with a chat template in the future. I did not have that in mind at that point, that's on me.

Re: standards. Yea, we have no choice atm - it's more of a question on how we handle this in the future

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ebezzam ebezzam requested a review from vasqu November 21, 2025 15:03
@ebezzam
Copy link
Contributor Author

ebezzam commented Nov 21, 2025

Failed tests are unrelated, concatenating below:

# https://app.circleci.com/pipelines/github/huggingface/transformers/154274/workflows/48600cd6-5d2a-455a-a2e0-96f3b9bae8ae/jobs/2028466
FAILED tests/models/efficientloftr/test_image_processing_efficientloftr.py::EfficientLoFTRImageProcessingTest::test_post_processing_keypoint_matching_with_padded_match_indices - AssertionError: 2 != 1
===== 1 failed, 549 passed, 362 skipped, 24 warnings in 115.22s (0:01:55) ======

# https://app.circleci.com/pipelines/github/huggingface/transformers/154274/workflows/48600cd6-5d2a-455a-a2e0-96f3b9bae8ae/jobs/2028474
FAILED tests/models/resnet/test_modeling_resnet.py::ResNetModelTest::test_can_load_ignoring_mismatched_shapes - AssertionError: 0.14109472930431366 not less than or equal to 0.1 : Issue with classifier.1.bias
==== 1 failed, 3959 passed, 6611 skipped, 93 warnings in 187.89s (0:03:07) =====

# https://app.circleci.com/pipelines/github/huggingface/transformers/154274/workflows/48600cd6-5d2a-455a-a2e0-96f3b9bae8ae/jobs/2028467
FAILED tests/models/olmo/test_modeling_olmo.py::OlmoModelTest::test_generate_with_static_cache - AssertionError: False is not true
===== 1 failed, 625 passed, 218 skipped, 16 warnings in 118.43s (0:01:58) ======

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments 🤗 I think it's overall fine more so nits to have more alignment with other pipelines.

My only gripe is that we still don't have a standard way how models generate audio. For example, CSM directly generates the audio waveform but that's because it uses the audio tokenizer directly within the model itself. Dia does not do that and depends on the processor do decode into waveform. This is something we have to properly enfore at some point before we get too many exceptions cc @eustlb

Comment on lines 36 to 46
# Copied from transformers.pipelines.text_generation.Chat
class Chat:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
actually a batch of samples rather than messages in the same conversation."""

def __init__(self, messages: dict):
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
self.messages = messages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like chat templates with any modality are increasingly more important, might be better to move this somewhere more general (and let it be imported). We gonna have audio, image, text already atp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I would like to avoid copied from tbh

Copy link
Contributor Author

@ebezzam ebezzam Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently two other pipelines using this:

We could combine into a single Chat object in base.py like so?

class Chat:
    """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
    to this format because the rest of the pipeline code tends to assume that lists of messages are
    actually a batch of samples rather than messages in the same conversation."""

    def __init__(
        self, messages: dict, images: Union[str, list[str], "Image.Image", list["Image.Image"]] | None = None
    ):
        for message in messages:
            if not ("role" in message and "content" in message):
                raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
        if images is not None:
             messages = add_images_to_messages(messages, images)

        self.messages = messages

For audio models with @eustlb, our chat templates already allow audio in the message and the jinja template (like this and this) handles extracting audio when calling apply_chat_template (which then calls the processor internally, see here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, @Rocketknight1 @zucchini-nlp wdyt about this? Not sure who has a better overview on (image) pipelines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you can move it to base.py or utils/chat_template_utils.py, whichever is a cleaner import. Not a huge issue either way, though!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @Rocketknight1! I've put in utils/chat_template_utils.py for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For image-text-to-text's current Chat object (here), it could be nice to do something like below in the general purpose Chat to support the images input:

class Chat:
    """This class is intended to just be used internally for pipelines and not exposed to users. We convert chats
    to this format because the rest of the pipeline code tends to assume that lists of messages are
    actually a batch of samples rather than messages in the same conversation."""

    def __init__(self, messages: dict, images: Union[str, list[str], "Image.Image", list["Image.Image"]] | None = None):
        for message in messages:
            if not ("role" in message and "content" in message):
                raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
        
        if images is not None:
	        messages = add_images_to_messages(messages, images)
        self.messages = messages

But actually this fails for a current usage 👉 when there is an image URL in the chat template (this code path).

FYI, I found out above this edge case, because this test would fail when I tried a modified Chat object like above (because the image wouldn't be properly loaded).

@zucchini-nlp do you have an idea of how image-text-to-text could also use the general purpose Chat object without having to call add_images_to_messages to avoid the current double for-loop if there is indeed no image input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu, @Rocketknight1 for reference @zucchini-nlp started a PR to handle image-text-to-text so I won't touch it in this PR

Comment on lines 257 to 271
if isinstance(text_inputs, types.GeneratorType):
text_inputs, _ = itertools.tee(text_inputs)
text_inputs, first_item = (x for x in text_inputs), next(_)
else:
first_item = text_inputs[0]
if isinstance(first_item, (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(first_item, dict):
return super().__call__(Chat(text_inputs), **forward_params)
else:
chats = (Chat(chat) for chat in text_inputs)
if isinstance(text_inputs, types.GeneratorType):
return super().__call__(chats, **forward_params)
else:
return super().__call__(list(chats), **forward_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly like this, it's short and sweet but I feel like we should maybe align with image e.g.

def _is_chat(arg):
return isinstance(arg, (list, tuple, KeyDataset)) and isinstance(arg[0], (list, tuple, dict))
if _is_chat(text):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(text[0], dict):
return super().__call__(Chat(text, images), **kwargs)
else:
if images is None:
images = [None] * len(text)
chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈
return super().__call__(chats, **kwargs)
# Same as above, but the `images` argument contains the chat. This can happen e.g. is the user only passes a
# chat as a positional argument.
elif text is None and _is_chat(images):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(images[0], dict):
return super().__call__(Chat(images), **kwargs)
else:
chats = [Chat(image) for image in images] # 🐈 🐈 🐈
return super().__call__(chats, **kwargs)

We cooking all our own soup 😢

Only thing I'd change would be avoid using a wildcard _ and give an explicit name instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed on Slack, the __call__ logic of image-text-to-text may be more complicated because they allow users to pass images as separate arguments rather than keeping everything within the chat template?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were breaking it tho for v5? Or did I misunderstand something?

Copy link
Contributor Author

@ebezzam ebezzam Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What they are planning to break for v5 in image-text-to-text is removing inputs like images for pipeline, so that users stick to chat template. See #42359

But perhaps the __call__ logic could still be simplified and shifted to base.py. Let me see...

elif isinstance(audio, tuple):
waveform = audio[0]
else:
waveform = self.processor.decode(audio)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iirc it was for Dia but it has been broken way too often.

@vasqu
Copy link
Contributor

vasqu commented Nov 21, 2025

Failing tests are either flaky (loftr one is handled elsewhere)

@ebezzam ebezzam requested a review from vasqu November 24, 2025 14:14
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some smaller comments, imo we should still check how we standardize (not necessarily in this PR but we should keep this in mind / discuss before we lose all sight)

I think 2 somewhat bigger things

  • seamless should be a different PR?
  • check if dia works

Comment on lines 243 to 244
text_inputs, _ = itertools.tee(text_inputs)
text_inputs, first_item = (x for x in text_inputs), next(_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's exchange the wildcard tho, i.e. _ - trying to be explicit here

Comment on lines 257 to 271
if isinstance(text_inputs, types.GeneratorType):
text_inputs, _ = itertools.tee(text_inputs)
text_inputs, first_item = (x for x in text_inputs), next(_)
else:
first_item = text_inputs[0]
if isinstance(first_item, (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(first_item, dict):
return super().__call__(Chat(text_inputs), **forward_params)
else:
chats = (Chat(chat) for chat in text_inputs)
if isinstance(text_inputs, types.GeneratorType):
return super().__call__(chats, **forward_params)
else:
return super().__call__(list(chats), **forward_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we were breaking it tho for v5? Or did I misunderstand something?

elif isinstance(audio, tuple):
waveform = audio[0]
else:
waveform = self.processor.decode(audio)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reopening because I want to verify that Dia works with this current version - I'm pretty sure we need the processor to decode for Dia which is why I wrote the initial long message on how we plan to standardize

  • Everything handled by the model, audio tokenizer within it already
  • Separate model / tokenizer, processor handles encoding/decoding into codebooks/waveform

Copy link
Contributor Author

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu thanks for your comments!

  • I tried going directly for an approach that puts common chat template logic into the base pipeline object. Let me know if we should rather focus just on text-to-audio and do such standardization in a separate PR.
  • Double-checked Dia. You're right self.processor.decode was needed for it, but I've adapted the (previous) logic so it doesn't assume such a call is needed for all models that have a processor (e.g. CSM and VibeVoice don't need to call this)

Comment on lines 1208 to 1229
# Detect if inputs is a chat-style input and cast as `Chat` or list of `Chat`
if isinstance(
inputs,
(list, tuple, types.GeneratorType, KeyDataset)
if is_torch_available()
else (list, tuple, types.GeneratorType),
):
if isinstance(inputs, types.GeneratorType):
gen_copy1, gen_copy2 = itertools.tee(inputs)
inputs = (x for x in gen_copy1)
first_item = next(gen_copy2)
else:
first_item = inputs[0]
if isinstance(first_item, (list, tuple, dict)):
if isinstance(first_item, dict):
inputs = Chat(inputs)
else:
chats = (Chat(chat) for chat in inputs)
if isinstance(inputs, types.GeneratorType):
inputs = chats
else:
inputs = list(chats)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu an idea for standardizing chat template usage in the base class.

Essentially this was in text-generation and what I had copied in text-to-audio (in what you last saw), and could potentially be used by image-text-to-text once it drops support for the images input (@zucchini-nlp)?

Following test run as before:

RUN_SLOW=1 pytest tests/pipelines/test_pipelines_text_to_audio.py
RUN_SLOW=1 pytest tests/pipelines/test_pipelines_text_generation.py
RUN_SLOW=1 pytest tests/pipelines/test_pipelines_image_text_to_text.py

I could also run below if you think we should check to be safe? and do you have anything else in mind that I should double check?

RUN_SLOW=1 pytest tests/pipelines

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our pipelines tests should be all not slow tests, but doesn't hurt to check with the env. I'm pro this!

Waiting for @zucchini-nlp if she has anything to add, if I see it correctly it's depending on #42359

Comment on lines 257 to 258
# ensure audio and not codes
self.assertEqual(len(audio.shape), 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added such checks to make sure audio decoding actually working

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's already looking pretty good. We should sync with the other pipelines PR not to have some weird conflicting behavior

Other than that, we really should work on enforcing good standards so that we do not have to add so many exceptions (especially now with CSM/Dia). With v5, we IMO have the opportunity to break things and make it unified cc @eustlb wdyt?

Comment on lines 1208 to 1229
# Detect if inputs is a chat-style input and cast as `Chat` or list of `Chat`
if isinstance(
inputs,
(list, tuple, types.GeneratorType, KeyDataset)
if is_torch_available()
else (list, tuple, types.GeneratorType),
):
if isinstance(inputs, types.GeneratorType):
gen_copy1, gen_copy2 = itertools.tee(inputs)
inputs = (x for x in gen_copy1)
first_item = next(gen_copy2)
else:
first_item = inputs[0]
if isinstance(first_item, (list, tuple, dict)):
if isinstance(first_item, dict):
inputs = Chat(inputs)
else:
chats = (Chat(chat) for chat in inputs)
if isinstance(inputs, types.GeneratorType):
inputs = chats
else:
inputs = list(chats)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our pipelines tests should be all not slow tests, but doesn't hurt to check with the env. I'm pro this!

Waiting for @zucchini-nlp if she has anything to add, if I see it correctly it's depending on #42359

elif isinstance(audio, tuple):
waveform = audio[0]
else:
waveform = self.processor.decode(audio)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point on voice cloning, we should maybe update Dia with a chat template in the future. I did not have that in mind at that point, that's on me.

Re: standards. Yea, we have no choice atm - it's more of a question on how we handle this in the future

Copy link
Contributor Author

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu thanks for previous comments, it's getting better!

No more exceptions (for newer audio models) in this iteration, so going forward we don't need to add exceptions for newer audio models if:

  • they properly use return_dict_in_generate
  • write audio into that the generation output dict, and otherwise codecs that need decoding into sequences

Comment on lines 1241 to 1252
if isinstance(first_item, dict):
if is_valid_chat(inputs):
inputs = Chat(inputs)
elif isinstance(first_item, (list, tuple)):
# materialize generator is needed
items = list(inputs) if isinstance(inputs, types.GeneratorType) else inputs
if all(is_valid_chat(chat) for chat in items):
chats = (Chat(chat) for chat in items)
if isinstance(inputs, types.GeneratorType):
inputs = chats
else:
inputs = list(chats)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed previous logic for backward compatibility: some pipelines pass a list of objects which are not necessarily a chat template (e.g. key point matching). So I only convert if the object is a valid chat.

cc @Rocketknight1 who has more experience with pipelines and may spot an edge case 🙃

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI no new failures when I run RUN_SLOW=1 pytest tests/pipelines

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see anything it's missing! You could maybe simplify it by just always calling list(inputs) and removing the extra conditionals for generators, though? Python lists only store pointers to elements so it should be basically free in terms of speed/memory, even if the chats are big.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 yes I can cast as list from the start! fyi I'll have to remove this check which errors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, makes sense to me! Once we're materializing the entire generator output there's not much point in pretending we're streaming anymore.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, just make to sure to sync with main and that nothing is broken based on that

The remaining comments are nits

Copy link
Contributor Author

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu I removed the exceptions for bark/musicgen in _forward. I needed to fix the handling of return_dict_in_generate in modeling_bark.py

Comment on lines +654 to +657
if kwargs.get("return_dict_in_generate", False):
semantic_output = semantic_output.sequences[:, max_input_semantic_length + 1 :]
else:
semantic_output = semantic_output[:, max_input_semantic_length + 1 :]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and below so that that Bark doesn't error out when return_dict_in_generate=True is passed

Comment on lines +196 to +197
# ensure dict output to facilitate postprocessing
forward_params.update({"return_dict_in_generate": True})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No more exception 🙂

Comment on lines +288 to +289
if needs_decoding and self.processor is not None:
audio = self.processor.decode(audio)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had to add self.processor is not None for music gen to pass. Makes sense since I set the processor to None 🤦


outputs = music_generator("This is a test")
self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 32000}, outputs)
self.assertEqual(len(outputs["audio"].shape), n_ch)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added check for musicgen and bark to make sure ensure they are audio with cleaned postprocess and for future

@ebezzam ebezzam changed the title Fix processor usage and add chat_template support to TTS pipeline. Fix processor usage, add chat_template support to TTS pipeline, and shift common chat template logic to base class. Nov 26, 2025
@ebezzam ebezzam changed the title Fix processor usage, add chat_template support to TTS pipeline, and shift common chat template logic to base class. Fix processor usage + add chat_template support to TTS pipeline, and shift common chat template logic to base class. Nov 26, 2025
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: bark

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@Cyrilvallez Cyrilvallez merged commit 5458d81 into huggingface:main Nov 27, 2025
21 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants