From 84319bb5707ee458bcedb56369508bca82bcc65c Mon Sep 17 00:00:00 2001 From: Jing Date: Thu, 25 Sep 2025 21:22:18 +0800 Subject: [PATCH] refactor(hpu_model_runner): restructure multimodal-related code --- vllm/model_executor/models/gemma3_mm.py | 60 +++ vllm/worker/hpu_model_runner.py | 466 +++++++++++------------- 2 files changed, 273 insertions(+), 253 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 7ca4fd95347a..427b168d84c6 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -7,6 +7,7 @@ import torch from torch import nn +import habana_frameworks.torch as htorch from transformers import BatchFeature, Gemma3Config, Gemma3Processor from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs @@ -33,6 +34,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.worker.hpu_model_runner import BaseMultimodalHandler, VisionBuckets, register_mm_handler from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -46,6 +48,64 @@ is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '0') == '1' if is_hpu else False +@register_mm_handler("gemma3") +class Gemma3MultimodalHandler(BaseMultimodalHandler): + + def init_mm_buckets(self, model): + model.vision_buckets = VisionBuckets(mm_buckets=[1, 2, 4, 8], validate_mm_buckets=False) + # model.audio_buckets = AudioBuckets() + + def wrap_mm_modules_in_hpu_graph(self, model): + if hasattr(model, 'vision_tower'): + model.vision_tower = htorch.hpu.wrap_in_hpu_graph( + model.vision_tower, disable_tensor_cache=False) + if hasattr(model, 'multi_modal_projector'): + model.multi_modal_projector = \ + htorch.hpu.wrap_in_hpu_graph( \ + model.multi_modal_projector, \ + disable_tensor_cache=True) + + def create_mm_dummy_input(self, model, img_args, **kwargs): + s = model.config.vision_config.image_size + pixel_values = torch.randn([img_args, 3, s, s]) + num_image_tokens = model.config.mm_tokens_per_image \ + * img_args + multi_modal_data = { + "pixel_values": pixel_values, + "num_crops": torch.zeros([img_args], dtype=torch.int32) + } + return multi_modal_data, num_image_tokens + + def compute_input_embedding(self, model, dtype, warmup_mode, **kwargs): + input_ids = kwargs['input_ids'] + vision_embeddings = model.get_multimodal_embeddings(**kwargs) + inputs_embeds = model.get_input_embeddings( + input_ids, vision_embeddings) + + # TODO: In warmup, we need to warmup the model with dummy image data for + # multimodal model for prompt, here instead of generating a dummy image, + # we are just generating attn_mask for the images and pass with + # attn_metadata, so we can reuse HPU graph without running + # the whole vision tower. + if vision_embeddings is not None or ( + warmup_mode & kwargs['attn_metadata'].is_prompt): + input_ids = kwargs['input_ids'] + positions = kwargs['positions'] + kwargs = model.prepare_attn_masks( + mask_dtype=dtype, + **kwargs, + ) + kwargs['input_ids'] = input_ids + kwargs['positions'] = positions + + kwargs.update({ + 'inputs_embeds': inputs_embeds, + }) + + kwargs.pop('pixel_values', None) + kwargs.pop("num_crops", None) + kwargs.pop("graphed_multimodal_buckets", None) + return kwargs class Gemma3ImagePixelInputs(TypedDict): type: Literal["pixel_values"] diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 348b85ed9f59..e4bc725203d9 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -5,6 +5,7 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### +from abc import ABC, abstractmethod import collections import contextlib import dataclasses @@ -96,7 +97,7 @@ LORA_WARMUP_RANK = 8 DUMMY_TOKEN_ID = -1 -UNSET_IMG_ARGS = 9999999 +UNSET_MM_LEN = 9999999 shutdown_inc_called = False @@ -105,95 +106,175 @@ class PhaseType(Enum): PREFIX_PREFILL = 'prefix_prefill' DECODE = 'decode' +MULTIMODAL_HANDLER_REGISTRY = {} -class VisionBuckets: - ''' - This class is used to bucket image tokens - ''' +def register_mm_handler(model_type: str): + def wrapper(cls): + MULTIMODAL_HANDLER_REGISTRY[model_type] = cls + return cls + return wrapper - def __init__(self, is_batch_based): - self.is_batch_based = is_batch_based - envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "").lower() - if envvar == 'none': - self.multimodal_buckets = None - else: - if envvar == "": - if is_batch_based: - multimodal_buckets = [1, 2, 4, 8] # batch sizes for gemma3 - else: - multimodal_buckets = [ - 1600, 3136, 4096, 6400, 7744, 9216, 12544 - ] - else: - multimodal_buckets = [int(i) for i in envvar.split(',')] - self.multimodal_buckets = self._process_buckets(multimodal_buckets) +class MultimodalBuckets(ABC): + """ + Base class for multimodal buckets. + + Bucket sources (priority order): + 1. Environment variable (defined by `_get_env_var_name()` in subclasses). + 2. Constructor argument `mm_buckets`. + """ + def __init__(self, mm_buckets=None, validate_mm_buckets=False): + self.multimodal_buckets = None + envvar = os.environ.get(self._get_env_var_name(), "") + if envvar != "": + self.multimodal_buckets = [int(i) for i in envvar.split(',')] + elif mm_buckets is not None: + self.multimodal_buckets = mm_buckets + + if validate_mm_buckets: + self._validate_buckets() + self.multimodal_buckets = sorted(self.multimodal_buckets) self.graphed_buckets = set() self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true' - def _process_buckets(self, buckets): - if not self.is_batch_based: - for bucket in buckets: - assert bucket % 8 == 0, ( - 'Buckets needs to be multiples 8 (slices of 64)') - return sorted(buckets) - - def get_multimodal_bucket(self, curr_num_image_patches): - if self.multimodal_buckets is not None: - for mm_bucket in self.multimodal_buckets: - if curr_num_image_patches <= mm_bucket: - return mm_bucket - return curr_num_image_patches - else: - return 0 + @abstractmethod + def _get_env_var_name(self): + """Return the env var name for this bucket type""" + pass + + def _validate_buckets(self): + pass + + def get_multimodal_bucket(self, curr_size): + if self.multimodal_buckets is None: + return None + + for mm_bucket in self.multimodal_buckets: + if curr_size <= mm_bucket: + return mm_bucket + return curr_size def __repr__(self): - return str(self.multimodal_buckets) + return f"{self.__class__.__name__}:{str(self.multimodal_buckets)}" - def use_graph(self, seq_len): - if self.skip_warmup and \ - self.multimodal_buckets is not None and \ - seq_len in self.multimodal_buckets: - return True - return seq_len in self.graphed_buckets + def use_graph(self, curr_size): + in_bucket = ( + self.multimodal_buckets is not None + and curr_size in self.multimodal_buckets + ) + return (self.skip_warmup and in_bucket) or (curr_size in self.graphed_buckets) -class AudioBuckets: - ''' - This class is used to bucket audio tokens - ''' +class VisionBuckets(MultimodalBuckets): + """Default vision buckets for models with mrope""" - def __init__(self): - envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS_AUDIO', "").lower() - if envvar == 'none': - self.multimodal_buckets = None - else: - if envvar == "": - self.multimodal_buckets = list(range(0, 12801, 1600)) - else: - self.multimodal_buckets = [int(i) for i in envvar.split(',')] - self.graphed_buckets = set() - self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', - 'false').lower() == 'true' + def __init__(self, mm_buckets=None, validate_mm_buckets=False): + default_vision_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544] + super().__init__(mm_buckets=mm_buckets if mm_buckets is not None else default_vision_buckets, + validate_mm_buckets=validate_mm_buckets) - def get_multimodal_bucket(self, curr_num_audio_patches): - if self.multimodal_buckets is not None: - for mm_bucket in self.multimodal_buckets: - if curr_num_audio_patches <= mm_bucket: - return mm_bucket - return curr_num_audio_patches - else: - return 0 + def _get_env_var_name(self): + return "VLLM_MULTIMODAL_BUCKETS" + + def _validate_buckets(self): + for bucket in self.multimodal_buckets: + assert bucket % 8 == 0, ( + 'Buckets needs to be multiples 8 (slices of 64)') - def __repr__(self): - return str(self.multimodal_buckets) - def use_graph(self, seq_len): - if self.skip_warmup and \ - self.multimodal_buckets is not None and \ - seq_len in self.multimodal_buckets: - return True - return seq_len in self.graphed_buckets +class AudioBuckets(MultimodalBuckets): + def __init__(self, mm_buckets=None, validate_mm_buckets=False): + default_audio_buckets = list(range(0, 12801, 1600)) + super().__init__(mm_buckets=mm_buckets if mm_buckets is not None else default_audio_buckets, + validate_mm_buckets=validate_mm_buckets) + + def _get_env_var_name(self): + return "VLLM_MULTIMODAL_BUCKETS_AUDIO" + + +class BaseMultimodalHandler: + """ + Helper class for HpuModelAdapter to handle multimodal functionality. + + Assumes mrope-based models by default. + For other types, extend by: + - Subclassing this base class + - Registering with @register_mm_handler(model_type) + - Overriding necessary methods + """ + def init_mm_buckets(self, model): + model.vision_buckets = VisionBuckets(validate_mm_buckets=True) + model.audio_buckets = AudioBuckets() + + def wrap_mm_modules_in_hpu_graph(self, model): + # This applies exclusively to Qwen2/2.5-VL models + # both use mrope. We wrap the visual and language + # models separately with HPU graph. + # This is to ensure that we keeps + # the static and dynamic parts distinct. + if hasattr(model, 'visual'): + logger.info("[Multimodal] Wrapping Visual Model") + model.visual = htorch.hpu.wrap_in_hpu_graph( + model.visual, disable_tensor_cache=True) + if hasattr(model, 'audio_tower'): + logger.info("[Multimodal] Wrapping Audio Model") + model.audio_tower = htorch.hpu.wrap_in_hpu_graph( + model.audio_tower) + + def create_mm_dummy_input(self, model, mm_len, **kwargs): + if not hasattr(model.config, "vision_config"): + raise ValueError("Expect mrope model to have vision_config") + vision_config = model.config.vision_config + if not hasattr(vision_config, "spatial_merge_size"): + raise ValueError( + "Expect mrope model to have spatial_merge_size") + + spatial_merge_unit = vision_config.spatial_merge_size**2 + num_image_tokens = mm_len // spatial_merge_unit + assert mm_len % 8 == 0, ( + f"Expects mm_len to be multiples of 8, got: {mm_len}") + image_h = mm_len // 8 + image_grid_thw = torch.tensor( + [[1, image_h, int(mm_len / image_h)]]) + pixel_values = torch.randn( + image_grid_thw[0].prod(), + 1176) # TODO: figure out the variable name + + assert pixel_values.shape[0] % 64 == 0, ( + f"pixel_values must be sliced in 64 chunks, " + f"got: {pixel_values.shape}") + + multi_modal_data = { + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + } + return multi_modal_data, num_image_tokens + + def compute_input_embedding(self, model, **kwargs): + input_ids = kwargs['input_ids'] + if model.config.model_type == 'qwen2_5_omni_thinker': + multimodal_embeddings = \ + model.get_multimodal_embeddings_v0(**kwargs) + inputs_embeds = model.get_input_embeddings_v0( + input_ids, multimodal_embeddings) + else: + image_input = model._parse_and_validate_image_input( + **kwargs) + video_input = model._parse_and_validate_video_input( + **kwargs) + inputs_embeds = model.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + kwargs.update({ + # 'input_ids': None, + 'inputs_embeds': inputs_embeds, + }) + # done compute the visual tokens + kwargs.pop('pixel_values', None) + kwargs.pop('image_grid_thw', None) + return kwargs class Singleton(type): @@ -392,7 +473,15 @@ def __init__(self, model, vllm_config, is_causal, sampler): model_config = getattr(self.model, "config", None) self.model_is_mrope = uses_mrope(model_config) - self.is_mm_optimized = is_mm_optimized(self.model) + + multimodal_handler_cls = MULTIMODAL_HANDLER_REGISTRY.get(model_config.model_type, None) + if multimodal_handler_cls is None and self.model_is_mrope: + multimodal_handler_cls = BaseMultimodalHandler + + if multimodal_handler_cls is not None: + self.mm_handler = multimodal_handler_cls() + self.mm_handler.init_mm_buckets(self.model) + text_config = vllm_config.model_config.hf_config.get_text_config() self.interleaved_sliding_window = getattr( text_config, "interleaved_sliding_window", @@ -406,30 +495,8 @@ def __init__(self, model, vllm_config, is_causal, sampler): self.sliding_window_thld = int( os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192')) - # This applies exclusively to Qwen2/2.5-VL models - # both use mrope. We wrap the visual and language - # models separately with HPU graph. - # This is to ensure that we keeps - # the static and dynamic parts distinct. if htorch.utils.internal.is_lazy(): - if self.model_is_mrope and hasattr(self.model, 'visual'): - logger.info("[Multimodal] Wrapping Visual Model") - self.model.visual = htorch.hpu.wrap_in_hpu_graph( - self.model.visual, disable_tensor_cache=True) - if self.model_is_mrope and hasattr(self.model, 'audio_tower'): - logger.info("[Multimodal] Wrapping Audio Model") - self.model.audio_tower = htorch.hpu.wrap_in_hpu_graph( - self.model.audio_tower) - - if self.is_mm_optimized: - if hasattr(self.model, 'vision_tower'): - self.model.vision_tower = htorch.hpu.wrap_in_hpu_graph( - self.model.vision_tower, disable_tensor_cache=False) - if hasattr(self.model, 'multi_modal_projector'): - self.model.multi_modal_projector = \ - htorch.hpu.wrap_in_hpu_graph( \ - self.model.multi_modal_projector, \ - disable_tensor_cache=True) + self.mm_handler.wrap_mm_modules_in_hpu_graph(self.model) self._rotary_embed_module = self._get_rotary_embedding_module( self.model) @@ -671,41 +738,13 @@ def _update_metadata(self, device, dtype, True) return attn_metadata - def compute_input_embeddings_for_mm_optimized(self, warmup_mode, **kwargs): - input_ids = kwargs['input_ids'] - vision_embeddings = self.model.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.model.get_input_embeddings( - input_ids, vision_embeddings) - - # TODO: In warmup, we need to warmup the model with dummy image data for - # multimodal model for prompt, here instead of generating a dummy image, - # we are just generating attn_mask for the images and pass with - # attn_metadata, so we can reuse HPU graph without running - # the whole vision tower. - if vision_embeddings is not None or ( - warmup_mode & kwargs['attn_metadata'].is_prompt): - input_ids = kwargs['input_ids'] - positions = kwargs['positions'] - kwargs = self.model.prepare_attn_masks( - mask_dtype=self.dtype, - **kwargs, - ) - kwargs['input_ids'] = input_ids - kwargs['positions'] = positions - - kwargs.update({'inputs_embeds': inputs_embeds}) - # done compute the visual tokens and others - kwargs.pop('pixel_values', None) - kwargs.pop("num_crops", None) - kwargs.pop("graphed_multimodal_buckets", None) - return kwargs def compute_input_embeddings_for_mrope_mm_optimized( self, warmup_mode, **kwargs): if 'inputs_embeds' in kwargs: return kwargs - if not self.model_is_mrope and not self.is_mm_optimized: + if self.mm_handler is None: return None # For Qwen2.5-VL/Gemma3 VL multimodal embedding, # this embedding part should be executed @@ -720,34 +759,11 @@ def compute_input_embeddings_for_mrope_mm_optimized( compile_only_mode_context_false = functools.partial( bc.env_setting, "PT_COMPILE_ONLY_MODE", False) - input_ids = kwargs['input_ids'] with compile_only_mode_context_false(): - if self.model_is_mrope: - if self.model.config.model_type == 'qwen2_5_omni_thinker': - multimodal_embeddings = \ - self.model.get_multimodal_embeddings_v0(**kwargs) - inputs_embeds = self.model.get_input_embeddings_v0( - input_ids, multimodal_embeddings) - else: - image_input = self.model._parse_and_validate_image_input( - **kwargs) - video_input = self.model._parse_and_validate_video_input( - **kwargs) - inputs_embeds = self.model.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input) - input_ids = None - kwargs.update({ - 'inputs_embeds': inputs_embeds, - }) - # done compute the visual tokens - kwargs.pop('pixel_values', None) - kwargs.pop('image_grid_thw', None) - return kwargs - else: - return self.compute_input_embeddings_for_mm_optimized( - warmup_mode, **kwargs) + return self.mm_handler.compute_input_embedding(self.model, + dtype=self.dtype, + warmup_mode=warmup_mode, + **kwargs) def forward(self, *args, **kwargs): kwargs = kwargs.copy() @@ -773,14 +789,16 @@ def forward(self, *args, **kwargs): if self._rotary_prepare_cos_sin is not None and not self.model_is_mrope: self._rotary_prepare_cos_sin( kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) - if self.model_is_mrope or self.is_mm_optimized: - # inputs_embeds was computed on execute_model - # now we always want to use the inputs_embeds - # even if the prompt is text only - # that keeps all the shapes consistent with warmup - kwargs.update({ - 'input_ids': None, + + if self.mm_handler is not None: + # All supported multimodal models precompute input_embeds, + # since handling multimodal modules with separate HPU graphs is more flexible. + # Presence of mm_handler is used as the condition (may refine later). + kwargs.update({ + 'input_ids': None }) + # Always use inputs_embeds (even for text-only) to keep shapes consistent with warmup. + attn_meta = kwargs.pop('attn_metadata') if 'kv_caches' in kwargs: kwargs.pop('kv_caches') @@ -1112,7 +1130,6 @@ def __init__( self.multi_modal_input_mapper = self.mm_registry \ .create_input_mapper(self.model_config) self.mm_registry.init_mm_limits_per_prompt(self.model_config) - self.is_mm_optimized = False # Lazy initialization self.lora_manager: LRUCacheWorkerLoRAManager = None self.model: torch.nn.Module = None @@ -1420,9 +1437,6 @@ def move_model_to_hpu(model): msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) - # Models that process images at different resolutions - # need to be warmed up. Current tested for MRoPE models only. - self.add_vision_buckets_to_mrope_mm_optimized() def _add_dummy_seq(self, seq_group_metadata_list, @@ -1648,15 +1662,6 @@ def move_to_device(self, tensor): return tensor if tensor is None else tensor.to(self.device, non_blocking=True) - def add_vision_buckets_to_mrope_mm_optimized(self): - model = self.get_model() - self.is_mm_optimized = is_mm_optimized(model) - if self.model_is_mrope or self.is_mm_optimized: - model.vision_buckets = VisionBuckets(self.is_mm_optimized) - model.vision_buckets.graphed_buckets = \ - self.graphed_multimodal_buckets - model.audio_buckets = AudioBuckets() - def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -2921,51 +2926,18 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: ]) return attention_metadata - def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args, + def create_dummy_multi_modal_seq_group_metadata(self, group_id, mm_len, sampling_params, lora_request, seq_len): - assert self.model_is_mrope or self.is_mm_optimized, \ + assert self.model.mm_handler is not None, \ ("Warmup compatible with Qwen2vl/Gemma3 models") - if img_args == UNSET_IMG_ARGS: + # assert self.model_is_mrope or self.is_mm_optimized, \ + # ("Warmup compatible with Qwen2vl/Gemma3 models") + if mm_len == UNSET_MM_LEN: # Using the largest bucket - img_args = self.get_model().vision_buckets.multimodal_buckets[-1] + mm_len = self.get_model().vision_buckets.multimodal_buckets[-1] - if self.model_is_mrope: - if not hasattr(self.get_model().config, "vision_config"): - raise ValueError("Expect mrope model to have vision_config") - vision_config = self.get_model().config.vision_config - if not hasattr(vision_config, "spatial_merge_size"): - raise ValueError( - "Expect mrope model to have spatial_merge_size") - - spatial_merge_unit = vision_config.spatial_merge_size**2 - num_image_tokens = img_args // spatial_merge_unit - assert img_args % 8 == 0, ( - f"Expects img_args to be multiples of 8, got: {img_args}") - image_h = img_args // 8 - image_grid_thw = torch.tensor( - [[1, image_h, int(img_args / image_h)]]) - pixel_values = torch.randn( - image_grid_thw[0].prod(), - 1176) # TODO: figure out the variable name - - assert pixel_values.shape[0] % 64 == 0, ( - f"pixel_values must be sliced in 64 chunks, " - f"got: {pixel_values.shape}") - - multi_modal_data = { - "pixel_values": pixel_values, - "image_grid_thw": image_grid_thw, - } - else: - s = self.model.model.config.vision_config.image_size - pixel_values = torch.randn([img_args, 3, s, s]) - num_image_tokens = self.model.model.config.mm_tokens_per_image \ - * img_args - multi_modal_data = { - "pixel_values": pixel_values, - "num_crops": torch.zeros([img_args], dtype=torch.int32) - } + multi_modal_data, num_image_tokens = self.model.mm_handler.create_mm_dummy_input(self.get_model(), mm_len) image_token_id = self.get_model().config.image_token_id prompt_token_ids_image = [image_token_id] * num_image_tokens @@ -2997,7 +2969,7 @@ def create_dummy_seq_group_metadata(self, seq_len, is_prompt, lora_request=None, - img_args=None, + mm_len=None, temperature=0, ctx=0): if self.is_pooler: @@ -3008,10 +2980,11 @@ def create_dummy_seq_group_metadata(self, seq_len = max(seq_len, 1) computed_block_nums = None if is_prompt: - if self.is_mm_run() and img_args is not None: + if mm_len is not None: + assert self.is_mm_run() return self.create_dummy_multi_modal_seq_group_metadata( group_id=group_id, - img_args=img_args, + mm_len=mm_len, sampling_params=sampling_params, lora_request=lora_request, seq_len=seq_len, @@ -3044,8 +3017,9 @@ def create_dummy_seq_group_metadata(self, lora_request=lora_request) def is_mm_run(self) -> bool: - return (self.is_mm_optimized or self.model_is_mrope) and \ - (self.multimodal_buckets is not None) + return self.model.mm_handler is not None and self.multimodal_buckets is not None + # return (self.is_mm_optimized or self.model_is_mrope) and \ + # (self.multimodal_buckets is not None) def profile_run(self) -> None: # Skip profile run on decode instances @@ -3061,14 +3035,12 @@ def profile_run(self) -> None: max_seq_len = self.bucketing_manager.get_max_prompt_shape() max_batch_size = min(self.max_num_seqs, self.max_num_batched_tokens // max_seq_len) - - if self.model_is_mrope or self.is_mm_optimized: + if self.model.mm_handler is not None: # Using batch_size 1 is profile multimodal models max_batch_size = 1 model = self.get_model() self.multimodal_buckets = model.vision_buckets.multimodal_buckets - logger_msg = "Multimodal bucket : " + str(self.multimodal_buckets) - logger.info(logger_msg) + logger.info(model.vision_buckets) logger.info("Profile run with bs=%s, seq_len=%s", \ max_batch_size, max_seq_len) @@ -3080,7 +3052,7 @@ def profile_run(self) -> None: is_prompt=True, kv_caches=kv_caches, is_pt_profiler_run=False, - img_args=UNSET_IMG_ARGS if self.is_mm_run() else None, + mm_len=UNSET_MM_LEN if self.is_mm_run() else None, is_lora_profile_run=True, ) @@ -3095,7 +3067,7 @@ def _dummy_run(self, max_num_batched_tokens: int) -> None: is_prompt=False, kv_caches=None, is_pt_profiler_run=False, - img_args=UNSET_IMG_ARGS if self.is_mm_run() else None, + img_args=UNSET_MM_LEN if self.is_mm_run() else None, is_lora_profile_run=True, num_iters=1, align_worker=True, @@ -3153,7 +3125,7 @@ def warmup_scenario(self, is_pt_profiler_run=False, is_lora_profile_run=False, temperature=0, - img_args=None, + mm_len=None, num_iters=3, align_worker=False, is_dummy_run=False) -> None: @@ -3165,7 +3137,7 @@ def warmup_scenario(self, f"bs{batch_size}_" f"seq{seq_len}_" f"ctx{ctx}_" - f"multimodal{img_args if img_args else 'F'}_" + f"multimodal{mm_len if mm_len else 'F'}_" f"graphs{'T' if use_graphs else 'F'}") # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory @@ -3200,7 +3172,7 @@ def warmup_scenario(self, is_prompt, lora_request=dummy_lora_requests_per_seq[i] if dummy_lora_requests_per_seq else None, - img_args=img_args, + mm_len=mm_len, temperature=temperature, ctx=ctx) for i in range(batch_size) ] @@ -3322,14 +3294,14 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len, ctx): logger.info(msg) def log_warmup_multimodal(self, phase, i, max_i, batch_size, seq_len, - img_args): + mm_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " f"{dim}:{seq_len} " - f"img_args:{img_args} " + f"mm_len:{mm_len} " f"free_mem:{free_mem}") logger.info(msg) @@ -3406,17 +3378,17 @@ def _warmup_multimodal_graph(self, num_candidates = len(self.multimodal_buckets) captured_all = True - for idx, img_args in enumerate(self.multimodal_buckets): + for idx, mm_len in enumerate(self.multimodal_buckets): batch_size = 1 # Note: Multimodal buckets do not change with bs max_seq_len = self.bucketing_manager.get_max_prompt_shape() seq_len = max_seq_len - batch_seq = 1 * img_args - graphed_multimodal_bucket = img_args + batch_seq = 1 * mm_len + graphed_multimodal_bucket = mm_len if graphed_multimodal_bucket in self.graphed_multimodal_buckets: continue self.graphed_multimodal_buckets.add(graphed_multimodal_bucket) self.log_warmup_multimodal(phase, idx, num_candidates, batch_size, - seq_len, img_args) + seq_len, mm_len) with HabanaMemoryProfiler() as mem_prof: self.warmup_scenario(batch_size=batch_size, @@ -3424,7 +3396,7 @@ def _warmup_multimodal_graph(self, ctx=0, is_prompt=True, kv_caches=kv_caches, - img_args=img_args) + mm_len=mm_len) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) @@ -3473,18 +3445,14 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: graphs = graph == 't' if graphs: self.graphed_buckets.add(cfg) - if self.is_mm_run(): - img_args = (int(seq_len) // - self.model.model.config.mm_tokens_per_image - if self.is_mm_optimized else int(seq_len)) + self.warmup_scenario( int(bs), int(seq_len), ctx, is_prompt, kv_caches, - is_pt_profiler_run=True, - img_args=img_args if self.is_mm_run() else None) + is_pt_profiler_run=True) raise AssertionError("Finished profiling") if not htorch.utils.internal.is_lazy() and not self.enforce_eager: multiplier = 3 if os.getenv('VLLM_REGIONAL_COMPILATION', @@ -4206,22 +4174,14 @@ def try_revert_dummy_output_tokens(): execute_model_kwargs['attn_metadata'] = attn_metadata if not bypass_model_exec: - if self.model_is_mrope or self.is_mm_optimized: - if 'pixel_values' in execute_model_kwargs and \ - self.is_mm_optimized: - if warmup_mode and not is_pt_profiler_run: - bypass_model_exec = True - execute_model_kwargs[ - 'graphed_multimodal_buckets'] = \ - list(self.graphed_multimodal_buckets) - # set is unhasable and causes friction with - # hpu graphs, hence turning it to a list + + if self.model.mm_handler is not None: execute_model_kwargs = \ self.model.compute_input_embeddings_for_mrope_mm_optimized( warmup_mode, **execute_model_kwargs ) - if warmup_mode and bypass_model_exec: + if warmup_mode and is_pt_profiler_run: return [] with self.profiler.record_event('internal',