Skip to content

Conversation

@hengtaoguo
Copy link
Collaborator

@hengtaoguo hengtaoguo commented Nov 6, 2025

Description

  • Add image/video/audio preprocessing utils for Qwen3-Omni in MaxText.multimodal.qwen3_omni_preprocessor.preprocess_mm_data_qwen3_omni(), returning dataclass Qwen3OmniPreprocessorOutput containing all preprocessed data (pixel_values, pixel_grid_thw, video_values, video_grid_thw, video_second_per_grid, audio_values, audio_mask).
  • Add unit test comparing MaxText implementation with Qwen3-Omni's processor on HuggingFace.
  • [WIP] Refactor [multimodal_utils.py]:
    • MaxText.multimodal.utils: Commonly used basic functions such as image loading and normalization.
    • MaxText.multimodal.{MODEL}_preprocessor.py: Model-specific preprocessing utils.
    • MaxText.multimodal.preprocessor.py: Centralized function preprocess_mm_data() will route to model-specific preprocessing logics based on model name.

Tests

Passing unit tests for MaxText preprocess_mm_data_qwen3_omni vs HuggingFace Qwen3OmniMoeProcessor:

python -m unittest tests.check_qwen3_embedding_vs_reference.TextQwen3OmniPreprocessing

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@hengtaoguo hengtaoguo force-pushed the hengtaoguo-pre branch 2 times, most recently from e6ff3dd to 674d9d9 Compare November 13, 2025 21:55
@hengtaoguo hengtaoguo force-pushed the hengtaoguo-pre branch 3 times, most recently from abd6cf5 to 71ba0b8 Compare November 19, 2025 06:14
@eitanporat
Copy link
Collaborator

is the functionality implemented on cpu in numpy in the torch variant. if so, is there a reason not to want to reuse it?

@eitanporat
Copy link
Collaborator

could you add the new requirements to the pyproject toml (decord and librosa)?

Copy link
Collaborator

@eitanporat eitanporat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @hengtaoguo, I left some comments

