diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index fe3afdc7bbd2..2dfc16b7c113 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -651,8 +651,10 @@ def generate( ) # size: 10048 # take the generated semantic tokens - semantic_output = semantic_output[:, max_input_semantic_length + 1 :] - + if kwargs.get("return_dict_in_generate", False): + semantic_output = semantic_output.sequences[:, max_input_semantic_length + 1 :] + else: + semantic_output = semantic_output[:, max_input_semantic_length + 1 :] return semantic_output @@ -865,7 +867,10 @@ def generate( input_coarse_len = input_coarse.shape[1] - x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]]) + if kwargs.get("return_dict_in_generate", False): + x_coarse = torch.hstack([x_coarse, output_coarse.sequences[:, input_coarse_len:]]) + else: + x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]]) total_generated_len = x_coarse.shape[1] - len_coarse_history del output_coarse diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 1635f379c5d2..6721bcfbb67b 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -51,6 +51,7 @@ is_torch_xpu_available, logging, ) +from ..utils.chat_template_utils import Chat, is_valid_message GenericTensor = Union[list["GenericTensor"], "torch.Tensor"] @@ -60,6 +61,7 @@ from torch.utils.data import DataLoader, Dataset from ..modeling_utils import PreTrainedModel + from .pt_utils import KeyDataset else: Dataset = None @@ -1202,6 +1204,18 @@ def __call__(self, inputs, *args, num_workers=None, batch_size=None, **kwargs): if args: logger.warning(f"Ignoring args : {args}") + # Detect if inputs are a chat-style input(s) and cast as `Chat` or list of `Chat` + container_types = (list, tuple, types.GeneratorType) + if is_torch_available(): + container_types = (*container_types, KeyDataset) + if isinstance(inputs, container_types): + if isinstance(inputs, types.GeneratorType): + inputs = list(inputs) + if is_valid_message(inputs[0]): + inputs = Chat(inputs) + elif isinstance(inputs[0], (list, tuple)) and all(chat and is_valid_message(chat[0]) for chat in inputs): + inputs = [Chat(chat) for chat in inputs] + if num_workers is None: if self._num_workers is None: num_workers = 0 diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 537c8dac491e..7d0aa2d47f2d 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -25,6 +25,7 @@ logging, requires_backends, ) +from ..utils.chat_template_utils import Chat from .base import Pipeline, build_pipeline_init_args @@ -49,18 +50,6 @@ class ReturnType(enum.Enum): FULL_TEXT = 2 -class Chat: - """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats - to this format because the rest of the pipeline code tends to assume that lists of messages are - actually a batch of samples rather than messages in the same conversation.""" - - def __init__(self, messages: dict): - for message in messages: - if not ("role" in message and "content" in message): - raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") - self.messages = messages - - @add_end_docstrings(build_pipeline_init_args(has_processor=True)) class ImageTextToTextPipeline(Pipeline): """ diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 7950e6faf2da..dc1e2035393e 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,10 +1,9 @@ import enum -import itertools -import types from typing import Any, overload from ..generation import GenerationConfig from ..utils import ModelOutput, add_end_docstrings, is_torch_available +from ..utils.chat_template_utils import Chat, ChatType from .base import Pipeline, build_pipeline_init_args @@ -12,9 +11,6 @@ import torch from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - from .pt_utils import KeyDataset - -ChatType = list[dict[str, str]] class ReturnType(enum.Enum): @@ -23,18 +19,6 @@ class ReturnType(enum.Enum): FULL_TEXT = 2 -class Chat: - """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats - to this format because the rest of the pipeline code tends to assume that lists of messages are - actually a batch of samples rather than messages in the same conversation.""" - - def __init__(self, messages: dict): - for message in messages: - if not ("role" in message and "content" in message): - raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") - self.messages = messages - - @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True)) class TextGenerationPipeline(Pipeline): """ @@ -263,7 +247,7 @@ def __call__(self, text_inputs, **kwargs): Complete the prompt(s) given as inputs. Args: - text_inputs (`str`, `list[str]`, list[dict[str, str]], or `list[list[dict[str, str]]]`): + text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`): One or several prompts (or one list of prompts) to complete. If strings or a list of string are passed, this pipeline will continue each prompt. Alternatively, a "chat", in the form of a list of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed, @@ -308,27 +292,6 @@ def __call__(self, text_inputs, **kwargs): - **generated_token_ids** (`torch.Tensor`, present when `return_tensors=True`) -- The token ids of the generated text. """ - if isinstance( - text_inputs, - (list, tuple, types.GeneratorType, KeyDataset) - if is_torch_available() - else (list, tuple, types.GeneratorType), - ): - if isinstance(text_inputs, types.GeneratorType): - text_inputs, _ = itertools.tee(text_inputs) - text_inputs, first_item = (x for x in text_inputs), next(_) - else: - first_item = text_inputs[0] - if isinstance(first_item, (list, tuple, dict)): - # We have one or more prompts in list-of-dicts format, so this is chat mode - if isinstance(first_item, dict): - return super().__call__(Chat(text_inputs), **kwargs) - else: - chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈 - if isinstance(text_inputs, types.GeneratorType): - return super().__call__(chats, **kwargs) - else: - return super().__call__(list(chats), **kwargs) return super().__call__(text_inputs, **kwargs) def preprocess( diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py index 99334eff468a..be7a9b9bc0c8 100644 --- a/src/transformers/pipelines/text_to_audio.py +++ b/src/transformers/pipelines/text_to_audio.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License.from typing import List, Union -from typing import Any, overload +from typing import Any, TypedDict, overload + +from ..audio_utils import AudioInput from ..generation import GenerationConfig from ..utils import is_torch_available +from ..utils.chat_template_utils import Chat, ChatType from .base import Pipeline @@ -27,6 +30,18 @@ DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan" +class AudioOutput(TypedDict, total=False): + """ + audio (`AudioInput`): + The generated audio waveform. + sampling_rate (`int`): + The sampling rate of the generated audio waveform. + """ + + audio: AudioInput + sampling_rate: int + + class TextToAudioPipeline(Pipeline): """ Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This @@ -81,7 +96,7 @@ class TextToAudioPipeline(Pipeline): """ _pipeline_calls_generate = True - _load_processor = False + _load_processor = None # prioritize processors as some models require it _load_image_processor = False _load_feature_extractor = False _load_tokenizer = True @@ -91,12 +106,9 @@ class TextToAudioPipeline(Pipeline): max_new_tokens=256, ) - def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs): + def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs): super().__init__(*args, **kwargs) - # Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time - self.no_processor = no_processor - self.vocoder = None if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values(): self.vocoder = ( @@ -105,6 +117,10 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, * else vocoder ) + if self.model.config.model_type in ["musicgen"]: + # MusicGen expect to use the tokenizer + self.processor = None + self.sampling_rate = sampling_rate if self.vocoder is not None: self.sampling_rate = self.vocoder.config.sampling_rate @@ -127,7 +143,7 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, * self.sampling_rate = sampling_rate # last fallback to get the sampling rate based on processor - if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"): + if self.sampling_rate is None and self.processor is not None and hasattr(self.processor, "feature_extractor"): self.sampling_rate = self.processor.feature_extractor.sampling_rate def preprocess(self, text, **kwargs): @@ -141,16 +157,22 @@ def preprocess(self, text, **kwargs): "add_special_tokens": False, "return_attention_mask": True, "return_token_type_ids": False, - "padding": "max_length", } # priority is given to kwargs new_kwargs.update(kwargs) - kwargs = new_kwargs - preprocessor = self.tokenizer if self.no_processor else self.processor - output = preprocessor(text, **kwargs, return_tensors="pt") + preprocessor = self.processor if self.processor is not None else self.tokenizer + if isinstance(text, Chat): + output = preprocessor.apply_chat_template( + text.messages, + tokenize=True, + return_dict=True, + **kwargs, + ) + else: + output = preprocessor(text, **kwargs, return_tensors="pt") return output @@ -171,6 +193,9 @@ def _forward(self, model_inputs, **kwargs): # generate_kwargs get priority over forward_params forward_params.update(generate_kwargs) + # ensure dict output to facilitate postprocessing + forward_params.update({"return_dict_in_generate": True}) + output = self.model.generate(**model_inputs, **forward_params) else: if len(generate_kwargs): @@ -188,18 +213,27 @@ def _forward(self, model_inputs, **kwargs): return output @overload - def __call__(self, text_inputs: str, **forward_params: Any) -> dict[str, Any]: ... + def __call__(self, text_inputs: str, **forward_params: Any) -> AudioOutput: ... + + @overload + def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[AudioOutput]: ... + + @overload + def __call__(self, text_inputs: ChatType, **forward_params: Any) -> AudioOutput: ... @overload - def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[dict[str, Any]]: ... + def __call__(self, text_inputs: list[ChatType], **forward_params: Any) -> list[AudioOutput]: ... - def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, Any] | list[dict[str, Any]]: + def __call__(self, text_inputs, **forward_params): """ Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information. Args: - text_inputs (`str` or `list[str]`): - The text(s) to generate. + text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`): + One or several texts to generate. If strings or a list of string are passed, this pipeline will + generate the corresponding text. Alternatively, a "chat", in the form of a list of dicts with "role" + and "content" keys, can be passed, or a list of such chats. When chats are passed, the model's chat + template will be used to format them before passing them to the model. forward_params (`dict`, *optional*): Parameters passed to the model generation/forward method. `forward_params` are always passed to the underlying model. @@ -210,7 +244,7 @@ def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, only passed to the underlying model if the latter is a generative model. Return: - A `dict` or a list of `dict`: The dictionaries have two keys: + `AudioOutput` or a list of `AudioOutput`, which is a `TypedDict` with two keys: - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform. - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform. @@ -241,29 +275,26 @@ def _sanitize_parameters( return preprocess_params, params, postprocess_params def postprocess(self, audio): - output_dict = {} - - if self.model.config.model_type == "csm": - waveform_key = "audio" - else: - waveform_key = "waveform" - - # We directly get the waveform - if self.no_processor: - if isinstance(audio, dict): - waveform = audio[waveform_key] - elif isinstance(audio, tuple): - waveform = audio[0] + needs_decoding = False + if isinstance(audio, dict): + if "audio" in audio: + audio = audio["audio"] else: - waveform = audio - # Or we need to postprocess to get the waveform - else: - waveform = self.processor.decode(audio) + needs_decoding = True + audio = audio["sequences"] + elif isinstance(audio, tuple): + audio = audio[0] + + if needs_decoding and self.processor is not None: + audio = self.processor.decode(audio) if isinstance(audio, list): - output_dict["audio"] = [el.to(device="cpu", dtype=torch.float).numpy() for el in waveform] + audio = [el.to(device="cpu", dtype=torch.float).numpy().squeeze() for el in audio] + audio = audio if len(audio) > 1 else audio[0] else: - output_dict["audio"] = waveform.to(device="cpu", dtype=torch.float).numpy() - output_dict["sampling_rate"] = self.sampling_rate + audio = audio.to(device="cpu", dtype=torch.float).numpy().squeeze() - return output_dict + return AudioOutput( + audio=audio, + sampling_rate=self.sampling_rate, + ) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index e07228163c2b..ed3a6daee73e 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -53,6 +53,9 @@ from torch import Tensor +ChatType = list[dict[str, Any]] + + BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) # Extracts the initial segment of the docstring, containing the function description description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL) @@ -459,9 +462,9 @@ def strftime_now(format): def render_jinja_template( - conversations: list[list[dict[str, str]]], + conversations: list[ChatType], tools: list[dict | Callable] | None = None, - documents: list[dict[str, str]] | None = None, + documents: ChatType | None = None, chat_template: str | None = None, return_assistant_tokens_mask: bool = False, continue_final_message: bool = False, @@ -558,3 +561,26 @@ def render_jinja_template( rendered.append(rendered_chat) return rendered, all_generation_indices + + +def is_valid_message(message): + """ + Check that input is a valid message in a chat, namely a dict with "role" and "content" keys. + """ + if not isinstance(message, dict): + return False + if not ("role" in message and "content" in message): + return False + return True + + +class Chat: + """This class is intended to just be used internally for pipelines and not exposed to users. We convert chats + to this format because the rest of the pipeline code tends to assume that lists of messages are + actually a batch of samples rather than messages in the same conversation.""" + + def __init__(self, messages: dict): + for message in messages: + if not is_valid_message(message): + raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") + self.messages = messages diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 4f7aa91c5094..e814b16eb19e 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -220,8 +220,6 @@ def __getitem__(self, i): @require_torch def test_small_chat_model_with_iterator_pt(self): - from transformers.pipelines.pt_utils import PipelineIterator - text_generator = pipeline( task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", @@ -253,7 +251,6 @@ def data(): yield from [chat1, chat2] outputs = text_generator(data(), do_sample=False, max_new_tokens=10) - assert isinstance(outputs, PipelineIterator) outputs = list(outputs) self.assertEqual( outputs, diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py index c13d0830c6e6..19c7570e9e9d 100644 --- a/tests/pipelines/test_pipelines_text_to_audio.py +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -45,9 +45,11 @@ def test_small_musicgen_pt(self): music_generator = pipeline( task="text-to-audio", model="facebook/musicgen-small", do_sample=False, max_new_tokens=5 ) + num_channels = 1 # model generates mono audio outputs = music_generator("This is a test") self.assertEqual({"audio": ANY(np.ndarray), "sampling_rate": 32000}, outputs) + self.assertEqual(len(outputs["audio"].shape), num_channels) # test two examples side-by-side outputs = music_generator(["This is a test", "This is a second test"]) @@ -88,6 +90,7 @@ def test_medium_seamless_m4t_pt(self): @require_torch def test_small_bark_pt(self): speech_generator = pipeline(task="text-to-audio", model="suno/bark-small") + num_channels = 1 # model generates mono audio forward_params = { # Using `do_sample=False` to force deterministic output @@ -100,6 +103,7 @@ def test_small_bark_pt(self): {"audio": ANY(np.ndarray), "sampling_rate": 24000}, outputs, ) + self.assertEqual(len(outputs["audio"].shape), num_channels) # test two examples side-by-side outputs = speech_generator( @@ -151,7 +155,6 @@ def test_conversion_additional_tensor(self): "add_special_tokens": False, "return_attention_mask": True, "return_token_type_ids": False, - "padding": "max_length", } outputs = speech_generator( "This is a test", @@ -247,22 +250,69 @@ def test_generative_model_kwargs(self): @slow @require_torch def test_csm_model_pt(self): - speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b") + speech_generator = pipeline(task="text-to-audio", model="sesame/csm-1b", device=torch_device) + generate_kwargs = {"max_new_tokens": 10, "output_audio": True} + num_channels = 1 # model generates mono audio - outputs = speech_generator("[0]This is a test") + outputs = speech_generator("[0]This is a test", generate_kwargs=generate_kwargs) self.assertEqual(outputs["sampling_rate"], 24000) + audio = outputs["audio"] + self.assertEqual(ANY(np.ndarray), audio) + # ensure audio and not discrete codes + self.assertEqual(len(audio.shape), num_channels) + # test two examples side-by-side + outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + self.assertEqual(len(audio[0].shape), num_channels) + + # test batching + batch_size = 2 + outputs = speech_generator( + ["[0]This is a test", "[0]This is a second test"], generate_kwargs=generate_kwargs, batch_size=batch_size + ) + self.assertEqual(len(outputs), batch_size) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + self.assertEqual(len(outputs[0]["audio"].shape), num_channels) + + @slow + @require_torch + def test_dia_model(self): + speech_generator = pipeline(task="text-to-audio", model="nari-labs/Dia-1.6B-0626", device=torch_device) + generate_kwargs = {"max_new_tokens": 20} + num_channels = 1 # model generates mono audio + + outputs = speech_generator( + "[S1] Dia is an open weights text to dialogue model.", generate_kwargs=generate_kwargs + ) + self.assertEqual(outputs["sampling_rate"], 44100) audio = outputs["audio"] self.assertEqual(ANY(np.ndarray), audio) + # ensure audio (with one channel) and not discrete codes + self.assertEqual(len(audio.shape), num_channels) # test two examples side-by-side - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"]) + outputs = speech_generator( + ["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], + generate_kwargs=generate_kwargs, + ) audio = [output["audio"] for output in outputs] self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + self.assertEqual(len(audio[0].shape), num_channels) # test batching - outputs = speech_generator(["[0]This is a test", "[0]This is a second test"], batch_size=2) - self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) + batch_size = 2 + outputs = speech_generator( + ["[S1] Dia is an open weights text to dialogue model.", "[S2] This is a second example."], + generate_kwargs=generate_kwargs, + batch_size=2, + ) + self.assertEqual(len(outputs), batch_size) + audio = [output["audio"] for output in outputs] + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + self.assertEqual(len(outputs[0]["audio"].shape), num_channels) def get_test_pipeline( self,