-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Fix processor usage + add chat_template support to TTS pipeline, and shift common chat template logic to base class. #42326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
d4f89a0
f4dbca1
25f4816
29185d4
27a56cf
a654122
13234d9
729844b
bb96d9b
7da8240
f801b96
716c3cf
c0060d1
f8e90cf
f5fb635
ed25458
97158c6
5b80b03
0b17759
838daf7
e3b882d
661cafb
94410f4
79d7e73
7a438ff
1428092
4fde390
80364e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,13 +11,20 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License.from typing import List, Union | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| import itertools | ||||||||||||||||||||||||||||||||||||||||||||||
| import types | ||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, overload | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from ..generation import GenerationConfig | ||||||||||||||||||||||||||||||||||||||||||||||
| from ..utils import is_torch_available | ||||||||||||||||||||||||||||||||||||||||||||||
| from ..utils.chat_template_utils import Chat, ChatType | ||||||||||||||||||||||||||||||||||||||||||||||
| from .base import Pipeline | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| AudioOutput = dict[str, Any] # {"audio": np.ndarray, "sampling_rate": int} | ||||||||||||||||||||||||||||||||||||||||||||||
ebezzam marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if is_torch_available(): | ||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -81,7 +88,7 @@ class TextToAudioPipeline(Pipeline): | |||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| _pipeline_calls_generate = True | ||||||||||||||||||||||||||||||||||||||||||||||
| _load_processor = False | ||||||||||||||||||||||||||||||||||||||||||||||
| _load_processor = None # prioritize processors as some models require it | ||||||||||||||||||||||||||||||||||||||||||||||
| _load_image_processor = False | ||||||||||||||||||||||||||||||||||||||||||||||
| _load_feature_extractor = False | ||||||||||||||||||||||||||||||||||||||||||||||
| _load_tokenizer = True | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -91,12 +98,9 @@ class TextToAudioPipeline(Pipeline): | |||||||||||||||||||||||||||||||||||||||||||||
| max_new_tokens=256, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time | ||||||||||||||||||||||||||||||||||||||||||||||
| self.no_processor = no_processor | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.vocoder = None | ||||||||||||||||||||||||||||||||||||||||||||||
| if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values(): | ||||||||||||||||||||||||||||||||||||||||||||||
| self.vocoder = ( | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -105,6 +109,10 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, * | |||||||||||||||||||||||||||||||||||||||||||||
| else vocoder | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if self.model.config.model_type in ["musicgen"]: | ||||||||||||||||||||||||||||||||||||||||||||||
| # MusicGen expect to use the tokenizer | ||||||||||||||||||||||||||||||||||||||||||||||
| self.processor = None | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.sampling_rate = sampling_rate | ||||||||||||||||||||||||||||||||||||||||||||||
| if self.vocoder is not None: | ||||||||||||||||||||||||||||||||||||||||||||||
| self.sampling_rate = self.vocoder.config.sampling_rate | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -127,7 +135,7 @@ def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, * | |||||||||||||||||||||||||||||||||||||||||||||
| self.sampling_rate = sampling_rate | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # last fallback to get the sampling rate based on processor | ||||||||||||||||||||||||||||||||||||||||||||||
| if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"): | ||||||||||||||||||||||||||||||||||||||||||||||
| if self.sampling_rate is None and self.processor is not None and hasattr(self.processor, "feature_extractor"): | ||||||||||||||||||||||||||||||||||||||||||||||
| self.sampling_rate = self.processor.feature_extractor.sampling_rate | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def preprocess(self, text, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -141,16 +149,22 @@ def preprocess(self, text, **kwargs): | |||||||||||||||||||||||||||||||||||||||||||||
| "add_special_tokens": False, | ||||||||||||||||||||||||||||||||||||||||||||||
| "return_attention_mask": True, | ||||||||||||||||||||||||||||||||||||||||||||||
| "return_token_type_ids": False, | ||||||||||||||||||||||||||||||||||||||||||||||
| "padding": "max_length", | ||||||||||||||||||||||||||||||||||||||||||||||
ebezzam marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # priority is given to kwargs | ||||||||||||||||||||||||||||||||||||||||||||||
| new_kwargs.update(kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = new_kwargs | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| preprocessor = self.tokenizer if self.no_processor else self.processor | ||||||||||||||||||||||||||||||||||||||||||||||
| output = preprocessor(text, **kwargs, return_tensors="pt") | ||||||||||||||||||||||||||||||||||||||||||||||
| preprocessor = self.processor if self.processor is not None else self.tokenizer | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(text, Chat): | ||||||||||||||||||||||||||||||||||||||||||||||
| output = preprocessor.apply_chat_template( | ||||||||||||||||||||||||||||||||||||||||||||||
| text.messages, | ||||||||||||||||||||||||||||||||||||||||||||||
| tokenize=True, | ||||||||||||||||||||||||||||||||||||||||||||||
| return_dict=True, | ||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| output = preprocessor(text, **kwargs, return_tensors="pt") | ||||||||||||||||||||||||||||||||||||||||||||||
vasqu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -188,18 +202,27 @@ def _forward(self, model_inputs, **kwargs): | |||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @overload | ||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, text_inputs: str, **forward_params: Any) -> dict[str, Any]: ... | ||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, text_inputs: str, **forward_params: Any) -> AudioOutput: ... | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @overload | ||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[dict[str, Any]]: ... | ||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[AudioOutput]: ... | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, Any] | list[dict[str, Any]]: | ||||||||||||||||||||||||||||||||||||||||||||||
| @overload | ||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, text_inputs: ChatType, **forward_params: Any) -> AudioOutput: ... | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @overload | ||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, text_inputs: list[ChatType], **forward_params: Any) -> list[AudioOutput]: ... | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, text_inputs, **forward_params): | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||
| text_inputs (`str` or `list[str]`): | ||||||||||||||||||||||||||||||||||||||||||||||
| The text(s) to generate. | ||||||||||||||||||||||||||||||||||||||||||||||
| text_inputs (`str`, `list[str]`, `ChatType`, or `list[ChatType]`): | ||||||||||||||||||||||||||||||||||||||||||||||
| One or several texts to generate. If strings or a list of string are passed, this pipeline will | ||||||||||||||||||||||||||||||||||||||||||||||
| generate the corresponding text. Alternatively, a "chat", in the form of a list of dicts with "role" | ||||||||||||||||||||||||||||||||||||||||||||||
| and "content" keys, can be passed, or a list of such chats. When chats are passed, the model's chat | ||||||||||||||||||||||||||||||||||||||||||||||
| template will be used to format them before passing them to the model. | ||||||||||||||||||||||||||||||||||||||||||||||
ebezzam marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
| forward_params (`dict`, *optional*): | ||||||||||||||||||||||||||||||||||||||||||||||
| Parameters passed to the model generation/forward method. `forward_params` are always passed to the | ||||||||||||||||||||||||||||||||||||||||||||||
| underlying model. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -215,6 +238,23 @@ def __call__(self, text_inputs: str | list[str], **forward_params) -> dict[str, | |||||||||||||||||||||||||||||||||||||||||||||
| - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform. | ||||||||||||||||||||||||||||||||||||||||||||||
| - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform. | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(text_inputs, (list, tuple, types.GeneratorType)): | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(text_inputs, types.GeneratorType): | ||||||||||||||||||||||||||||||||||||||||||||||
| text_inputs, _ = itertools.tee(text_inputs) | ||||||||||||||||||||||||||||||||||||||||||||||
| text_inputs, first_item = (x for x in text_inputs), next(_) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| first_item = text_inputs[0] | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(first_item, (list, tuple, dict)): | ||||||||||||||||||||||||||||||||||||||||||||||
| # We have one or more prompts in list-of-dicts format, so this is chat mode | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(first_item, dict): | ||||||||||||||||||||||||||||||||||||||||||||||
| return super().__call__(Chat(text_inputs), **forward_params) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| chats = (Chat(chat) for chat in text_inputs) | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(text_inputs, types.GeneratorType): | ||||||||||||||||||||||||||||||||||||||||||||||
| return super().__call__(chats, **forward_params) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| return super().__call__(list(chats), **forward_params) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _is_chat(arg): | |
| return isinstance(arg, (list, tuple, KeyDataset)) and isinstance(arg[0], (list, tuple, dict)) | |
| if _is_chat(text): | |
| # We have one or more prompts in list-of-dicts format, so this is chat mode | |
| if isinstance(text[0], dict): | |
| return super().__call__(Chat(text, images), **kwargs) | |
| else: | |
| if images is None: | |
| images = [None] * len(text) | |
| chats = [Chat(chat, image) for chat, image in zip(text, images)] # 🐈 🐈 🐈 | |
| return super().__call__(chats, **kwargs) | |
| # Same as above, but the `images` argument contains the chat. This can happen e.g. is the user only passes a | |
| # chat as a positional argument. | |
| elif text is None and _is_chat(images): | |
| # We have one or more prompts in list-of-dicts format, so this is chat mode | |
| if isinstance(images[0], dict): | |
| return super().__call__(Chat(images), **kwargs) | |
| else: | |
| chats = [Chat(image) for image in images] # 🐈 🐈 🐈 | |
| return super().__call__(chats, **kwargs) |
We cooking all our own soup 😢
Only thing I'd change would be avoid using a wildcard _ and give an explicit name instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed on Slack, the __call__ logic of image-text-to-text may be more complicated because they allow users to pass images as separate arguments rather than keeping everything within the chat template?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we were breaking it tho for v5? Or did I misunderstand something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What they are planning to break for v5 in image-text-to-text is removing inputs like images for pipeline, so that users stick to chat template. See #42359
But perhaps the __call__ logic could still be simplified and shifted to base.py. Let me see...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was breaking for CSM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Iirc it was for Dia but it has been broken way too often.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reopening because I want to verify that Dia works with this current version - I'm pretty sure we need the processor to decode for Dia which is why I wrote the initial long message on how we plan to standardize
- Everything handled by the model, audio tokenizer within it already
- Separate model / tokenizer, processor handles encoding/decoding into codebooks/waveform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, Dia does need the processor for decoding and I'll also add a unit test for Dia so we don't miss this in the future.
However, I feel like a blanket self.processor.decode might be too broad. For example, CSM and VibeVoice don't require the processor to decode.
Since there is not standard approach (yet), how about something like below (which is working):
if isinstance(audio, dict):
waveform = audio[waveform_key]
elif isinstance(audio, tuple):
waveform = audio[0]
elif self.model.config.model_type in ["dia"]:
# models that require decoding, e.g. with codec
waveform = self.processor.decode(audio)
else:
waveform = audioThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example usage:
from transformers import pipeline
import soundfile as sf
dia_pipeline = pipeline(
"text-to-audio", model=model_checkpoint,
)
outputs = dia_pipeline(
"[S1] Dia is an open weights text to dialogue model.",
generate_kwargs={"max_new_tokens": 512},
)
assert outputs["sampling_rate"] == 44100
audio = outputs["audio"].squeeze()
fn = "dia_pipeline_output.wav"
sf.write(fn, audio, outputs["sampling_rate"])
print(f"Audio saved to {fn}")I'm reluctant to allow voice cloning through pipeline, as this would require passing an audios input to pipeline (since Dia doesn't support chat templates).
Moreover, allowing inputs like audios is exactly what they are trying to phase out with image-text-to-text in #42359 (to only support chat template usage).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point on voice cloning, we should maybe update Dia with a chat template in the future. I did not have that in mind at that point, that's on me.
Re: standards. Yea, we have no choice atm - it's more of a question on how we handle this in the future
Uh oh!
There was an error while loading. Please reload this page.