image_masks=image_masks,
)
# TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed
elif cfg.model_name in ["qwen3-omni-30b-a3b"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reasoning behind this change?

if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:

images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
processor_outputs = multimodal_utils.pre_process_gemma3_image(images)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename the functions to preprocess_mm_data_gemma3 ?

maybe it would be better to use a factory pattern here

elif config.model_name in ["qwen3-omni-30b-a3b"]:
from MaxText.multimodal.qwen3_omni_processor import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel

processor_outputs = preprocess_mm_data_qwen3_omni(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it accept a config?

FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should move to the config

max_pixels=max_pixels,
)

with jax.default_device(jax.devices("cpu")[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this necessary i am curious.
could it support multiple cpus for example?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question. Could we replace jax.image.resize with other alternatives to avoid jax grabbing TPU for preprocessing. Perhaps we can apply the resize function from PIL on video frames? We did the same for gemma3 (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/multimodal_utils.py#L359-L367) and llama4

The normalized images.
"""
images -= np.asarray(mean)
images /= np.asarray(std)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why it's called mean std. as this isn't used as mean as std.

it transforms the mean and std in the following way. if mean[images]=m and std[images]=s then
if x = normalize_images(images, mean, std), mean[x] = (m-mean)/std and std[x] = s/std

"""
if device != "cpu":
raise ValueError(
f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess this was copied from torch but this comment should be changed.

raise ValueError("db_range must be greater than zero")
spectrogram_array = np.clip(spectrogram_array, a_min=spectrogram_array.max() - db_range, a_max=None)

return spectrogram
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably have an error here... it will err

mt_audio, mt_audio_mask = pre_process_audio_qwen3_omni(mt_audio)
processor_outputs.audio_values = mt_audio
processor_outputs.audio_mask = mt_audio_mask

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will break if you want audio and not video!

processor_outputs.pixel_grid_thw = pixel_grid_thw
processor_outputs.num_images = len(images)

if config.video_path is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did it work for you? in the base.yml i see config.video_path = "" so this doesn't actually work for me.

@github-actions
Copy link

🤖 Hi @hengtaoguo, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request introduces comprehensive preprocessing utilities for Qwen3-Omni multimodal data, including image, video, and audio. The changes involve adding new configuration parameters, a new centralized preprocessor, and model-specific preprocessing logic with corresponding unit tests. The new test coverage for Qwen3-Omni preprocessing is good, and the overall structure for multimodal handling is improving.

🔍 General Feedback

  • The refactoring aims to deprecate the old multimodal_utils.py in favor of a new MaxText.multimodal.utils.py and a centralized preprocessor.py. While the direction is positive, some remnants of the old multimodal_utils.py are still in use, and there are duplicate dataclass definitions. A clearer plan for complete deprecation and migration would be beneficial.
  • The Qwen3-Omni embedding in decoders.py is currently a placeholder (pass), indicating future work is needed for full integration.
  • Specific implementation details in the qwen3_omni_processor.py, such as hardcoded temporal dimensions for image processing, could benefit from further clarification or configurability.

mask=bidirectional_mask,
image_masks=image_masks,
)
# TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 This pass statement indicates that multimodal embedding for qwen3-omni-30b-a3b is not yet implemented. While it unblocks the current integration, it's crucial to implement the actual embedding logic for this model to ensure full multimodal functionality. Consider creating a follow-up task to address this.

"""Multimodal data preprocessor router."""

from MaxText import multimodal_utils # TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The TODO comment on line 18 indicates a future refactoring to deprecate MaxText/multimodal_utils.py. It would be beneficial to have a clear plan or follow-up issue for this refactoring to ensure multimodal_utils.py is eventually removed and its relevant functions are moved to MaxText/multimodal/utils.py to keep the codebase clean and organized.


images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
processor_outputs = multimodal_utils.pre_process_gemma3_image(images)
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The PreprocessorOutput is imported from MaxText.multimodal_utils. Since a new MaxText.multimodal.utils has been introduced, it would be more consistent to use PreprocessorOutput from MaxText.multimodal.utils instead of the old multimodal_utils. This also aligns with the TODO to deprecate multimodal_utils.py.

Suggested change
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
from MaxText.multimodal import utils as mm_utils
from MaxText.multimodal.qwen3_omni_processor import Qwen3OmniPreprocessorOutput # To resolve a potential circular dependency
# TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py
def preprocess_mm_data(config):
"""Preprocesses multimodal data based on the provided configuration.
Routes to the appropriate preprocessing function based on the model name.
Args:
config: A `pyconfig.Config` object containing configuration parameters.
Returns:
A `PreprocessorOutput` object containing the processed multimodal data.
"""
processor_outputs = mm_utils.PreprocessorOutput() # Using the new utils


images_in = [image] if isinstance(image, np.ndarray) else image
images_out = []
grids_thw = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 In pre_process_qwen3_image, the line grid_t = 2 // temporal_patch_size hardcodes the initial temporal dimension to 2. It's unclear why 2 is chosen as the dividend here without further explanation. This might limit flexibility or be a potential source of error if temporal_patch_size is greater than 2, resulting in grid_t becoming 0.

Consider adding a comment explaining the rationale behind this fixed value, or make it configurable if different temporal dimensions are expected for image processing in the future.

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
mel_filters *= np.expand_dims(enorm, 0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 In preprocess_mm_data_qwen3_omni, the load_image_from_path function is called from mm_utils (which refers to MaxText.multimodal_utils). To align with the ongoing refactoring, this should be updated to use the load_image_from_path function from the new MaxText.multimodal.utils.

Suggested change
if config.image_path is not None:
images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")]
pixel_values, pixel_grid_thw = pre_process_qwen3_image(images, config)
processor_outputs.pixel_values = pixel_values
processor_outputs.pixel_grid_thw = pixel_grid_thw
processor_outputs.num_images = len(images)

pixel_values: None | np.ndarray = None
pixel_mask: None | np.ndarray = None
aspect_ratios: None | np.ndarray = None
num_images: int = 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The PreprocessorOutput dataclass in src/MaxText/multimodal_utils.py seems to be a duplicate of the one introduced in src/MaxText/multimodal/utils.py. Given the TODO to deprecate multimodal_utils.py, it would be cleaner to remove this duplicated PreprocessorOutput and ensure all parts of the codebase use MaxText.multimodal.utils.PreprocessorOutput.

pixel_values=image_tiles,
pixel_mask=image_mask,
aspect_ratios=aspect_ratios_array,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The pre_process_image function in MaxText/multimodal_utils.py now accepts an optional config parameter, but its functionality is largely superseded by MaxText.multimodal.preprocessor.preprocess_mm_data. As part of the refactoring to deprecate multimodal_utils.py, consider either removing this function entirely if it's no longer needed, or clearly defining its role and ensuring it delegates to the new preprocessor if it must remain for backward compatibility during the transition.

@hengtaoguo
Copy link
Collaborator Author

is the functionality implemented on cpu in numpy in the torch variant. if so, is there a reason not to want to reuse it?

This has been a long-standing constraint, we intentionally exclude torch from our dependency. So we cannot use torch resize functions and need to reimplement everything in numpy/jnp.

max_pixels=max_pixels,
)

with jax.default_device(jax.devices("cpu")[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question. Could we replace jax.image.resize with other alternatives to avoid jax grabbing TPU for preprocessing. Perhaps we can apply the resize function from PIL on video frames? We did the same for gemma3 (https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/multimodal_utils.py#L359-L367) and llama4

IMAGE_MEAN = 127.5
IMAGE_STD = 127.5
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comments to explain the constants? Same for video constants. It's not easy to tell from name.



def pre_process_image(image, model_name):
def pre_process_image(image, model_name, config=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is config needed here?

"video_start": video_start,
"video_end": video_end,
}
vr = decord.VideoReader(video_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add try/except to handle errors?

Returns:
np.ndarray: The loaded audio waveform.
"""
audio = librosa.load(data_path, sr=SAMPLE_RATE)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add try/except to handle errors?

processor_outputs.num_videos = 1 # Only one video for now.

if config.audio_path is not None or (config.video_path is not None and config.use_audio_in_video):
mt_audio = _load_audio(config.video_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be _load_audio(config.audio_path) when audio_path is provided?

@eitanporat eitanporat mentioned this pull request Nov 20, 2025
4 tasks
images_out.append(img_np)
grids_thw.append(img_grid_thw)

return images_out[0][0, :, :], grids_thw[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only return the first item here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants