From d4f89a0a822d2eb7446ba73a06c3bfe57dcf1bac Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 21 Nov 2025 14:27:57 +0100 Subject: [PATCH 01/24] Fix processor usage and add chat_template support to TTS pipeline. --- src/transformers/pipelines/text_to_audio.py | 91 ++++++++++++++----- .../pipelines/test_pipelines_text_to_audio.py | 1 - 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 99334eff468a..625da2495791 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.from typing import List, Union + +import itertools +import types from typing import Any, overload from ..generation import GenerationConfig @@ -27,6 +30,23 @@ DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan" +# Copied from transformers.pipelines.text_generation +ChatType = list[dict[str, str]] + + +# Copied from transformers.pipelines.text_generation +class Chat: + """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats + to this format because the rest of the pipeline code tends to assume that lists of messages are + actually a batch of samples rather than messages in the same conversation.""" + + def __init__(self, messages: dict): + for message in messages: + if not ("role" in message and "content" in message): + raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") + self.messages = messages + + class TextToAudioPipeline(Pipeline): """ Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This @@ -81,7 +101,7 @@ class TextToAudioPipeline(Pipeline): """ _pipeline_calls_generate = True - _load_processor = False + _load_processor = None _load_image_processor = False _load_feature_extractor = False _load_tokenizer = True @@ -91,12 +111,9 @@ class TextToAudioPipeline(Pipeline): max_new_tokens=256, ) - def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs): + def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs): super().__init__(*args, **kwargs) - # Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time - self.no_processor = no_processor - self.vocoder = None if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values(): self.vocoder = ( @@ -127,7 +144,7 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, * self.sampling_rate = sampling_rate # last fallback to get the sampling rate based on processor - if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"): + if self.sampling_rate is None and self.processor is not None and hasattr(self.processor, "feature_extractor"): self.sampling_rate = self.processor.feature_extractor.sampling_rate def preprocess(self, text, **kwargs): @@ -141,16 +158,22 @@ def preprocess(self, text, **kwargs): "add_special_tokens": False, "return_attention_mask": True, "return_token_type_ids": False, - "padding": "max_length", } # priority is given to kwargs new_kwargs.update(kwargs) - kwargs = new_kwargs - preprocessor = self.tokenizer if self.no_processor else self.processor - output = preprocessor(text, **kwargs, return_tensors="pt") + preprocessor = self.processor if self.processor is not None else self.tokenizer + if isinstance(text, Chat): + output = preprocessor.apply_chat_template( + text.messages, + tokenize=True, + return_dict=True, + **kwargs, + ) + else: + output = preprocessor(text, **kwargs, return_tensors="pt") return output @@ -193,13 +216,22 @@ def __call__(self, text_inputs: str, **forward_params: Any) -> dict[str, Any]: . @overload def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[dict[str, Any]]: ... + @overload + def __call__(self, text_inputs: ChatType, **forward_params: Any) -> list[dict[str, ChatType]]: ... + + @overload + def __call__(self, text_inputs: list[ChatType], **forward_params: Any) -> list[list[dict[str, ChatType]]]: ... + def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, Any] | list[dict[str, Any]]: """ Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information. Args: text_inputs (`str` or `list[str]`): - The text(s) to generate. + One or several texts to generate. If strings or a list of string are passed, this pipeline will + generate the corresponding text. Alternatively, a "chat", in the form of a list of dicts with "role" + and "content" keys, can be passed, or a list of such chats. When chats are passed, the model's chat + template will be used to format them before passing them to the model. forward_params (`dict`, *optional*): Parameters passed to the model generation/forward method. `forward_params` are always passed to the underlying model. @@ -215,6 +247,28 @@ def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform. - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform. """ + if isinstance( + text_inputs, + (list, tuple, types.GeneratorType) + if is_torch_available() + else (list, tuple, types.GeneratorType), + ): + if isinstance(text_inputs, types.GeneratorType): + text_inputs, _ = itertools.tee(text_inputs) + text_inputs, first_item = (x for x in text_inputs), next(_) + else: + first_item = text_inputs[0] + if isinstance(first_item, (list, tuple, dict)): + # We have one or more prompts in list-of-dicts format, so this is chat mode + if isinstance(first_item, dict): + return super().__call__(Chat(text_inputs), **forward_params) + else: + chats = (Chat(chat) for chat in text_inputs) + if isinstance(text_inputs, types.GeneratorType): + return super().__call__(chats, **forward_params) + else: + return super().__call__(list(chats), **forward_params) + return super().__call__(text_inputs, **forward_params) def _sanitize_parameters( @@ -248,17 +302,12 @@ def postprocess(self, audio): else: waveform_key = "waveform" - # We directly get the waveform - if self.no_processor: - if isinstance(audio, dict): - waveform = audio[waveform_key] - elif isinstance(audio, tuple): - waveform = audio[0] - else: - waveform = audio - # Or we need to postprocess to get the waveform + if isinstance(audio, dict): + waveform = audio[waveform_key] + elif isinstance(audio, tuple): + waveform = audio[0] else: - waveform = self.processor.decode(audio) + waveform = audio if isinstance(audio, list): output_dict["audio"] = [el.to(device="cpu", dtype=torch.float).numpy() for el in waveform] diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index c13d0830c6e6..c4870bc1dff6 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -151,7 +151,6 @@ def test_conversion_additional_tensor(self): "add_special_tokens": False, "return_attention_mask": True, "return_token_type_ids": False, - "padding": "max_length", } outputs = speech_generator( "This is a test", From f4dbca1d932dd5aff5e95931d6519856e89a17ee Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 21 Nov 2025 15:19:02 +0100 Subject: [PATCH 02/24] Fallback to tokenizer for musicgen. --- src/transformers/pipelines/text_to_audio.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 625da2495791..b3b19e209f97 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -165,6 +165,9 @@ def preprocess(self, text, **kwargs): kwargs = new_kwargs preprocessor = self.processor if self.processor is not None else self.tokenizer + if self.model.config.model_type == "musicgen": + preprocessor = self.tokenizer + if isinstance(text, Chat): output = preprocessor.apply_chat_template( text.messages, From 25f481699af52f687a532d9472adb793b3205d40 Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 21 Nov 2025 15:31:56 +0100 Subject: [PATCH 03/24] Fallback to tokenizer for musicgen. --- src/transformers/pipelines/text_to_audio.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index b3b19e209f97..d71ac585448d 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -165,7 +165,8 @@ def preprocess(self, text, **kwargs): kwargs = new_kwargs preprocessor = self.processor if self.processor is not None else self.tokenizer - if self.model.config.model_type == "musicgen": + if self.model.config.model_type in ["musicgen"]: + # Fallback to legacy models that prefer tokenizer preprocessor = self.tokenizer if isinstance(text, Chat): From 29185d42e6b930bce7b58a4266c1fbd6c6a9a559 Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 21 Nov 2025 15:45:05 +0100 Subject: [PATCH 04/24] Make style --- src/transformers/pipelines/text_to_audio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index d71ac585448d..a9acb7dadd3b 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -174,7 +174,7 @@ def preprocess(self, text, **kwargs): text.messages, tokenize=True, return_dict=True, - **kwargs, + **kwargs, ) else: output = preprocessor(text, **kwargs, return_tensors="pt") @@ -272,7 +272,7 @@ def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, return super().__call__(chats, **forward_params) else: return super().__call__(list(chats), **forward_params) - + return super().__call__(text_inputs, **forward_params) def _sanitize_parameters( From 27a56cfc20ca4416a43e02b8783c61c89581a77f Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 21 Nov 2025 15:52:43 +0100 Subject: [PATCH 05/24] style/quality after update? --- src/transformers/pipelines/text_to_audio.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index a9acb7dadd3b..a33726081537 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -253,9 +253,7 @@ def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, """ if isinstance( text_inputs, - (list, tuple, types.GeneratorType) - if is_torch_available() - else (list, tuple, types.GeneratorType), + (list, tuple, types.GeneratorType) if is_torch_available() else (list, tuple, types.GeneratorType), ): if isinstance(text_inputs, types.GeneratorType): text_inputs, _ = itertools.tee(text_inputs) From a65412200a93372efb7df86738e9eb381d45ee67 Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 21 Nov 2025 16:02:46 +0100 Subject: [PATCH 06/24] FIx copied from --- src/transformers/pipelines/text_to_audio.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index a33726081537..d458ad44a2eb 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -30,11 +30,10 @@ DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan" -# Copied from transformers.pipelines.text_generation ChatType = list[dict[str, str]] -# Copied from transformers.pipelines.text_generation +# Copied from transformers.pipelines.text_generation.Chat class Chat: """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats to this format because the rest of the pipeline code tends to assume that lists of messages are From 13234d9802bf22f14df3746bf2e0eef5d3ecb7c8 Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 21 Nov 2025 17:48:54 +0100 Subject: [PATCH 07/24] Smaller things. --- src/transformers/pipelines/text_to_audio.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index d458ad44a2eb..2e2bf56367fc 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -100,7 +100,7 @@ class TextToAudioPipeline(Pipeline): """ _pipeline_calls_generate = True - _load_processor = None + _load_processor = None # prioritize processors as some models require it _load_image_processor = False _load_feature_extractor = False _load_tokenizer = True @@ -121,6 +121,10 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs): else vocoder ) + if self.model.config.model_type in ["musicgen"]: + # MusicGen expect to use the tokenizer + self.processor = None + self.sampling_rate = sampling_rate if self.vocoder is not None: self.sampling_rate = self.vocoder.config.sampling_rate @@ -164,10 +168,6 @@ def preprocess(self, text, **kwargs): kwargs = new_kwargs preprocessor = self.processor if self.processor is not None else self.tokenizer - if self.model.config.model_type in ["musicgen"]: - # Fallback to legacy models that prefer tokenizer - preprocessor = self.tokenizer - if isinstance(text, Chat): output = preprocessor.apply_chat_template( text.messages, @@ -230,7 +230,7 @@ def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information. Args: - text_inputs (`str` or `list[str]`): + text_inputs (`str`, `list[str]`, list[dict[str, str]], or `list[list[dict[str, str]]]`): One or several texts to generate. If strings or a list of string are passed, this pipeline will generate the corresponding text. Alternatively, a "chat", in the form of a list of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed, the model's chat From 729844b8a69148b9a54a8cb14301fa3cf771ea91 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 21 Nov 2025 17:51:29 +0100 Subject: [PATCH 08/24] Update src/transformers/pipelines/text_to_audio.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --- src/transformers/pipelines/text_to_audio.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 2e2bf56367fc..352dc0dc3418 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -250,10 +250,7 @@ def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform. - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform. """ - if isinstance( - text_inputs, - (list, tuple, types.GeneratorType) if is_torch_available() else (list, tuple, types.GeneratorType), - ): + if isinstance(text_inputs, (list, tuple, types.GeneratorType)): if isinstance(text_inputs, types.GeneratorType): text_inputs, _ = itertools.tee(text_inputs) text_inputs, first_item = (x for x in text_inputs), next(_) From bb96d9bb6d6c2764004d8df9ad3a4a16aaffae3c Mon Sep 17 00:00:00 2001 From: Eric B Date: Fri, 21 Nov 2025 19:42:15 +0100 Subject: [PATCH 09/24] Shift common utilities to chat template utils. --- src/transformers/pipelines/text_generation.py | 15 +-------------- src/transformers/pipelines/text_to_audio.py | 17 +---------------- src/transformers/utils/chat_template_utils.py | 19 +++++++++++++++++-- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 7950e6faf2da..68b16e32746b 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -5,6 +5,7 @@ from ..generation import GenerationConfig from ..utils import ModelOutput, add_end_docstrings, is_torch_available +from ..utils.chat_template_utils import Chat, ChatType from .base import Pipeline, build_pipeline_init_args @@ -14,8 +15,6 @@ from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from .pt_utils import KeyDataset -ChatType = list[dict[str, str]] - class ReturnType(enum.Enum): TENSORS = 0 @@ -23,18 +22,6 @@ class ReturnType(enum.Enum): FULL_TEXT = 2 -class Chat: - """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats - to this format because the rest of the pipeline code tends to assume that lists of messages are - actually a batch of samples rather than messages in the same conversation.""" - - def __init__(self, messages: dict): - for message in messages: - if not ("role" in message and "content" in message): - raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") - self.messages = messages - - @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True)) class TextGenerationPipeline(Pipeline): """ diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 352dc0dc3418..8699edf3f2c0 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -18,6 +18,7 @@ from ..generation import GenerationConfig from ..utils import is_torch_available +from ..utils.chat_template_utils import Chat, ChatType from .base import Pipeline @@ -30,22 +31,6 @@ DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan" -ChatType = list[dict[str, str]] - - -# Copied from transformers.pipelines.text_generation.Chat -class Chat: - """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats - to this format because the rest of the pipeline code tends to assume that lists of messages are - actually a batch of samples rather than messages in the same conversation.""" - - def __init__(self, messages: dict): - for message in messages: - if not ("role" in message and "content" in message): - raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") - self.messages = messages - - class TextToAudioPipeline(Pipeline): """ Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index e07228163c2b..e89c13d97604 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -53,6 +53,9 @@ from torch import Tensor +ChatType = list[dict[str, str]] + + BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) # Extracts the initial segment of the docstring, containing the function description description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) @@ -459,9 +462,9 @@ def strftime_now(format): def render_jinja_template( - conversations: list[list[dict[str, str]]], + conversations: list[ChatType], tools: list[dict | Callable] | None = None, - documents: list[dict[str, str]] | None = None, + documents: ChatType | None = None, chat_template: str | None = None, return_assistant_tokens_mask: bool = False, continue_final_message: bool = False, @@ -558,3 +561,15 @@ def render_jinja_template( rendered.append(rendered_chat) return rendered, all_generation_indices + + +class Chat: + """This class is intended to just be used internally for pipelines and not exposed to users. We convert chats + to this format because the rest of the pipeline code tends to assume that lists of messages are + actually a batch of samples rather than messages in the same conversation.""" + + def __init__(self, messages: dict): + for message in messages: + if not ("role" in message and "content" in message): + raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") + self.messages = messages From f801b96e11146cdcab05943d42ed6a1f6e5bd708 Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 24 Nov 2025 15:14:28 +0100 Subject: [PATCH 10/24] Type nits --- src/transformers/pipelines/text_generation.py | 2 +- src/transformers/pipelines/text_to_audio.py | 15 +++++++++------ src/transformers/utils/chat_template_utils.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 68b16e32746b..d49b26f59963 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -250,7 +250,7 @@ def __call__(self, text_inputs, **kwargs): Complete the prompt(s) given as inputs. Args: - text_inputs (`str`, `list[str]`, list[dict[str, str]], or `list[list[dict[str, str]]]`): + text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`): One or several prompts (or one list of prompts) to complete. If strings or a list of string are passed, this pipeline will continue each prompt. Alternatively, a "chat", in the form of a list of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed, diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 8699edf3f2c0..c2f985423e5b 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -22,6 +22,9 @@ from .base import Pipeline +AudioOutput = dict[str, Any] # {"audio": np.ndarray, "sampling_rate": int} + + if is_torch_available(): import torch @@ -199,23 +202,23 @@ def _forward(self, model_inputs, **kwargs): return output @overload - def __call__(self, text_inputs: str, **forward_params: Any) -> dict[str, Any]: ... + def __call__(self, text_inputs: str, **forward_params: Any) -> AudioOutput: ... @overload - def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[dict[str, Any]]: ... + def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[AudioOutput]: ... @overload - def __call__(self, text_inputs: ChatType, **forward_params: Any) -> list[dict[str, ChatType]]: ... + def __call__(self, text_inputs: ChatType, **forward_params: Any) -> AudioOutput: ... @overload - def __call__(self, text_inputs: list[ChatType], **forward_params: Any) -> list[list[dict[str, ChatType]]]: ... + def __call__(self, text_inputs: list[ChatType], **forward_params: Any) -> list[AudioOutput]: ... - def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, Any] | list[dict[str, Any]]: + def __call__(self, text_inputs, **forward_params): """ Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information. Args: - text_inputs (`str`, `list[str]`, list[dict[str, str]], or `list[list[dict[str, str]]]`): + text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`): One or several texts to generate. If strings or a list of string are passed, this pipeline will generate the corresponding text. Alternatively, a "chat", in the form of a list of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed, the model's chat diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index e89c13d97604..2024caca870c 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -53,7 +53,7 @@ from torch import Tensor -ChatType = list[dict[str, str]] +ChatType = list[dict[str, Any]] BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) From 716c3cfda633103a5cf2111ecb1a2746cdcc54b1 Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 24 Nov 2025 16:22:47 +0100 Subject: [PATCH 11/24] Remove tied weights. --- src/transformers/models/seamless_m4t/modeling_seamless_m4t.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 13ebc98ef56d..885acd90f02c 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1981,7 +1981,6 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, "text_encoder", "text_decoder", ] - _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__( self, From c0060d1a1567d29d606d4b80350c8e32cf3e4643 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 25 Nov 2025 07:58:22 +0100 Subject: [PATCH 12/24] Keep seamless error --- src/transformers/models/seamless_m4t/modeling_seamless_m4t.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 885acd90f02c..13ebc98ef56d 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1981,6 +1981,7 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, "text_encoder", "text_decoder", ] + _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__( self, From f8e90cf72a1754418397170decf0928ce93beb27 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 25 Nov 2025 09:01:10 +0100 Subject: [PATCH 13/24] Better audio output object. --- src/transformers/pipelines/text_to_audio.py | 28 ++++++++++++++------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index c2f985423e5b..51251110f702 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -14,17 +14,15 @@ import itertools import types -from typing import Any, overload +from typing import Any, TypedDict, overload +from ..audio_utils import AudioInput from ..generation import GenerationConfig from ..utils import is_torch_available from ..utils.chat_template_utils import Chat, ChatType from .base import Pipeline -AudioOutput = dict[str, Any] # {"audio": np.ndarray, "sampling_rate": int} - - if is_torch_available(): import torch @@ -34,6 +32,17 @@ DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan" +class AudioOutput(TypedDict, total=False): + """ + audio (`AudioInput`): + The generated audio waveform. + sampling_rate (`int`): + The sampling rate of the generated audio waveform. + """ + audio: AudioInput + sampling_rate: int + + class TextToAudioPipeline(Pipeline): """ Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This @@ -281,7 +290,6 @@ def _sanitize_parameters( return preprocess_params, params, postprocess_params def postprocess(self, audio): - output_dict = {} if self.model.config.model_type == "csm": waveform_key = "audio" @@ -296,9 +304,11 @@ def postprocess(self, audio): waveform = audio if isinstance(audio, list): - output_dict["audio"] = [el.to(device="cpu", dtype=torch.float).numpy() for el in waveform] + audio = [el.to(device="cpu", dtype=torch.float).numpy() for el in waveform] else: - output_dict["audio"] = waveform.to(device="cpu", dtype=torch.float).numpy() - output_dict["sampling_rate"] = self.sampling_rate + audio = waveform.to(device="cpu", dtype=torch.float).numpy() - return output_dict + return AudioOutput( + audio=audio, + sampling_rate=self.sampling_rate, + ) From f5fb635f9d56a243853fb9b6d10b75e6e2ce5194 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 25 Nov 2025 11:45:23 +0100 Subject: [PATCH 14/24] Properly handle DIa and add test. --- src/transformers/pipelines/text_to_audio.py | 30 +++++++++++-------- .../pipelines/test_pipelines_text_to_audio.py | 29 ++++++++++++++++-- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 51251110f702..08f858a7bf3c 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -249,8 +249,9 @@ def __call__(self, text_inputs, **forward_params): """ if isinstance(text_inputs, (list, tuple, types.GeneratorType)): if isinstance(text_inputs, types.GeneratorType): - text_inputs, _ = itertools.tee(text_inputs) - text_inputs, first_item = (x for x in text_inputs), next(_) + gen_copy1, gen_copy2 = itertools.tee(text_inputs) + text_inputs = (x for x in gen_copy1) + first_item = next(gen_copy2) else: first_item = text_inputs[0] if isinstance(first_item, (list, tuple, dict)): @@ -291,22 +292,25 @@ def _sanitize_parameters( def postprocess(self, audio): - if self.model.config.model_type == "csm": - waveform_key = "audio" + if self.model.config.model_type in ["csm"]: + audio_key = "audio" else: - waveform_key = "waveform" + audio_key = "waveform" if isinstance(audio, dict): - waveform = audio[waveform_key] + audio = audio[audio_key] elif isinstance(audio, tuple): - waveform = audio[0] + audio = audio[0] + elif self.model.config.model_type in ["dia"]: + # models that require decoding, e.g. with codec + audio = self.processor.decode(audio) + + if isinstance(audio, list) and len(audio) > 1: + audio = [el.to(device="cpu", dtype=torch.float).numpy() for el in audio] + elif isinstance(audio, list): + audio = audio[0].to(device="cpu", dtype=torch.float).numpy() else: - waveform = audio - - if isinstance(audio, list): - audio = [el.to(device="cpu", dtype=torch.float).numpy() for el in waveform] - else: - audio = waveform.to(device="cpu", dtype=torch.float).numpy() + audio = audio.to(device="cpu", dtype=torch.float).numpy() return AudioOutput( audio=audio, diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index c4870bc1dff6..700ceedd03c3 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -248,14 +248,16 @@ def test_generative_model_kwargs(self): def test_csm_model_pt(self): speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b") - outputs = speech_generator("[0]This is a test") + outputs = speech_generator("[0]This is a test", generate_kwargs={"output_audio": True, "max_new_tokens": 32}) self.assertEqual(outputs["sampling_rate"], 24000) audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) + # ensure audio and not codes + self.assertEqual(len(audio.shape), 1) # test two examples side-by-side - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"]) + outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs={"output_audio": True, "max_new_tokens": 32}) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) @@ -263,6 +265,29 @@ def test_csm_model_pt(self): outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], batch_size=2) self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) + @slow + @require_torch + def test_dia_model(self): + speech_generator = pipeline(task="text-to-audio", model="nari-labs/Dia-1.6B-0626") + + outputs = speech_generator( + "[S1] Dia is an open weights text to dialogue model.", + generate_kwargs={"max_new_tokens": 32}, + ) + self.assertEqual(outputs["sampling_rate"], 44100) + + audio = outputs["audio"] + self.assertEqual(ANY(np.ndarray), audio) + # ensure audio and not codes + self.assertEqual(len(audio.shape), 1) + + # test batch + outputs = speech_generator( + ["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], + generate_kwargs={"max_new_tokens": 32}, + ) + self.assertEqual(len(outputs), 2) + def get_test_pipeline( self, model, From ed25458bdf6ad4b907e58744aea21c7b63042762 Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 25 Nov 2025 16:35:11 +0100 Subject: [PATCH 15/24] Shift chat template prep to base, test dia batch. --- src/transformers/pipelines/base.py | 26 ++++++++++++++ src/transformers/pipelines/text_generation.py | 24 ------------- src/transformers/pipelines/text_to_audio.py | 28 ++++----------- .../pipelines/test_pipelines_text_to_audio.py | 35 ++++++++++++++----- 4 files changed, 58 insertions(+), 55 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 92867eee1529..82a3bea9c35d 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -16,6 +16,7 @@ import copy import csv import importlib +import itertools import json import os import pickle @@ -51,6 +52,7 @@ is_torch_xpu_available, logging, ) +from ..utils.chat_template_utils import Chat GenericTensor = Union[list["GenericTensor"], "torch.Tensor"] @@ -60,6 +62,7 @@ from torch.utils.data import DataLoader, Dataset from ..modeling_utils import PreTrainedModel + from .pt_utils import KeyDataset else: Dataset = None @@ -1202,6 +1205,29 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): if args: logger.warning(f"Ignoring args : {args}") + # Detect if inputs is a chat-style input and cast as `Chat` or list of `Chat` + if isinstance( + inputs, + (list, tuple, types.GeneratorType, KeyDataset) + if is_torch_available() + else (list, tuple, types.GeneratorType), + ): + if isinstance(inputs, types.GeneratorType): + gen_copy1, gen_copy2 = itertools.tee(inputs) + inputs = (x for x in gen_copy1) + first_item = next(gen_copy2) + else: + first_item = inputs[0] + if isinstance(first_item, (list, tuple, dict)): + if isinstance(first_item, dict): + inputs = Chat(inputs) + else: + chats = (Chat(chat) for chat in inputs) + if isinstance(inputs, types.GeneratorType): + inputs = chats + else: + inputs = list(chats) + if num_workers is None: if self._num_workers is None: num_workers = 0 diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index d49b26f59963..dc1e2035393e 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,6 +1,4 @@ import enum -import itertools -import types from typing import Any, overload from ..generation import GenerationConfig @@ -13,7 +11,6 @@ import torch from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - from .pt_utils import KeyDataset class ReturnType(enum.Enum): @@ -295,27 +292,6 @@ def __call__(self, text_inputs, **kwargs): - **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token ids of the generated text. """ - if isinstance( - text_inputs, - (list, tuple, types.GeneratorType, KeyDataset) - if is_torch_available() - else (list, tuple, types.GeneratorType), - ): - if isinstance(text_inputs, types.GeneratorType): - text_inputs, _ = itertools.tee(text_inputs) - text_inputs, first_item = (x for x in text_inputs), next(_) - else: - first_item = text_inputs[0] - if isinstance(first_item, (list, tuple, dict)): - # We have one or more prompts in list-of-dicts format, so this is chat mode - if isinstance(first_item, dict): - return super().__call__(Chat(text_inputs), **kwargs) - else: - chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈 - if isinstance(text_inputs, types.GeneratorType): - return super().__call__(chats, **kwargs) - else: - return super().__call__(list(chats), **kwargs) return super().__call__(text_inputs, **kwargs) def preprocess( diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 08f858a7bf3c..0abb69c45cf1 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License.from typing import List, Union -import itertools -import types from typing import Any, TypedDict, overload from ..audio_utils import AudioInput @@ -242,29 +240,11 @@ def __call__(self, text_inputs, **forward_params): only passed to the underlying model if the latter is a generative model. Return: - A `dict` or a list of `dict`: The dictionaries have two keys: + `AudioOutput` or a list of `AudioOutput`, which is a `TypedDict` with two keys: - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform. - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform. """ - if isinstance(text_inputs, (list, tuple, types.GeneratorType)): - if isinstance(text_inputs, types.GeneratorType): - gen_copy1, gen_copy2 = itertools.tee(text_inputs) - text_inputs = (x for x in gen_copy1) - first_item = next(gen_copy2) - else: - first_item = text_inputs[0] - if isinstance(first_item, (list, tuple, dict)): - # We have one or more prompts in list-of-dicts format, so this is chat mode - if isinstance(first_item, dict): - return super().__call__(Chat(text_inputs), **forward_params) - else: - chats = (Chat(chat) for chat in text_inputs) - if isinstance(text_inputs, types.GeneratorType): - return super().__call__(chats, **forward_params) - else: - return super().__call__(list(chats), **forward_params) - return super().__call__(text_inputs, **forward_params) def _sanitize_parameters( @@ -294,6 +274,9 @@ def postprocess(self, audio): if self.model.config.model_type in ["csm"]: audio_key = "audio" + elif self.model.config.model_type in ["dia"]: + # codes that need decoding + audio_key = "sequences" else: audio_key = "waveform" @@ -301,7 +284,8 @@ def postprocess(self, audio): audio = audio[audio_key] elif isinstance(audio, tuple): audio = audio[0] - elif self.model.config.model_type in ["dia"]: + + if self.model.config.model_type in ["dia"]: # models that require decoding, e.g. with codec audio = self.processor.decode(audio) diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index 700ceedd03c3..ed6391c8c49e 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -246,9 +246,10 @@ def test_generative_model_kwargs(self): @slow @require_torch def test_csm_model_pt(self): - speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b") + speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device) + max_new_tokens = 10 - outputs = speech_generator("[0]This is a test", generate_kwargs={"output_audio": True, "max_new_tokens": 32}) + outputs = speech_generator("[0]This is a test", generate_kwargs={"output_audio": True, "max_new_tokens": max_new_tokens}) self.assertEqual(outputs["sampling_rate"], 24000) audio = outputs["audio"] @@ -257,22 +258,28 @@ def test_csm_model_pt(self): self.assertEqual(len(audio.shape), 1) # test two examples side-by-side - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs={"output_audio": True, "max_new_tokens": 32}) + outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs={"output_audio": True, "max_new_tokens": max_new_tokens}) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + # ensure audio and not codes + self.assertEqual(len(audio[0].shape), 1) # test batching - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], batch_size=2) + generate_kwargs = {"output_audio": True, "max_new_tokens": max_new_tokens, "return_dict_in_generate": True} + outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=2) self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) + # ensure audio and not codes + self.assertEqual(len(outputs[0]["audio"].squeeze().shape), 1) @slow @require_torch def test_dia_model(self): - speech_generator = pipeline(task="text-to-audio", model="nari-labs/Dia-1.6B-0626") + speech_generator = pipeline(task="text-to-audio", model="nari-labs/Dia-1.6B-0626", device=torch_device) + max_new_tokens = 20 outputs = speech_generator( "[S1] Dia is an open weights text to dialogue model.", - generate_kwargs={"max_new_tokens": 32}, + generate_kwargs={"max_new_tokens": max_new_tokens}, ) self.assertEqual(outputs["sampling_rate"], 44100) @@ -281,12 +288,22 @@ def test_dia_model(self): # ensure audio and not codes self.assertEqual(len(audio.shape), 1) - # test batch + # test two examples side-by-side outputs = speech_generator( ["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], - generate_kwargs={"max_new_tokens": 32}, + generate_kwargs={"max_new_tokens": max_new_tokens}, ) - self.assertEqual(len(outputs), 2) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + # ensure audio and not codes + self.assertEqual(len(audio[0].shape), 1) + + # test batching + generate_kwargs = {"max_new_tokens": max_new_tokens, "return_dict_in_generate": True} + outputs = speech_generator(["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], generate_kwargs=generate_kwargs, batch_size=2) + self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) + # ensure audio and not codes + self.assertEqual(len(outputs[0]["audio"].squeeze().shape), 1) def get_test_pipeline( self, From 97158c6e0d0f9db5816a7fb65232b70718cdd5f3 Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 12:07:55 +0100 Subject: [PATCH 16/24] Backward compatibility for dicts passed to pipelines --- src/transformers/pipelines/base.py | 45 ++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 82a3bea9c35d..c53c73926ea6 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -76,6 +76,26 @@ def no_collate_fn(items): return items[0] +def is_valid_chat(chat): + """ + Check that input is a valid chat, namely list of messages dicts that have "role" and "content" keys. + """ + is_iterable = isinstance( + chat, + (list, tuple, types.GeneratorType, KeyDataset) + if is_torch_available() + else (list, tuple, types.GeneratorType), + ) + if not is_iterable: + return False + for message in chat: + if not isinstance(message, dict): + return False + if not ("role" in message and "content" in message): + return False + return True + + def _pad(items, key, padding_value, padding_side): batch_size = len(items) if isinstance(items[0][key], torch.Tensor): @@ -1205,24 +1225,27 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): if args: logger.warning(f"Ignoring args : {args}") - # Detect if inputs is a chat-style input and cast as `Chat` or list of `Chat` - if isinstance( - inputs, - (list, tuple, types.GeneratorType, KeyDataset) - if is_torch_available() - else (list, tuple, types.GeneratorType), - ): + # Detect if inputs are a chat-style input(s) and cast as `Chat` or list of `Chat` + container_types = (list, tuple, types.GeneratorType) + if is_torch_available(): + container_types = (*container_types, KeyDataset) + if isinstance(inputs, container_types): + # get first item to see if a single chat or list of chats if isinstance(inputs, types.GeneratorType): gen_copy1, gen_copy2 = itertools.tee(inputs) inputs = (x for x in gen_copy1) first_item = next(gen_copy2) else: first_item = inputs[0] - if isinstance(first_item, (list, tuple, dict)): - if isinstance(first_item, dict): + + if isinstance(first_item, dict): + if is_valid_chat(inputs): inputs = Chat(inputs) - else: - chats = (Chat(chat) for chat in inputs) + elif isinstance(first_item, (list, tuple)): + # materialize generator is needed + items = list(inputs) if isinstance(inputs, types.GeneratorType) else inputs + if all(is_valid_chat(chat) for chat in items): + chats = (Chat(chat) for chat in items) if isinstance(inputs, types.GeneratorType): inputs = chats else: From 5b80b0317abf255a65a96a9f852739e2e8e43687 Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 12:16:45 +0100 Subject: [PATCH 17/24] Simplify postprocessing and tests. --- src/transformers/pipelines/text_to_audio.py | 23 ++++---- .../pipelines/test_pipelines_text_to_audio.py | 53 +++++++++---------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 0abb69c45cf1..e640c185186d 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -192,6 +192,10 @@ def _forward(self, model_inputs, **kwargs): # generate_kwargs get priority over forward_params forward_params.update(generate_kwargs) + # ensure dict output to facilitate postprocessing + if self.model.config.model_type not in ["bark", "musicgen"]: + forward_params.update({"return_dict_in_generate": True}) + output = self.model.generate(**model_inputs, **forward_params) else: if len(generate_kwargs): @@ -253,6 +257,7 @@ def _sanitize_parameters( forward_params=None, generate_kwargs=None, ): + if getattr(self, "assistant_model", None) is not None: generate_kwargs["assistant_model"] = self.assistant_model if getattr(self, "assistant_tokenizer", None) is not None: @@ -272,21 +277,17 @@ def _sanitize_parameters( def postprocess(self, audio): - if self.model.config.model_type in ["csm"]: - audio_key = "audio" - elif self.model.config.model_type in ["dia"]: - # codes that need decoding - audio_key = "sequences" - else: - audio_key = "waveform" - + needs_decoding = False if isinstance(audio, dict): - audio = audio[audio_key] + if "audio" in audio: + audio = audio["audio"] + else: + needs_decoding = True + audio = audio["sequences"] elif isinstance(audio, tuple): audio = audio[0] - if self.model.config.model_type in ["dia"]: - # models that require decoding, e.g. with codec + if needs_decoding: audio = self.processor.decode(audio) if isinstance(audio, list) and len(audio) > 1: diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index ed6391c8c49e..e2e13bf63f4e 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -247,63 +247,60 @@ def test_generative_model_kwargs(self): @require_torch def test_csm_model_pt(self): speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device) - max_new_tokens = 10 + generate_kwargs = {"max_new_tokens": 10, "output_audio": True} + n_channel = 1 # model generates mono audio - outputs = speech_generator("[0]This is a test", generate_kwargs={"output_audio": True, "max_new_tokens": max_new_tokens}) + outputs = speech_generator("[0]This is a test", generate_kwargs=generate_kwargs) self.assertEqual(outputs["sampling_rate"], 24000) - audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) - # ensure audio and not codes - self.assertEqual(len(audio.shape), 1) + # ensure audio and not discrete codes + self.assertEqual(len(audio.shape), n_channel) # test two examples side-by-side - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs={"output_audio": True, "max_new_tokens": max_new_tokens}) + outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - # ensure audio and not codes - self.assertEqual(len(audio[0].shape), 1) + self.assertEqual(len(audio[0].shape), n_channel) # test batching - generate_kwargs = {"output_audio": True, "max_new_tokens": max_new_tokens, "return_dict_in_generate": True} - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=2) - self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) - # ensure audio and not codes - self.assertEqual(len(outputs[0]["audio"].squeeze().shape), 1) + batch_size = 2 + outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size) + self.assertEqual(len(outputs), batch_size) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + self.assertEqual(len(outputs[0]["audio"].squeeze().shape), n_channel) @slow @require_torch def test_dia_model(self): speech_generator = pipeline(task="text-to-audio", model="nari-labs/Dia-1.6B-0626", device=torch_device) - max_new_tokens = 20 + generate_kwargs = {"max_new_tokens": 20} + n_channel = 1 # model generates mono audio - outputs = speech_generator( - "[S1] Dia is an open weights text to dialogue model.", - generate_kwargs={"max_new_tokens": max_new_tokens}, - ) + outputs = speech_generator("[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs) self.assertEqual(outputs["sampling_rate"], 44100) - audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) - # ensure audio and not codes - self.assertEqual(len(audio.shape), 1) + # ensure audio (with one channel) and not discrete codes + self.assertEqual(len(audio.shape), n_channel) # test two examples side-by-side outputs = speech_generator( ["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], - generate_kwargs={"max_new_tokens": max_new_tokens}, + generate_kwargs=generate_kwargs, ) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - # ensure audio and not codes - self.assertEqual(len(audio[0].shape), 1) + self.assertEqual(len(audio[0].shape), n_channel) # test batching - generate_kwargs = {"max_new_tokens": max_new_tokens, "return_dict_in_generate": True} + batch_size = 2 outputs = speech_generator(["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], generate_kwargs=generate_kwargs, batch_size=2) - self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) - # ensure audio and not codes - self.assertEqual(len(outputs[0]["audio"].squeeze().shape), 1) + self.assertEqual(len(outputs), batch_size) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + self.assertEqual(len(outputs[0]["audio"].squeeze().shape), n_channel) def get_test_pipeline( self, From 838daf748ca3053e60e6d9520d83ea1cf765b9bc Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 13:57:41 +0100 Subject: [PATCH 18/24] Make quality/style --- src/transformers/pipelines/base.py | 4 +--- src/transformers/pipelines/text_to_audio.py | 3 +-- .../pipelines/test_pipelines_text_to_audio.py | 18 +++++++++++++----- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index f88f73844ba6..9a40a64d9d2f 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -82,9 +82,7 @@ def is_valid_chat(chat): """ is_iterable = isinstance( chat, - (list, tuple, types.GeneratorType, KeyDataset) - if is_torch_available() - else (list, tuple, types.GeneratorType), + (list, tuple, types.GeneratorType, KeyDataset) if is_torch_available() else (list, tuple, types.GeneratorType), ) if not is_iterable: return False diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index e640c185186d..64e749165be1 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -37,6 +37,7 @@ class AudioOutput(TypedDict, total=False): sampling_rate (`int`): The sampling rate of the generated audio waveform. """ + audio: AudioInput sampling_rate: int @@ -257,7 +258,6 @@ def _sanitize_parameters( forward_params=None, generate_kwargs=None, ): - if getattr(self, "assistant_model", None) is not None: generate_kwargs["assistant_model"] = self.assistant_model if getattr(self, "assistant_tokenizer", None) is not None: @@ -276,7 +276,6 @@ def _sanitize_parameters( return preprocess_params, params, postprocess_params def postprocess(self, audio): - needs_decoding = False if isinstance(audio, dict): if "audio" in audio: diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index e2e13bf63f4e..ebc041e25ab0 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -248,7 +248,7 @@ def test_generative_model_kwargs(self): def test_csm_model_pt(self): speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device) generate_kwargs = {"max_new_tokens": 10, "output_audio": True} - n_channel = 1 # model generates mono audio + n_channel = 1 # model generates mono audio outputs = speech_generator("[0]This is a test", generate_kwargs=generate_kwargs) self.assertEqual(outputs["sampling_rate"], 24000) @@ -265,7 +265,9 @@ def test_csm_model_pt(self): # test batching batch_size = 2 - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size) + outputs = speech_generator( + ["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size + ) self.assertEqual(len(outputs), batch_size) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) @@ -276,9 +278,11 @@ def test_csm_model_pt(self): def test_dia_model(self): speech_generator = pipeline(task="text-to-audio", model="nari-labs/Dia-1.6B-0626", device=torch_device) generate_kwargs = {"max_new_tokens": 20} - n_channel = 1 # model generates mono audio + n_channel = 1 # model generates mono audio - outputs = speech_generator("[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs) + outputs = speech_generator( + "[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs + ) self.assertEqual(outputs["sampling_rate"], 44100) audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) @@ -296,7 +300,11 @@ def test_dia_model(self): # test batching batch_size = 2 - outputs = speech_generator(["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], generate_kwargs=generate_kwargs, batch_size=2) + outputs = speech_generator( + ["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], + generate_kwargs=generate_kwargs, + batch_size=2, + ) self.assertEqual(len(outputs), batch_size) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) From e3b882dd9818be894db387fc4e58339c9cc9c092 Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 14:33:28 +0100 Subject: [PATCH 19/24] Remove chat from image text to text. --- src/transformers/pipelines/image_text_to_text.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 537c8dac491e..7d0aa2d47f2d 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -25,6 +25,7 @@ logging, requires_backends, ) +from ..utils.chat_template_utils import Chat from .base import Pipeline, build_pipeline_init_args @@ -49,18 +50,6 @@ class ReturnType(enum.Enum): FULL_TEXT = 2 -class Chat: - """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats - to this format because the rest of the pipeline code tends to assume that lists of messages are - actually a batch of samples rather than messages in the same conversation.""" - - def __init__(self, messages: dict): - for message in messages: - if not ("role" in message and "content" in message): - raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") - self.messages = messages - - @add_end_docstrings(build_pipeline_init_args(has_processor=True)) class ImageTextToTextPipeline(Pipeline): """ From 661cafbd7f7524b5e4326f17e5db72e6cd969c4d Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 15:40:14 +0100 Subject: [PATCH 20/24] Only check first item, to not consume first item of generator inputs. --- src/transformers/pipelines/base.py | 24 +++---------------- src/transformers/utils/chat_template_utils.py | 13 +++++++++- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 9a40a64d9d2f..b1e87da93399 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -52,7 +52,7 @@ is_torch_xpu_available, logging, ) -from ..utils.chat_template_utils import Chat +from ..utils.chat_template_utils import Chat, is_valid_message GenericTensor = Union[list["GenericTensor"], "torch.Tensor"] @@ -76,24 +76,6 @@ def no_collate_fn(items): return items[0] -def is_valid_chat(chat): - """ - Check that input is a valid chat, namely list of messages dicts that have "role" and "content" keys. - """ - is_iterable = isinstance( - chat, - (list, tuple, types.GeneratorType, KeyDataset) if is_torch_available() else (list, tuple, types.GeneratorType), - ) - if not is_iterable: - return False - for message in chat: - if not isinstance(message, dict): - return False - if not ("role" in message and "content" in message): - return False - return True - - def _pad(items, key, padding_value, padding_side): batch_size = len(items) if isinstance(items[0][key], torch.Tensor): @@ -1237,12 +1219,12 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): first_item = inputs[0] if isinstance(first_item, dict): - if is_valid_chat(inputs): + if is_valid_message(first_item): inputs = Chat(inputs) elif isinstance(first_item, (list, tuple)): # materialize generator is needed items = list(inputs) if isinstance(inputs, types.GeneratorType) else inputs - if all(is_valid_chat(chat) for chat in items): + if all(is_valid_message(chat[0]) for chat in items): chats = (Chat(chat) for chat in items) if isinstance(inputs, types.GeneratorType): inputs = chats diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 2024caca870c..ed3a6daee73e 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -563,6 +563,17 @@ def render_jinja_template( return rendered, all_generation_indices +def is_valid_message(message): + """ + Check that input is a valid message in a chat, namely a dict with "role" and "content" keys. + """ + if not isinstance(message, dict): + return False + if not ("role" in message and "content" in message): + return False + return True + + class Chat: """This class is intended to just be used internally for pipelines and not exposed to users. We convert chats to this format because the rest of the pipeline code tends to assume that lists of messages are @@ -570,6 +581,6 @@ class Chat: def __init__(self, messages: dict): for message in messages: - if not ("role" in message and "content" in message): + if not is_valid_message(message): raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") self.messages = messages From 94410f4d06d8a2e526887474cd4feed66788d5b6 Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 15:54:30 +0100 Subject: [PATCH 21/24] Nit --- src/transformers/pipelines/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index b1e87da93399..785ce396a846 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1218,9 +1218,8 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): else: first_item = inputs[0] - if isinstance(first_item, dict): - if is_valid_message(first_item): - inputs = Chat(inputs) + if is_valid_message(first_item): + inputs = Chat(inputs) elif isinstance(first_item, (list, tuple)): # materialize generator is needed items = list(inputs) if isinstance(inputs, types.GeneratorType) else inputs From 79d7e730680b9856b2e5b0732259265349818bf2 Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 17:42:47 +0100 Subject: [PATCH 22/24] Simplify --- src/transformers/models/bark/modeling_bark.py | 11 +++++++--- src/transformers/pipelines/base.py | 22 ++++--------------- src/transformers/pipelines/text_to_audio.py | 14 +++++------- .../test_pipelines_text_generation.py | 3 --- .../pipelines/test_pipelines_text_to_audio.py | 16 +++++++------- 5 files changed, 26 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index fe3afdc7bbd2..2dfc16b7c113 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -651,8 +651,10 @@ def generate( ) # size: 10048 # take the generated semantic tokens - semantic_output = semantic_output[:, max_input_semantic_length + 1 :] - + if kwargs.get("return_dict_in_generate", False): + semantic_output = semantic_output.sequences[:, max_input_semantic_length + 1 :] + else: + semantic_output = semantic_output[:, max_input_semantic_length + 1 :] return semantic_output @@ -865,7 +867,10 @@ def generate( input_coarse_len = input_coarse.shape[1] - x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]]) + if kwargs.get("return_dict_in_generate", False): + x_coarse = torch.hstack([x_coarse, output_coarse.sequences[:, input_coarse_len:]]) + else: + x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]]) total_generated_len = x_coarse.shape[1] - len_coarse_history del output_coarse diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 785ce396a846..6721bcfbb67b 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -16,7 +16,6 @@ import copy import csv import importlib -import itertools import json import os import pickle @@ -1210,25 +1209,12 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): if is_torch_available(): container_types = (*container_types, KeyDataset) if isinstance(inputs, container_types): - # get first item to see if a single chat or list of chats if isinstance(inputs, types.GeneratorType): - gen_copy1, gen_copy2 = itertools.tee(inputs) - inputs = (x for x in gen_copy1) - first_item = next(gen_copy2) - else: - first_item = inputs[0] - - if is_valid_message(first_item): + inputs = list(inputs) + if is_valid_message(inputs[0]): inputs = Chat(inputs) - elif isinstance(first_item, (list, tuple)): - # materialize generator is needed - items = list(inputs) if isinstance(inputs, types.GeneratorType) else inputs - if all(is_valid_message(chat[0]) for chat in items): - chats = (Chat(chat) for chat in items) - if isinstance(inputs, types.GeneratorType): - inputs = chats - else: - inputs = list(chats) + elif isinstance(inputs[0], (list, tuple)) and all(chat and is_valid_message(chat[0]) for chat in inputs): + inputs = [Chat(chat) for chat in inputs] if num_workers is None: if self._num_workers is None: diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 64e749165be1..be7a9b9bc0c8 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -194,8 +194,7 @@ def _forward(self, model_inputs, **kwargs): forward_params.update(generate_kwargs) # ensure dict output to facilitate postprocessing - if self.model.config.model_type not in ["bark", "musicgen"]: - forward_params.update({"return_dict_in_generate": True}) + forward_params.update({"return_dict_in_generate": True}) output = self.model.generate(**model_inputs, **forward_params) else: @@ -286,15 +285,14 @@ def postprocess(self, audio): elif isinstance(audio, tuple): audio = audio[0] - if needs_decoding: + if needs_decoding and self.processor is not None: audio = self.processor.decode(audio) - if isinstance(audio, list) and len(audio) > 1: - audio = [el.to(device="cpu", dtype=torch.float).numpy() for el in audio] - elif isinstance(audio, list): - audio = audio[0].to(device="cpu", dtype=torch.float).numpy() + if isinstance(audio, list): + audio = [el.to(device="cpu", dtype=torch.float).numpy().squeeze() for el in audio] + audio = audio if len(audio) > 1 else audio[0] else: - audio = audio.to(device="cpu", dtype=torch.float).numpy() + audio = audio.to(device="cpu", dtype=torch.float).numpy().squeeze() return AudioOutput( audio=audio, diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 4f7aa91c5094..e814b16eb19e 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -220,8 +220,6 @@ def __getitem__(self, i): @require_torch def test_small_chat_model_with_iterator_pt(self): - from transformers.pipelines.pt_utils import PipelineIterator - text_generator = pipeline( task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", @@ -253,7 +251,6 @@ def data(): yield from [chat1, chat2] outputs = text_generator(data(), do_sample=False, max_new_tokens=10) - assert isinstance(outputs, PipelineIterator) outputs = list(outputs) self.assertEqual( outputs, diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index ebc041e25ab0..2e2172c5887b 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -248,20 +248,20 @@ def test_generative_model_kwargs(self): def test_csm_model_pt(self): speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device) generate_kwargs = {"max_new_tokens": 10, "output_audio": True} - n_channel = 1 # model generates mono audio + n_ch = 1 # model generates mono audio outputs = speech_generator("[0]This is a test", generate_kwargs=generate_kwargs) self.assertEqual(outputs["sampling_rate"], 24000) audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) # ensure audio and not discrete codes - self.assertEqual(len(audio.shape), n_channel) + self.assertEqual(len(audio.shape), n_ch) # test two examples side-by-side outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - self.assertEqual(len(audio[0].shape), n_channel) + self.assertEqual(len(audio[0].shape), n_ch) # test batching batch_size = 2 @@ -271,14 +271,14 @@ def test_csm_model_pt(self): self.assertEqual(len(outputs), batch_size) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - self.assertEqual(len(outputs[0]["audio"].squeeze().shape), n_channel) + self.assertEqual(len(outputs[0]["audio"].shape), n_ch) @slow @require_torch def test_dia_model(self): speech_generator = pipeline(task="text-to-audio", model="nari-labs/Dia-1.6B-0626", device=torch_device) generate_kwargs = {"max_new_tokens": 20} - n_channel = 1 # model generates mono audio + n_ch = 1 # model generates mono audio outputs = speech_generator( "[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs @@ -287,7 +287,7 @@ def test_dia_model(self): audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) # ensure audio (with one channel) and not discrete codes - self.assertEqual(len(audio.shape), n_channel) + self.assertEqual(len(audio.shape), n_ch) # test two examples side-by-side outputs = speech_generator( @@ -296,7 +296,7 @@ def test_dia_model(self): ) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - self.assertEqual(len(audio[0].shape), n_channel) + self.assertEqual(len(audio[0].shape), n_ch) # test batching batch_size = 2 @@ -308,7 +308,7 @@ def test_dia_model(self): self.assertEqual(len(outputs), batch_size) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - self.assertEqual(len(outputs[0]["audio"].squeeze().shape), n_channel) + self.assertEqual(len(outputs[0]["audio"].shape), n_ch) def get_test_pipeline( self, From 7a438ff0598b220c42111e12f67782a79873c510 Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 17:51:56 +0100 Subject: [PATCH 23/24] Add checks for bark/musicgen to ensure output is audio. --- tests/pipelines/test_pipelines_text_to_audio.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index 2e2172c5887b..acd4e1de14e1 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -45,9 +45,11 @@ def test_small_musicgen_pt(self): music_generator = pipeline( task="text-to-audio", model="facebook/musicgen-small", do_sample=False, max_new_tokens=5 ) + n_ch = 1 # model generates mono audio outputs = music_generator("This is a test") self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 32000}, outputs) + self.assertEqual(len(outputs["audio"].shape), n_ch) # test two examples side-by-side outputs = music_generator(["This is a test", "This is a second test"]) @@ -88,6 +90,7 @@ def test_medium_seamless_m4t_pt(self): @require_torch def test_small_bark_pt(self): speech_generator = pipeline(task="text-to-audio", model="suno/bark-small") + n_ch = 1 # model generates mono audio forward_params = { # Using `do_sample=False` to force deterministic output @@ -100,6 +103,7 @@ def test_small_bark_pt(self): {"audio": ANY(np.ndarray), "sampling_rate": 24000}, outputs, ) + self.assertEqual(len(outputs["audio"].shape), n_ch) # test two examples side-by-side outputs = speech_generator( From 14280926e6453c5875deb5edacf20e9153ea7575 Mon Sep 17 00:00:00 2001 From: Eric B Date: Wed, 26 Nov 2025 18:10:51 +0100 Subject: [PATCH 24/24] Better var --- .../pipelines/test_pipelines_text_to_audio.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index acd4e1de14e1..19c7570e9e9d 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -45,11 +45,11 @@ def test_small_musicgen_pt(self): music_generator = pipeline( task="text-to-audio", model="facebook/musicgen-small", do_sample=False, max_new_tokens=5 ) - n_ch = 1 # model generates mono audio + num_channels = 1 # model generates mono audio outputs = music_generator("This is a test") self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 32000}, outputs) - self.assertEqual(len(outputs["audio"].shape), n_ch) + self.assertEqual(len(outputs["audio"].shape), num_channels) # test two examples side-by-side outputs = music_generator(["This is a test", "This is a second test"]) @@ -90,7 +90,7 @@ def test_medium_seamless_m4t_pt(self): @require_torch def test_small_bark_pt(self): speech_generator = pipeline(task="text-to-audio", model="suno/bark-small") - n_ch = 1 # model generates mono audio + num_channels = 1 # model generates mono audio forward_params = { # Using `do_sample=False` to force deterministic output @@ -103,7 +103,7 @@ def test_small_bark_pt(self): {"audio": ANY(np.ndarray), "sampling_rate": 24000}, outputs, ) - self.assertEqual(len(outputs["audio"].shape), n_ch) + self.assertEqual(len(outputs["audio"].shape), num_channels) # test two examples side-by-side outputs = speech_generator( @@ -252,20 +252,20 @@ def test_generative_model_kwargs(self): def test_csm_model_pt(self): speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device) generate_kwargs = {"max_new_tokens": 10, "output_audio": True} - n_ch = 1 # model generates mono audio + num_channels = 1 # model generates mono audio outputs = speech_generator("[0]This is a test", generate_kwargs=generate_kwargs) self.assertEqual(outputs["sampling_rate"], 24000) audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) # ensure audio and not discrete codes - self.assertEqual(len(audio.shape), n_ch) + self.assertEqual(len(audio.shape), num_channels) # test two examples side-by-side outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - self.assertEqual(len(audio[0].shape), n_ch) + self.assertEqual(len(audio[0].shape), num_channels) # test batching batch_size = 2 @@ -275,14 +275,14 @@ def test_csm_model_pt(self): self.assertEqual(len(outputs), batch_size) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - self.assertEqual(len(outputs[0]["audio"].shape), n_ch) + self.assertEqual(len(outputs[0]["audio"].shape), num_channels) @slow @require_torch def test_dia_model(self): speech_generator = pipeline(task="text-to-audio", model="nari-labs/Dia-1.6B-0626", device=torch_device) generate_kwargs = {"max_new_tokens": 20} - n_ch = 1 # model generates mono audio + num_channels = 1 # model generates mono audio outputs = speech_generator( "[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs @@ -291,7 +291,7 @@ def test_dia_model(self): audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) # ensure audio (with one channel) and not discrete codes - self.assertEqual(len(audio.shape), n_ch) + self.assertEqual(len(audio.shape), num_channels) # test two examples side-by-side outputs = speech_generator( @@ -300,7 +300,7 @@ def test_dia_model(self): ) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - self.assertEqual(len(audio[0].shape), n_ch) + self.assertEqual(len(audio[0].shape), num_channels) # test batching batch_size = 2 @@ -312,7 +312,7 @@ def test_dia_model(self): self.assertEqual(len(outputs), batch_size) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) - self.assertEqual(len(outputs[0]["audio"].shape), n_ch) + self.assertEqual(len(outputs[0]["audio"].shape), num_channels) def get_test_pipeline( self,