-
Notifications
You must be signed in to change notification settings - Fork 433
Add preprocessing utils for Qwen3-Omni #2613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
e6ff3dd to
674d9d9
Compare
abd6cf5 to
71ba0b8
Compare
df87150 to
cc08180
Compare
|
is the functionality implemented on cpu in numpy in the torch variant. if so, is there a reason not to want to reuse it? |
|
could you add the new requirements to the pyproject toml ( |
eitanporat
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @hengtaoguo, I left some comments
src/MaxText/layers/decoders.py
Outdated
| image_masks=image_masks, | ||
| ) | ||
| # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed | ||
| elif cfg.model_name in ["qwen3-omni-30b-a3b"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the reasoning behind this change?
| if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]: | ||
|
|
||
| images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")] | ||
| processor_outputs = multimodal_utils.pre_process_gemma3_image(images) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe rename the functions to preprocess_mm_data_gemma3 ?
maybe it would be better to use a factory pattern here
| elif config.model_name in ["qwen3-omni-30b-a3b"]: | ||
| from MaxText.multimodal.qwen3_omni_processor import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel | ||
|
|
||
| processor_outputs = preprocess_mm_data_qwen3_omni(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does it accept a config?
| FRAME_FACTOR = 2 | ||
| FPS = 2.0 | ||
| FPS_MIN_FRAMES = 4 | ||
| FPS_MAX_FRAMES = 768 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should move to the config
| max_pixels=max_pixels, | ||
| ) | ||
|
|
||
| with jax.default_device(jax.devices("cpu")[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this necessary i am curious.
could it support multiple cpus for example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same question. Could we replace jax.image.resize with other alternatives to avoid jax grabbing TPU for preprocessing. Perhaps we can apply the resize function from PIL on video frames? We did the same for gemma3 (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/multimodal_utils.py#L359-L367) and llama4
| The normalized images. | ||
| """ | ||
| images -= np.asarray(mean) | ||
| images /= np.asarray(std) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why it's called mean std. as this isn't used as mean as std.
it transforms the mean and std in the following way. if mean[images]=m and std[images]=s then
if x = normalize_images(images, mean, std), mean[x] = (m-mean)/std and std[x] = s/std
| """ | ||
| if device != "cpu": | ||
| raise ValueError( | ||
| f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess this was copied from torch but this comment should be changed.
| raise ValueError("db_range must be greater than zero") | ||
| spectrogram_array = np.clip(spectrogram_array, a_min=spectrogram_array.max() - db_range, a_max=None) | ||
|
|
||
| return spectrogram |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably have an error here... it will err
| mt_audio, mt_audio_mask = pre_process_audio_qwen3_omni(mt_audio) | ||
| processor_outputs.audio_values = mt_audio | ||
| processor_outputs.audio_mask = mt_audio_mask | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will break if you want audio and not video!
| processor_outputs.pixel_grid_thw = pixel_grid_thw | ||
| processor_outputs.num_images = len(images) | ||
|
|
||
| if config.video_path is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did it work for you? in the base.yml i see config.video_path = "" so this doesn't actually work for me.
|
🤖 Hi @hengtaoguo, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This Pull Request introduces comprehensive preprocessing utilities for Qwen3-Omni multimodal data, including image, video, and audio. The changes involve adding new configuration parameters, a new centralized preprocessor, and model-specific preprocessing logic with corresponding unit tests. The new test coverage for Qwen3-Omni preprocessing is good, and the overall structure for multimodal handling is improving.
🔍 General Feedback
- The refactoring aims to deprecate the old
multimodal_utils.pyin favor of a newMaxText.multimodal.utils.pyand a centralizedpreprocessor.py. While the direction is positive, some remnants of the oldmultimodal_utils.pyare still in use, and there are duplicate dataclass definitions. A clearer plan for complete deprecation and migration would be beneficial. - The Qwen3-Omni embedding in
decoders.pyis currently a placeholder (pass), indicating future work is needed for full integration. - Specific implementation details in the
qwen3_omni_processor.py, such as hardcoded temporal dimensions for image processing, could benefit from further clarification or configurability.
| mask=bidirectional_mask, | ||
| image_masks=image_masks, | ||
| ) | ||
| # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 This pass statement indicates that multimodal embedding for qwen3-omni-30b-a3b is not yet implemented. While it unblocks the current integration, it's crucial to implement the actual embedding logic for this model to ensure full multimodal functionality. Consider creating a follow-up task to address this.
| """Multimodal data preprocessor router.""" | ||
|
|
||
| from MaxText import multimodal_utils # TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 The TODO comment on line 18 indicates a future refactoring to deprecate MaxText/multimodal_utils.py. It would be beneficial to have a clear plan or follow-up issue for this refactoring to ensure multimodal_utils.py is eventually removed and its relevant functions are moved to MaxText/multimodal/utils.py to keep the codebase clean and organized.
|
|
||
| images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")] | ||
| processor_outputs = multimodal_utils.pre_process_gemma3_image(images) | ||
| elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 The PreprocessorOutput is imported from MaxText.multimodal_utils. Since a new MaxText.multimodal.utils has been introduced, it would be more consistent to use PreprocessorOutput from MaxText.multimodal.utils instead of the old multimodal_utils. This also aligns with the TODO to deprecate multimodal_utils.py.
| elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]: | |
| from MaxText.multimodal import utils as mm_utils | |
| from MaxText.multimodal.qwen3_omni_processor import Qwen3OmniPreprocessorOutput # To resolve a potential circular dependency | |
| # TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py | |
| def preprocess_mm_data(config): | |
| """Preprocesses multimodal data based on the provided configuration. | |
| Routes to the appropriate preprocessing function based on the model name. | |
| Args: | |
| config: A `pyconfig.Config` object containing configuration parameters. | |
| Returns: | |
| A `PreprocessorOutput` object containing the processed multimodal data. | |
| """ | |
| processor_outputs = mm_utils.PreprocessorOutput() # Using the new utils |
|
|
||
| images_in = [image] if isinstance(image, np.ndarray) else image | ||
| images_out = [] | ||
| grids_thw = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 In pre_process_qwen3_image, the line grid_t = 2 // temporal_patch_size hardcodes the initial temporal dimension to 2. It's unclear why 2 is chosen as the dividend here without further explanation. This might limit flexibility or be a potential source of error if temporal_patch_size is greater than 2, resulting in grid_t becoming 0.
Consider adding a comment explaining the rationale behind this fixed value, or make it configurable if different temporal dimensions are expected for image processing in the future.
| # Slaney-style mel is scaled to be approx constant energy per channel | ||
| enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) | ||
| mel_filters *= np.expand_dims(enorm, 0) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟠 In preprocess_mm_data_qwen3_omni, the load_image_from_path function is called from mm_utils (which refers to MaxText.multimodal_utils). To align with the ongoing refactoring, this should be updated to use the load_image_from_path function from the new MaxText.multimodal.utils.
| if config.image_path is not None: | |
| images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")] | |
| pixel_values, pixel_grid_thw = pre_process_qwen3_image(images, config) | |
| processor_outputs.pixel_values = pixel_values | |
| processor_outputs.pixel_grid_thw = pixel_grid_thw | |
| processor_outputs.num_images = len(images) |
| pixel_values: None | np.ndarray = None | ||
| pixel_mask: None | np.ndarray = None | ||
| aspect_ratios: None | np.ndarray = None | ||
| num_images: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟠 The PreprocessorOutput dataclass in src/MaxText/multimodal_utils.py seems to be a duplicate of the one introduced in src/MaxText/multimodal/utils.py. Given the TODO to deprecate multimodal_utils.py, it would be cleaner to remove this duplicated PreprocessorOutput and ensure all parts of the codebase use MaxText.multimodal.utils.PreprocessorOutput.
| pixel_values=image_tiles, | ||
| pixel_mask=image_mask, | ||
| aspect_ratios=aspect_ratios_array, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 The pre_process_image function in MaxText/multimodal_utils.py now accepts an optional config parameter, but its functionality is largely superseded by MaxText.multimodal.preprocessor.preprocess_mm_data. As part of the refactoring to deprecate multimodal_utils.py, consider either removing this function entirely if it's no longer needed, or clearly defining its role and ensuring it delegates to the new preprocessor if it must remain for backward compatibility during the transition.
This has been a long-standing constraint, we intentionally exclude |
| max_pixels=max_pixels, | ||
| ) | ||
|
|
||
| with jax.default_device(jax.devices("cpu")[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same question. Could we replace jax.image.resize with other alternatives to avoid jax grabbing TPU for preprocessing. Perhaps we can apply the resize function from PIL on video frames? We did the same for gemma3 (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/multimodal_utils.py#L359-L367) and llama4
| IMAGE_MEAN = 127.5 | ||
| IMAGE_STD = 127.5 | ||
| IMAGE_FACTOR = 28 | ||
| MIN_PIXELS = 4 * 28 * 28 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some comments to explain the constants? Same for video constants. It's not easy to tell from name.
src/MaxText/multimodal_utils.py
Outdated
|
|
||
|
|
||
| def pre_process_image(image, model_name): | ||
| def pre_process_image(image, model_name, config=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is config needed here?
| "video_start": video_start, | ||
| "video_end": video_end, | ||
| } | ||
| vr = decord.VideoReader(video_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add try/except to handle errors?
| Returns: | ||
| np.ndarray: The loaded audio waveform. | ||
| """ | ||
| audio = librosa.load(data_path, sr=SAMPLE_RATE)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add try/except to handle errors?
| processor_outputs.num_videos = 1 # Only one video for now. | ||
|
|
||
| if config.audio_path is not None or (config.video_path is not None and config.use_audio_in_video): | ||
| mt_audio = _load_audio(config.video_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be _load_audio(config.audio_path) when audio_path is provided?
| images_out.append(img_np) | ||
| grids_thw.append(img_grid_thw) | ||
|
|
||
| return images_out[0][0, :, :], grids_thw[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only return the first item here?
Description
MaxText.multimodal.qwen3_omni_preprocessor.preprocess_mm_data_qwen3_omni(), returning dataclassQwen3OmniPreprocessorOutputcontaining all preprocessed data (pixel_values,pixel_grid_thw,video_values,video_grid_thw,video_second_per_grid,audio_values,audio_mask).MaxText.multimodal.utils: Commonly used basic functions such as image loading and normalization.MaxText.multimodal.{MODEL}_preprocessor.py: Model-specific preprocessing utils.MaxText.multimodal.preprocessor.py: Centralized functionpreprocess_mm_data()will route to model-specific preprocessing logics based on model name.Tests
Passing unit tests for MaxText
preprocess_mm_data_qwen3_omnivs HuggingFaceQwen3OmniMoeProcessor:Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.