Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions vllm/model_executor/models/deepseek_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
from .deepencoder import DeepCLIPVisionTransformer, build_sam_vit_b
from .deepseek_vl2 import MlpProjector

from vllm.platforms import current_platform
is_hpu = current_platform.is_hpu()

if is_hpu:
import habana_frameworks.torch.core as htcore

# The image token id may be various
_IMAGE_TOKEN = "<image>"

Expand Down Expand Up @@ -265,6 +271,24 @@ def get_replacement_deepseek_vl2(item_idx: int):
)
]

class DeepseekOCRVisual(nn.Module):
def __init__(
self,
sam_model,
vision_model,
):
super().__init__()
self.sam_model = sam_model
self.vision_model = vision_model

def forward(
self,
image_tensor: torch.Tensor) -> torch.Tensor:
htcore.mark_step()
features_1 = self.sam_model(image_tensor)
htcore.mark_step()
features_2 = self.vision_model(image_tensor, features_1)
return features_1, features_2

@MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor,
Expand Down Expand Up @@ -327,6 +351,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=maybe_prefix(prefix, "vision_model"),
)

self.visual = DeepseekOCRVisual(self.sam_model, self.vision_model)

self.projector = MlpProjector(self.projector_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
Expand Down Expand Up @@ -387,20 +413,21 @@ def _parse_and_validate_image_input(
if images_spatial_crop is not None:
assert batch_sz == images_spatial_crop.shape[0]
if images_crop is not None:
assert batch_sz == images_crop.shape[0]
if images_crop.dtype != model_dtype:
images_crop = images_crop.to(model_dtype)
assert batch_sz == len(images_crop) \
if isinstance(images_crop, list) else \
images_crop.shape

ret_list = []
have_image_data = False
for i in range(batch_sz):
base_size = self.vision_config.image_size
if pixel_values[i] is not None:
images_crop_data = images_crop[i].to(model_dtype)

pixel_input = DeepseekOCRImagePixelInputs(
type="pixel_values",
data=pixel_values[i],
images_crop=images_crop[i] if images_crop \
is not None else None,
images_crop=images_crop_data,
images_spatial_crop=images_spatial_crop[i] \
if images_spatial_crop is not None else None,
resolve_bindings={
Expand All @@ -419,8 +446,8 @@ def _parse_and_validate_image_input(

def _encode_global_features(self,
image_tensor: torch.Tensor) -> torch.Tensor:
global_features_1 = self.sam_model(image_tensor)
global_features_2 = self.vision_model(image_tensor, global_features_1)
global_features_1, global_features_2 = \
self.visual(image_tensor)
features = torch.cat(
(
global_features_2[:, 1:],
Expand All @@ -444,8 +471,7 @@ def _encode_local_features(
if torch.sum(patches).item() == 0:
return None

local_features_1 = self.sam_model(patches)
local_features_2 = self.vision_model(patches, local_features_1)
local_features_1, local_features_2 = self.visual(patches)
features = torch.cat(
(
local_features_2[:, 1:],
Expand Down
32 changes: 28 additions & 4 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,18 @@
This class is used to bucket image tokens
'''

def __init__(self, is_batch_based):
def __init__(self, is_batch_based, sub_image_list = None):
self.is_batch_based = is_batch_based
envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "").lower()
if envvar == 'none':
self.multimodal_buckets = None
else:
if envvar == "":
if is_batch_based:
if sub_image_list is not None:
assert isinstance(sub_image_list, list), \
"sub_image_list must be a list"
multimodal_buckets = sub_image_list
elif is_batch_based:
multimodal_buckets = [1, 2, 4, 8] # batch sizes for gemma3
else:
multimodal_buckets = [1600, 3136, 4096, 6400]
Expand Down Expand Up @@ -205,6 +209,10 @@
'Gemma3ForConditionalGeneration' in str(type(model)) or \
'DeepseekOCRForCausalLM' in str(type(model))

def fixed_sub_image_list(model):
return [i for i in range(6+1) if i!=1] \
if 'DeepseekOCRForCausalLM' in str(type(model)) else None


def pad_flat_tensor(tensor, desired_size):
assert tensor.dim() == 1, 'Only flat tensors are supported'
Expand Down Expand Up @@ -411,9 +419,13 @@
if self.model_is_mrope and hasattr(self.model, 'visual') and \
model_config is not None and \
model_config.model_type != "glm4v_moe":
logger.info("[Multimodal] Wrapping Visual Model")

Check failure on line 422 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/worker/hpu_model_runner.py:422:81: E501 Line too long (129 > 80)
self.model.visual = htorch.hpu.wrap_in_hpu_graph(
self.model.visual, disable_tensor_cache=True)
elif 'DeepseekOCRForCausalLM' in str(type(self.model)):
logger.info("[Multimodal] Wrapping Visual Model")
self.model.visual = htorch.hpu.wrap_in_hpu_graph(
self.model.visual, disable_tensor_cache=True)

if self.model_is_mrope and hasattr(self.model, 'audio_tower'):
logger.info("[Multimodal] Wrapping Audio Model")
Expand Down Expand Up @@ -1717,8 +1729,10 @@
def add_vision_buckets_to_mrope_mm_optimized(self):
model = self.get_model()
self.is_mm_optimized = is_mm_optimized(model)
sub_image_list = fixed_sub_image_list(model)
if self.model_is_mrope or self.is_mm_optimized:
model.vision_buckets = VisionBuckets(self.is_mm_optimized)
model.vision_buckets = \
VisionBuckets(self.is_mm_optimized, sub_image_list)
model.vision_buckets.graphed_buckets = \
self.graphed_multimodal_buckets
model.audio_buckets = AudioBuckets()
Expand Down Expand Up @@ -3007,6 +3021,16 @@
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
}
elif 'DeepseekOCRForCausalLM' in str(type(self.get_model())):
pixel_values = torch.randn(1, 3, 1024, 1024)
images_crop = torch.randn(img_args, 3, 640, 640)
images_spatial_crop = torch.tensor([[1, img_args]])
multi_modal_data = {
"pixel_values": pixel_values,
"images_crop": images_crop,
"images_spatial_crop": images_spatial_crop,
}
num_image_tokens=1
elif self.model_is_mrope:
if not hasattr(self.get_model().config, "vision_config"):
raise ValueError("Expect mrope model to have vision_config")
Expand Down Expand Up @@ -3061,7 +3085,7 @@

if 'ernie4_5_moe_vl' in self.get_model().config.model_type:
image_token_id = self.get_model().config.im_patch_id
elif 'deepseek_vl_v2' in self.get_model().config.model_type:
elif 'DeepseekOCRForCausalLM' in str(type(self.get_model())):
image_token_id = self.get_model().image_token_id
else:
image_token_id = self.get_model().config.image_token_id
Expand Down
Loading