diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 275e9a18efe4..aaac656f1ad0 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1032,6 +1032,35 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData: ) +# PaddleOCR-VL +def run_paddleocr_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "PaddlePaddle/PaddleOCR-VL" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 147384, + "max_pixels": 2822400, + }, + limit_mm_per_prompt={modality: 1}, + ) + + placeholder = "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" + + prompts = [ + f"<|begin_of_sentence|>User:{placeholder}\n{question}" for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # PaliGemma def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1474,6 +1503,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "NVLM_D": run_nvlm_d, "ovis": run_ovis, "ovis2_5": run_ovis2_5, + "paddleocr_vl": run_paddleocr_vl, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, diff --git "a/scripts/docs/vllm\346\216\250\347\220\206\346\211\213\345\206\214.md" "b/scripts/docs/vllm\346\216\250\347\220\206\346\211\213\345\206\214.md" index 2646c68be61a..8c098e70a19a 100644 --- "a/scripts/docs/vllm\346\216\250\347\220\206\346\211\213\345\206\214.md" +++ "b/scripts/docs/vllm\346\216\250\347\220\206\346\211\213\345\206\214.md" @@ -39,7 +39,8 @@ - [3.4.2 client 端请求格式样例](#342-client-端请求格式样例) - [3.4.3 FP8 static quant](#343-fp8-static-quant) - [3.4.4 FP8 dynamic quant](#344-fp8-dynamic-quant) - - [3.4.5 问题解答](#345-问题解答) + - [3.4.5 PaddleOCR-VL 模型](#345-paddleocr-vl-模型) + - [3.4.6 问题解答](#346-问题解答) ## 1.0 环境部署 @@ -959,7 +960,41 @@ PT_HPU_LAZY_MODE=1 VLLM_GRAPH_RESERVED_MEM=0.5 vllm serve \ --mm_processor_kwargs max_pixels=1003520,min_pixels=3136 ``` -#### 3.4.5 问题解答 +#### 3.4.5 PaddleOCR-VL 模型 +**启动服务** + +```bash +PT_HPU_LAZY_MODE=1 vllm serve \ + PaddlePaddle/PaddleOCR-VL \ + --host 0.0.0.0 \ + --port 8080 \ + --trust-remote-code \ + --gpu-memory-utilization 0.5 \ + --max-model-len 16384 \ + --served-model-name 'PaddleOCR-VL-0.9B' +``` + +**client 端请求格式样例**\ +PaddleOCR-VL 模型的client端依赖于PaddleOCR pipeline, 先安装必要的paddle相关库: + +```bash +pip install paddlepaddle==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/ +pip install paddlex==3.3.4 +pip install "paddleocr[doc-parser]" +``` + +然后,使用PaddleOCR CLI 命令发送请求: + +```bash +paddleocr doc_parser \ + -i https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/paddleocr_vl_demo.png \ + --enable_mkldnn False \ + --vl_rec_backend vllm-server \ + --vl_rec_server_url http://127.0.0.1:8080/v1 \ + --save_path ./output +``` + +#### 3.4.6 问题解答 - 如果 server 端出现获取图像音视频超时错误,可以通过设置环境变量`VLLM_IMAGE_FETCH_TIMEOUT` `VLLM_VIDEO_FETCH_TIMEOUT` `VLLM_AUDIO_FETCH_TIMEOUT` 来提高超时时间。默认为 5/30/10 - 过大的输入图像要求更多的设备内存,可以通过设置更小的参数`--gpu-memory-utilization` (默认 0.9)来解决。例如参考脚本`openai_chat_completion_client_for_multimodal.py`中的图像分辨率最高达到 7952x5304,这会导致 server 端推理出错。可以通过设置`--gpu-memory-utilization`至 0.6~0.7 来解决。 diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 39d5c9346272..6db6e760a473 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -549,6 +549,8 @@ def _placeholder_str(self, modality: ModalityStr, return "" if model_type == "kimi_vl": return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501 + if model_type == "paddleocr_vl": + return None raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py new file mode 100644 index 000000000000..9b9a5f381db3 --- /dev/null +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -0,0 +1,1189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat +from transformers import BatchFeature +from transformers.activations import GELUActivation +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.utils import torch_int + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.vision import get_vit_attn_backend +from vllm.platforms import _Backend, current_platform + +try: + from vllm.model_executor.models.ernie45 import Ernie4_5_ForCausalLM +except ImportError: + from vllm.model_executor.models.ernie45 import ( + Ernie4_5ForCausalLM as Ernie4_5_ForCausalLM, ) + +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.utils import (AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + merge_multimodal_embeddings) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +is_hpu = current_platform.is_hpu() + +if is_hpu: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 28 * 28 * 130, + max_pixels: int = 28 * 28 * 1280, +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range + ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + + if height < factor: + print(f"smart_resize: height={height} < factor={factor}, \ + reset height=factor") + width = round((width * factor) / height) + height = factor + + if width < factor: + print(f"smart_resize: width={width} < factor={factor}, \ + reset width=factor") + height = round((height * factor) / width) + width = factor + + if max(height, width) / min(height, width) > 200: + raise ValueError(f"absolute aspect ratio must be smaller than 200, \ + got {max(height, width) / min(height, width)}") + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class PaddleOCRVLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(**kwargs) + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self): + return {"image": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor, + ) -> int: + if image_processor is None: + image_processor = self.get_image_processor() + do_resize = True + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + grid_t = 1 + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_image_tokens = num_patches // (merge_size**2) + + return num_image_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + image_size = hf_config.vision_config.image_size + return ImageSize(height=image_size, width=image_size) + + +class PaddleOCRVLDummyInputsBuilder( + BaseDummyInputsBuilder[PaddleOCRVLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + (target_width, + target_height) = (self.info.get_image_size_with_most_features()) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class PaddleOCRVLMultiModalProcessor( + BaseMultiModalProcessor[PaddleOCRVLProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + mm_kwargs, + ) + processed_outputs["pixel_values"] = processed_outputs[ + "pixel_values"].unsqueeze(0) + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_grid_thw=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_id + + def get_replacement(item_idx: int, image_processor): + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + image_processor=image_processor, + ) + + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=partial(get_replacement, + image_processor=image_processor), + ), + ] + + +class Projector(nn.Module): + + def __init__( + self, + text_config, + vision_config, + prefix: str = "", + ): + super().__init__() + self.text_config = text_config + self.vision_config = vision_config + self.merge_kernel_size = (2, 2) + + self.hidden_size = (self.vision_config.hidden_size * + self.merge_kernel_size[0] * + self.merge_kernel_size[1]) + + self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, + eps=1e-05) + self.linear_1 = nn.Linear(self.hidden_size, + self.hidden_size, + bias=True) + self.act = GELUActivation() + self.linear_2 = nn.Linear(self.hidden_size, + self.text_config.hidden_size, + bias=True) + + def forward( + self, + image_features: torch.Tensor, + image_grid_thw: list[tuple[int, int, int]], + ) -> torch.Tensor: + m1, m2 = self.merge_kernel_size + if isinstance(image_features, (list, tuple)): + processed_features = list() + for image_feature, image_grid in zip(image_features, + image_grid_thw): + image_feature = self.pre_norm(image_feature) + t, h, w = image_grid + + image_feature = rearrange( + image_feature, + "(t h p1 w p2) d -> (t h w) (p1 p2 d)", + t=t, + h=h // m1, + p1=m1, + w=w // m2, + p2=m2, + ) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + processed_features.append(hidden_states) + + return processed_features + + dims = image_features.shape[:-1] + dim = image_features.shape[-1] + image_features = image_features.view(np.prod(dims), dim) + hidden_states = self.pre_norm(image_features).view( + -1, self.hidden_size) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states.view(*dims, -1) + + +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + self.cache_position_embedding = dict() + self.cache_position_count = dict() + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) + + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding( + self, + embeddings: torch.Tensor, + height: int, + width: int, + is_after_patchify: bool = False, + ) -> torch.Tensor: + + num_positions = self.position_embedding.weight.shape[0] + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + dim = embeddings.shape[-1] + + if is_after_patchify: + new_height = height + new_width = width + else: + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, + sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bilinear", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, + 1).reshape(1, -1, dim) + return patch_pos_embed + + def fetch_position_embedding_lfu_cache(self, + embeddings, + h, + w, + max_cache: int = 20): + grid = (h, w) + if grid in self.cache_position_embedding: + self.cache_position_count[grid] += 1 + return self.cache_position_embedding[grid] + + if len(self.cache_position_embedding) >= max_cache: + min_hit_grid = min( + self.cache_position_count, + key=self.cache_position_count.get, + ) + self.cache_position_count.pop(min_hit_grid) + self.cache_position_embedding.pop(min_hit_grid) + + position_embedding = self.interpolate_pos_encoding( + embeddings, h, w, True) + self.cache_position_count[grid] = 1 + self.cache_position_embedding[grid] = position_embedding + return position_embedding + + def forward( + self, + pixel_values: torch.FloatTensor, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ]]] = None, + interpolate_pos_encoding=False, + ) -> torch.Tensor: + if pixel_values.dim() == 4: + pixel_values = pixel_values.unsqueeze(0) + if pixel_values.dim() == 5: + if position_ids is None: + raise ValueError( + "position_ids cannot be None when pixel_values.dim() is 5." + ) + ( + batch_size, + squence_len, + channel, + height, + width, + ) = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + embeddings = patch_embeds.flatten(-2).squeeze(-1) + + if interpolate_pos_encoding and image_grid_thw is not None: + start = 0 + tmp_embeddings = list() + for image_grid in image_grid_thw: + t, h, w = image_grid + end = start + t * h * w + image_embeddings = embeddings[start:end, :] + position_embedding = (self.interpolate_pos_encoding( + image_embeddings, h, w, True).squeeze(0).repeat(t, 1)) + image_embeddings = image_embeddings + position_embedding + tmp_embeddings.append(image_embeddings) + start = end + embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) + else: + embeddings = embeddings + self.packing_position_embedding( + position_ids) + return embeddings + else: + raise ValueError("Unsupported pixel_values dimension:" + f" {pixel_values.dim()}. Expected 4 or 5.") + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + + apply_rotary_emb = apply_rotary_emb_torch + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + + q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) + return q_embed, k_embed + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You + Need' paper.""" + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + hidden_size = config.hidden_size + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_attention_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + input_size=hidden_size, + output_size=hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + }: + raise RuntimeError(f"PaddleOCR-VL does not support \ + {self.attn_backend} backend now.") + + self.softmax_mode = 'fp32' if os.environ.get( + 'VLLM_FP32_SOFTMAX_VISION', 'false').lower() in ['true', '1' + ] else 'None' + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[list[torch.Tensor]] = None, + rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size, seq_length, embed_dim = hidden_states.shape + + qkv_states, _ = self.qkv_proj(hidden_states) + q, k, v = qkv_states.chunk(3, dim=-1) + + q = q.reshape(batch_size, seq_length, self.num_heads, self.head_dim) + k = k.reshape(batch_size, seq_length, self.num_heads, self.head_dim) + v = v.reshape(batch_size, seq_length, self.num_heads, self.head_dim) + + if rope_emb is not None: + cos, sin = rope_emb + q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin) + + if self.attn_backend == _Backend.FLASH_ATTN: + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + output = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + ) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA and is_hpu: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, + None, self.softmax_mode) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + + context_layer = rearrange(context_layer, + "b s h d -> b s (h d)").contiguous() + + output, _ = self.out_proj(context_layer) + return output + + +class SigLIPRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + self.rope_init() + + def rope_init(self): + inv_freq = 1.0 / (self.theta**( + torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange( + seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class SiglipMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + # Special handling for BNB and torchao quantization + if quant_config and quant_config.get_name() in [ + "bitsandbytes", "torchao" + ]: + quantizable = True + else: + # For other quantization, we require the hidden size to be a + # multiple of 64 + quantizable = (config.hidden_size % 64 == 0 + and config.intermediate_size % 64 == 0) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config if quantizable else None, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config if quantizable else None, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.self_attn = SiglipAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[list[torch.Tensor]] = None, + rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.FloatTensor]: + + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + rope_emb=rope_emb, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + return hidden_states + + +class SiglipEncoder(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + head_dim = embed_dim // num_heads + self.layers = nn.ModuleList([ + SiglipEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) for layer_idx in range(config.num_hidden_layers) + ]) + self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2) + + @staticmethod + def flatten_list(image_grid_thw): + tmp_image_grid_thw = list() + for image_grid in image_grid_thw: + if isinstance(image_grid, list): + tmp_image_grid_thw.extend(image_grid) + else: + tmp_image_grid_thw.append(image_grid) + return tmp_image_grid_thw + + def forward( + self, + inputs_embeds, + cu_seqlens: Optional[list[torch.Tensor]] = None, + image_grid_thw: Optional[list[Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ]]] = None, + height_position_ids: Optional[torch.Tensor] = None, + width_position_ids: Optional[torch.Tensor] = None, + ) -> BaseModelOutput: + device = inputs_embeds.device + hidden_states = inputs_embeds + + flatten_image_grid_thw = self.flatten_list(image_grid_thw) + + if width_position_ids is None or height_position_ids is None: + split_hids = list() + split_wids = list() + for t, h, w in flatten_image_grid_thw: + image_pids = torch.arange(t * h * w, device=device) % (h * w) + sample_hids = image_pids // w + sample_wids = image_pids % w + split_hids.append(sample_hids) + split_wids.append(sample_wids) + width_position_ids = torch.concat(split_wids, dim=0) + height_position_ids = torch.concat(split_hids, dim=0) + + pids = torch.stack( + [height_position_ids, width_position_ids], + dim=-1, + ) + max_grid_size = pids.max() + 1 + rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) + rope_emb = rope_emb_max_grid[pids].flatten(1) + rope_emb = rope_emb.repeat(1, 2) + rope_emb = (rope_emb.cos(), rope_emb.sin()) + + attn_cu_seqlens = cu_seqlens + hidden_states = inputs_embeds + + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + cu_seqlens=attn_cu_seqlens, + rope_emb=rope_emb, + ) + return hidden_states + + +class SiglipVisionTransformer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( + config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + pixel_values, + interpolate_pos_encoding: Optional[bool] = False, + position_ids: Optional[torch.Tensor] = None, + height_position_ids: Optional[torch.Tensor] = None, + width_position_ids: Optional[torch.Tensor] = None, + cu_seqlens: Optional[list[torch.Tensor]] = None, + image_grid_thw: Optional[list[Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ]]] = None, + ) -> BaseModelOutputWithPooling: + + hidden_states = self.embeddings( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + position_ids=position_ids, + image_grid_thw=image_grid_thw, + ) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + cu_seqlens=cu_seqlens, + image_grid_thw=image_grid_thw, + height_position_ids=height_position_ids, + width_position_ids=width_position_ids, + ) + + last_hidden_state = self.post_layernorm(last_hidden_state) + + sample_hidden_state = list() + if cu_seqlens is None: + raise ValueError("cu_seqlens cannot be None for " + "SiglipVisionTransformer output processing.") + for i in range(cu_seqlens.shape[0] - 1): + start = cu_seqlens[i] + end = cu_seqlens[i + 1] + tensor = last_hidden_state[:, start:end, :].squeeze(0) + sample_hidden_state.append(tensor) + + return sample_hidden_state + + +class SiglipVisionModel(nn.Module): + config_class = "PaddleOCRVisionConfig" + main_input_name = "pixel_values" + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.vision_model = SiglipVisionTransformer( + config, + quant_config=quant_config, + prefix=f"{prefix}.vision_model", + ) + self.quant_config = quant_config + + @property + def dtype(self) -> torch.dtype: + return self.vision_model.embeddings.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.vision_model.embeddings.patch_embedding.weight.device + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + interpolate_pos_encoding: bool = False, + position_ids: Optional[torch.Tensor] = None, + image_grid_thw: Optional[list[Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ]]] = None, + cu_seqlens: Optional[list[torch.Tensor]] = None, + ) -> BaseModelOutputWithPooling: + + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + position_ids=position_ids, + image_grid_thw=image_grid_thw, + cu_seqlens=cu_seqlens, + ) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "head.attention" in name or "head.layernorm" in name: + continue + if "head.mlp" in name or "head.probe" in name: + continue + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name)): + param = params_dict[scale_name] + weight_loader = getattr( + param, + "weight_loader", + default_weight_loader, + ) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for ( + param_name, + weight_name, + shard_id, + ) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, + "weight_loader", + default_weight_loader, + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@MULTIMODAL_REGISTRY.register_processor( + PaddleOCRVLMultiModalProcessor, + info=PaddleOCRVLProcessingInfo, + dummy_inputs=PaddleOCRVLDummyInputsBuilder, +) +@support_torch_compile( + # set dynamic_arg_dims to support mrope + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class PaddleOCRVLForConditionalGeneration(Ernie4_5_ForCausalLM, + SupportsMultiModal): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = self.config + + self.mlp_AR = Projector(config, config.vision_config) + self.visual = SiglipVisionModel(config=config.vision_config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + + for layer in self.model.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.rotary_emb.is_neox_style = True + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + @property + def language_model(self): + return self.model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + if intermediate_tensors is not None: + inputs_embeds = None + + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + return self.language_model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" + + raise ValueError("Only image modality is supported") + + def encode_image(self, pixel_values, image_grid_thw): + pixel_values = pixel_values.type(self.visual.dtype) + siglip_position_ids = list() + image_grid_hws = list() + cu_seqlens = [0] + + for idx, grid_thw in enumerate(image_grid_thw): + thw_tuple = tuple(grid_thw.detach().cpu().numpy().tolist()) + numel = np.prod(thw_tuple) + image_grid_hws.append(thw_tuple) + image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) + siglip_position_ids.append(image_position_ids) + cu_seqlens.append(cu_seqlens[-1] + numel) + + siglip_position_ids = torch.concat(siglip_position_ids, + dim=0).to(pixel_values.device) + cu_seqlens = torch.tensor(cu_seqlens, + dtype=torch.int32).to(pixel_values.device) + + vision_outputs = self.visual( + pixel_values=pixel_values, + image_grid_thw=image_grid_hws, + position_ids=siglip_position_ids, + interpolate_pos_encoding=True, + cu_seqlens=cu_seqlens, + ) + image_embeds = self.mlp_AR(vision_outputs, image_grid_thw) + + return image_embeds + + def get_multimodal_embeddings(self, **kwargs): + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None or image_grid_thw is None: + return None + + multimodal_embeddings = [] + for pv, ig in zip(pixel_values, image_grid_thw): + if pv is not None: + image_embeds = self.encode_image(pv, ig) + multimodal_embeddings += image_embeds + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None and len( + multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_id, + ) + + return inputs_embeds + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights) + return autoloaded_weights diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b52a750cfb4d..e5deedb5185d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -219,6 +219,8 @@ "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), "Ovis2_5": ("ovis2_5", "Ovis2_5"), + "PaddleOCRVLForConditionalGeneration": ("paddleocr_vl", + "PaddleOCRVLForConditionalGeneration"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 @@ -422,6 +424,8 @@ def register_model( raise TypeError(msg) if model_arch in self.models: + if model_arch == "PaddleOCRVLForConditionalGeneration": + return logger.warning( "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index fad57e9d2e89..31c4b812e12e 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -46,6 +46,9 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: plugins = dict[str, Callable[[], Any]]() for plugin in discovered_plugins: + if plugin.name == "register_paddlex_genai_models": + logger.info("Skipping plugin %s", plugin.name) + continue if allowed_plugins is None or plugin.name in allowed_plugins: if allowed_plugins is not None: log_level("Loading plugin %s", plugin.name) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index a4477868fed8..4a3becf691e0 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -410,7 +410,7 @@ def __init__(self, model, vllm_config, is_causal, sampler): if htorch.utils.internal.is_lazy(): if self.model_is_mrope and hasattr(self.model, 'visual') and \ model_config is not None and \ - model_config.model_type != "glm4v_moe": + model_config.model_type not in ("glm4v_moe", "paddleocr_vl"): logger.info("[Multimodal] Wrapping Visual Model") self.model.visual = htorch.hpu.wrap_in_hpu_graph( self.model.visual, disable_tensor_cache=True) @@ -729,6 +729,11 @@ def compute_input_embeddings_for_mrope_mm_optimized( self.model.get_multimodal_embeddings_v0(**kwargs) inputs_embeds = self.model.get_input_embeddings_v0( input_ids, multimodal_embeddings) + elif self.model.config.model_type == 'paddleocr_vl': + multimodal_embeddings = \ + self.model.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.model.get_input_embeddings( + input_ids, multimodal_embeddings) else: image_input = self.model._parse_and_validate_image_input( **kwargs) @@ -3089,6 +3094,66 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args, ) return seq_group + def create_dummy_paddleocr_multi_modal_seq_group_metadata( + self, group_id, num_patches, sampling_params, lora_request, + seq_len): + if not hasattr(self.get_model().config, "vision_config"): + raise ValueError("Expect paddleocr_vl model to have vision_config") + + if num_patches == UNSET_IMG_ARGS: + # Using the largest bucket + num_patches = self.get_model( + ).vision_buckets.multimodal_buckets[-1] + model_config = self.get_model().config + vision_config = model_config.vision_config + compression_ratio = model_config.compression_ratio + num_channels = vision_config.num_channels + image_size = vision_config.image_size + patch_size = vision_config.patch_size + + num_patches = (image_size // patch_size + 1)**2 + if num_patches == UNSET_IMG_ARGS: + # Using the largest bucket + num_patches = self.get_model( + ).vision_buckets.multimodal_buckets[-1] + + num_image_tokens = int(num_patches * (compression_ratio**2)) + + # for dummy input construction + image_token_id = 100295 + prompt_token_ids = [image_token_id] * min(seq_len, num_image_tokens) + prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 + placeholders_by_modality = { + 'image': + [PlaceholderRange(offset=0, length=len(prompt_token_ids))] + } + seq_data = SequenceData(prompt_token_ids_array) + + pixel_values = torch.randn(num_patches, num_channels, patch_size, + patch_size) + image_grid_thw = torch.tensor( + [[1, image_size // patch_size + 1, image_size // patch_size + 1]]) + + multi_modal_data = { + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + "image_num_patches": torch.Tensor([num_patches]), + "image_token_id": torch.tensor(image_token_id, dtype=torch.long), + } + multi_modal_data = MultiModalKwargs(multi_modal_data) + + seq_group = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=lora_request[group_id] if lora_request else None, + multi_modal_data=multi_modal_data, + multi_modal_placeholders=placeholders_by_modality, + ) + return seq_group + def create_dummy_seq_group_metadata(self, group_id, seq_len, @@ -3106,13 +3171,23 @@ def create_dummy_seq_group_metadata(self, computed_block_nums = None if is_prompt: if self.is_mm_run() and img_args is not None: - return self.create_dummy_multi_modal_seq_group_metadata( - group_id=group_id, - img_args=img_args, - sampling_params=sampling_params, - lora_request=lora_request, - seq_len=seq_len, - ) + if self.get_model().config.model_type == "paddleocr_vl": + return \ + self.create_dummy_paddleocr_multi_modal_seq_group_metadata( + group_id=group_id, + num_patches=img_args, + sampling_params=sampling_params, + lora_request=lora_request, + seq_len=seq_len, + ) + else: + return self.create_dummy_multi_modal_seq_group_metadata( + group_id=group_id, + img_args=img_args, + sampling_params=sampling_params, + lora_request=lora_request, + seq_len=seq_len, + ) else: input_len = seq_len output_len = 0