Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from torch import nn
import habana_frameworks.torch as htorch
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs

Expand All @@ -33,6 +34,7 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.worker.hpu_model_runner import BaseMultimodalHandler, VisionBuckets, register_mm_handler

Check failure on line 37 in vllm/model_executor/models/gemma3_mm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/gemma3_mm.py:37:81: E501 Line too long (98 > 80)

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
Expand All @@ -46,6 +48,64 @@

is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '0') == '1' if is_hpu else False

@register_mm_handler("gemma3")
class Gemma3MultimodalHandler(BaseMultimodalHandler):

def init_mm_buckets(self, model):
model.vision_buckets = VisionBuckets(mm_buckets=[1, 2, 4, 8], validate_mm_buckets=False)
# model.audio_buckets = AudioBuckets()

def wrap_mm_modules_in_hpu_graph(self, model):
if hasattr(model, 'vision_tower'):
model.vision_tower = htorch.hpu.wrap_in_hpu_graph(
model.vision_tower, disable_tensor_cache=False)
if hasattr(model, 'multi_modal_projector'):
model.multi_modal_projector = \
htorch.hpu.wrap_in_hpu_graph( \
model.multi_modal_projector, \
disable_tensor_cache=True)

def create_mm_dummy_input(self, model, img_args, **kwargs):
s = model.config.vision_config.image_size
pixel_values = torch.randn([img_args, 3, s, s])
num_image_tokens = model.config.mm_tokens_per_image \
* img_args
multi_modal_data = {
"pixel_values": pixel_values,
"num_crops": torch.zeros([img_args], dtype=torch.int32)
}
return multi_modal_data, num_image_tokens

def compute_input_embedding(self, model, dtype, warmup_mode, **kwargs):
input_ids = kwargs['input_ids']
vision_embeddings = model.get_multimodal_embeddings(**kwargs)
inputs_embeds = model.get_input_embeddings(
input_ids, vision_embeddings)

# TODO: In warmup, we need to warmup the model with dummy image data for
# multimodal model for prompt, here instead of generating a dummy image,
# we are just generating attn_mask for the images and pass with
# attn_metadata, so we can reuse HPU graph without running
# the whole vision tower.
if vision_embeddings is not None or (
warmup_mode & kwargs['attn_metadata'].is_prompt):
input_ids = kwargs['input_ids']
positions = kwargs['positions']
kwargs = model.prepare_attn_masks(
mask_dtype=dtype,
**kwargs,
)
kwargs['input_ids'] = input_ids
kwargs['positions'] = positions

kwargs.update({
'inputs_embeds': inputs_embeds,
})

kwargs.pop('pixel_values', None)
kwargs.pop("num_crops", None)
kwargs.pop("graphed_multimodal_buckets", None)
return kwargs

class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
Expand Down
Loading