Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,9 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
image_placeholder: "<|image|>"
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
use_audio_in_video: False
posemb_type_for_vit: "learn"
# max_num_images_per_example only applies for training when your image column is a list of images.
# -1 means no limit, and will pad to the max possible number of images determined by sequence length.
Expand Down
12 changes: 11 additions & 1 deletion src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,9 @@ class MultimodalGeneral(BaseModel):
-1,
description="Maximum number of images per example for training with image lists. -1 means no limit.",
)
video_path: PathStr = Field("", description="Path to a video for decoding.")
audio_path: PathStr = Field("", description="Path to an audio file for decoding.")
use_audio_in_video: bool = Field(False, description="Extract and use audio from video files.")


class VisionTower(BaseModel):
Expand Down Expand Up @@ -1850,7 +1853,14 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
if self.use_multimodal:
valid_mm_models = ("gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e")
valid_mm_models = (
"gemma3-4b",
"gemma3-12b",
"gemma3-27b",
"llama4-17b-16e",
"llama4-17b-128e",
"qwen3-omni-30b-a3b",
)
if self.model_name not in valid_mm_models and self.model_name != "default":
raise ValueError(f"Multimodal is only supported for {valid_mm_models}, not {self.model_name}")
if self.use_sft:
Expand Down
10 changes: 6 additions & 4 deletions src/MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from MaxText import pyconfig
from MaxText import profiler
from MaxText import multimodal_utils
from MaxText.multimodal import preprocessor
# Placeholder: internal

# Number of text sequences to process in a single batch.
Expand Down Expand Up @@ -100,14 +101,15 @@ def main(argv: Sequence[str]) -> None:
prefill_length = config.max_prefill_predict_length
processor_outputs = multimodal_utils.PreprocessorOutput()
if config.use_multimodal:
image_path = config.image_path.split(",")
images = [multimodal_utils.load_image_from_path(p) for p in image_path]
processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
processor_outputs = preprocessor.preprocess_mm_data(config)
image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)

prefill_length -= image_offsets
text = multimodal_utils.reformat_prompt(
text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
text,
image_placeholder=config.image_placeholder,
model_name=config.model_name,
num_images=processor_outputs.num_images,
)

metadata = engine.get_tokenizer()
Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ def _apply_embedding(
image_masks=image_masks,
)
# TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 This pass statement indicates that multimodal embedding for qwen3-omni-30b-a3b is not yet implemented. While it unblocks the current integration, it's crucial to implement the actual embedding logic for this model to ensure full multimodal functionality. Consider creating a follow-up task to address this.

elif cfg.model_name in ["qwen3-omni-30b-a3b"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reasoning behind this change?

pass
else:
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")

Expand Down
13 changes: 13 additions & 0 deletions src/MaxText/multimodal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
47 changes: 47 additions & 0 deletions src/MaxText/multimodal/preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multimodal data preprocessor router."""

from MaxText import multimodal_utils # TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The TODO comment on line 18 indicates a future refactoring to deprecate MaxText/multimodal_utils.py. It would be beneficial to have a clear plan or follow-up issue for this refactoring to ensure multimodal_utils.py is eventually removed and its relevant functions are moved to MaxText/multimodal/utils.py to keep the codebase clean and organized.


def preprocess_mm_data(config):
"""Preprocesses multimodal data based on the provided configuration.
Routes to the appropriate preprocessing function based on the model name.
Args:
config: A `pyconfig.Config` object containing configuration parameters.
Returns:
A `PreprocessorOutput` object containing the processed multimodal data.
"""
processor_outputs = multimodal_utils.PreprocessorOutput()

if config.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b"]:

images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
processor_outputs = multimodal_utils.pre_process_gemma3_image(images)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename the functions to preprocess_mm_data_gemma3 ?

maybe it would be better to use a factory pattern here

elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The PreprocessorOutput is imported from MaxText.multimodal_utils. Since a new MaxText.multimodal.utils has been introduced, it would be more consistent to use PreprocessorOutput from MaxText.multimodal.utils instead of the old multimodal_utils. This also aligns with the TODO to deprecate multimodal_utils.py.

Suggested change
elif config.model_name in ["llama4-17b-16e", "llama4-17b-128e"]:
from MaxText.multimodal import utils as mm_utils
from MaxText.multimodal.qwen3_omni_processor import Qwen3OmniPreprocessorOutput # To resolve a potential circular dependency
# TODO(hengtaoguo): deprecate this file and refactor to MaxText/multimodal/utils.py
def preprocess_mm_data(config):
"""Preprocesses multimodal data based on the provided configuration.
Routes to the appropriate preprocessing function based on the model name.
Args:
config: A `pyconfig.Config` object containing configuration parameters.
Returns:
A `PreprocessorOutput` object containing the processed multimodal data.
"""
processor_outputs = mm_utils.PreprocessorOutput() # Using the new utils


images = [multimodal_utils.load_image_from_path(p) for p in config.image_path.split(",")]
processor_outputs = multimodal_utils.pre_process_llama4_image(images)
elif config.model_name in ["qwen3-omni-30b-a3b"]:
from MaxText.multimodal.qwen3_omni_processor import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel

processor_outputs = preprocess_mm_data_qwen3_omni(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does it accept a config?

else:
raise ValueError(f"Model {config.model_name} not supported for multimodal preprocessing.")

return processor_outputs
Loading
Loading