diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index a9e162e3f..01e3a7268 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -1,69 +1,21 @@ import os -import re import json import torch import torch.nn.functional as F from PIL import Image -from typing import Any, Dict, List, Optional, Tuple, Union -from torchvision import transforms as T -from torchvision.transforms.functional import InterpolationMode -from transformers import AutoModel, AutoTokenizer +from typing import List, Optional from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_utils import PreTrainedModel import torch.nn as nn -from torch.nn import LayerNorm from transformers.activations import ACT2FN -import math -from lightllm.models.qwen2_vl.vision_process import get_image, Qwen2VLImageProcessor -from transformers import AutoProcessor +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from safetensors import safe_open -from transformers.utils import TensorType -from lightllm.server.multimodal_params import MultimodalParams, ImageItem +from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding +from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - -# adapted from -# https://github.com/huggingface/transformers/blob/ -# be37d34f44ff1bc928e59ffb8a30adecab8835a8/src -# /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1 -class Qwen2_5_VLVisionConfig(PretrainedConfig): - model_type = "qwen2_5_vl" - - def __init__( - self, - depth=32, - hidden_size=3584, - hidden_act="silu", - intermediate_size=3420, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - tokens_per_second=4, - window_size=112, - out_hidden_size=3584, - fullatt_block_indexes=[7, 15, 23, 31], - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size - self.tokens_per_second = tokens_per_second - self.window_size = window_size - self.fullatt_block_indexes = fullatt_block_indexes - self.out_hidden_size = out_hidden_size +from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton class Qwen2RMSNorm(nn.Module): @@ -76,11 +28,7 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return rms_norm(hidden_states, self.weight, eps=self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" @@ -104,27 +52,6 @@ def forward(self, hidden_state): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - orig_q_dtype = q.dtype - orig_k_dtype = k.dtype - q, k = q.float(), k.float() - cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - q_embed = q_embed.to(orig_q_dtype) - k_embed = k_embed.to(orig_k_dtype) - return q_embed, k_embed - - class Qwen2_5_VLVisionFlashAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() @@ -137,21 +64,15 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + max_seqlen: int = 0, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + q = apply_rotary_pos_emb_triton(q, rotary_cos, rotary_sin) + k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin) - cu_seqlens = cu_seqlens.to(q.device, torch.int32) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) @@ -183,14 +104,16 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + max_seqlen: int = 0, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - position_embeddings=position_embeddings, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -215,6 +138,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2_5_VisionTransformerPretrainedModel(nn.Module): def __init__( self, + kvargs, depth=32, hidden_size=3584, hidden_act="silu", @@ -231,6 +155,8 @@ def __init__( **kwargs, ): super().__init__() + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth self.hidden_size = hidden_size @@ -278,43 +204,43 @@ def __init__( self.gradient_checkpointing = False - self.device = self.get_device() - self.dtype = self.get_dtype() - - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.down_proj.weight.dtype + processor_config_path = os.path.join(self.weight_dir, "preprocessor_config.json") + with open(processor_config_path, "r") as f: + processor_config_dict = json.load(f) + self.processor = Qwen2VLImageProcessor(**processor_config_dict) - def get_device(self) -> torch.device: - return self.blocks[0].mlp.down_proj.weight.device + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return def rot_pos_emb(self, grid_thw): pos_ids = [] - for t, h, w in grid_thw: + s = self.spatial_merge_size + for _, h, w in grid_thw: + pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb + cos_full, sin_full = self.rotary_pos_emb(max_grid_size) + cos = cos_full[pos_ids].flatten(1) + sin = sin_full[pos_ids].flatten(1) + return cos, sin def get_window_index(self, grid_thw): window_index: list = [] @@ -359,56 +285,47 @@ def get_window_index(self, grid_thw): def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) + rotary_cos = rotary_cos.to("cuda", non_blocking=True) + rotary_sin = rotary_sin.to("cuda", non_blocking=True) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( cu_window_seqlens, device=hidden_states.device, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to("cuda", non_blocking=True) + max_window_seqlen = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() seq_len, _ = hidden_states.size() - hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - hidden_states = hidden_states[window_index, :, :] - hidden_states = hidden_states.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same - # dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 - # for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + pos_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states.reshape(pos_shape)[window_index].view(seq_len, -1) + rotary_cos = rotary_cos.reshape(pos_shape)[window_index].view(seq_len, -1) + rotary_sin = rotary_sin.reshape(pos_shape)[window_index].view(seq_len, -1) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens + max_seqlen_now = max_seqlen else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, - hidden_states, - cu_seqlens_now, - None, - position_embeddings, - ) - else: - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - position_embeddings=position_embeddings, - ) + max_seqlen_now = max_window_seqlen + + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + max_seqlen=max_seqlen_now, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) @@ -416,12 +333,23 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch. return hidden_states - def load_model(self, weight_dir): + def load_image(self, img: List[ImageItem]): + pixel_values = None + if isinstance(img, ImageItem): + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + elif isinstance(img, dict): + image_data = read_shm(get_shm_name_data(img["uuid"])) + image_data = Image.open(BytesIO(image_data)) + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + return pixel_values.to(dtype=self.data_type), image_grid_thw - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + def load_model(self, weight_dir): bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: @@ -455,10 +383,8 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: @@ -476,10 +402,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.cuda().to(dtype=torch.float32) - image_grid_thw = grid_thw.cuda() + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) - pixel_values = pixel_values.type(self.get_dtype()) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 70c8bf32e..6d179e6f9 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -31,6 +31,10 @@ class QWen2VLTokenizer(BaseMultiModalTokenizer): def __init__(self, tokenizer=None, image_processor=None, **kwargs): super().__init__(tokenizer) self.image_processor = image_processor + self.min_pixel = self.image_processor.min_pixels + self.max_pixel = self.image_processor.max_pixels + self.patch_size = self.image_processor.patch_size + self.merge_size = self.image_processor.merge_size self.image_start_id = kwargs["model_cfg"]["vision_start_token_id"] self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"] self.image_token_id = kwargs["model_cfg"]["image_token_id"] @@ -46,17 +50,13 @@ def init_audioitem_extral_params( raise NotImplementedError def get_image_token_length(self, img: ImageItem): - width = img.image_w - height = img.image_h - resized_height, resized_width = smart_resize(height=height, width=width) - self.patch_size = self.image_processor.image_processor.patch_size - self.merge_size = self.image_processor.image_processor.merge_size - grid_t = 1 + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, width=width, min_pixels=self.min_pixel, max_pixels=self.max_pixel + ) grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - merge_length = self.merge_size ** 2 - self.token_num = (grid_t * grid_h * grid_w) // merge_length - self.image_length = self.token_num - return self.image_length + token_num = (grid_h * grid_w) // (self.merge_size ** 2) + return token_num def get_audio_token_length(self, audio: AudioItem): raise NotImplementedError diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 4a9012518..68e161737 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -19,76 +19,26 @@ # limitations under the License. import os -import re import json import torch import torch.nn.functional as F from PIL import Image -from typing import List, Union +from typing import List from torchvision import transforms as T -from torchvision.transforms.functional import InterpolationMode -from transformers import AutoModel, AutoTokenizer from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging -from transformers.modeling_utils import PreTrainedModel import torch.nn as nn from torch.nn import LayerNorm from transformers.activations import ACT2FN -import math -from .vision_process import get_image -from transformers import AutoProcessor from safetensors import safe_open -from transformers.utils import TensorType -from lightllm.server.multimodal_params import MultimodalParams, ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor +from lightllm.server.multimodal_params import ImageItem +from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - -from transformers.utils import is_flash_attn_2_available - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - - from transformers.modeling_flash_attention_utils import _flash_attention_forward -else: - flash_attn_varlen_func = None - - -logger = logging.get_logger(__name__) +from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton # adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class Qwen2VLVisionConfig(PretrainedConfig): - model_type = "qwen2_vl" - - def __init__( - self, - depth=32, - embed_dim=1280, - hidden_size=3584, - hidden_act="quick_gelu", - mlp_ratio=4, - num_heads=16, - in_channels=3, - patch_size=14, - spatial_merge_size=2, - temporal_patch_size=2, - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.embed_dim = embed_dim - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.mlp_ratio = mlp_ratio - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size class PatchEmbed(nn.Module): @@ -109,11 +59,10 @@ def __init__( self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ).cuda() - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + ) + hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states @@ -133,38 +82,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - orig_dtype = tensor.dtype - tensor = tensor.float() - cos = freqs.cos() - sin = freqs.sin() - cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() - output = (tensor * cos) + (rotate_half(tensor) * sin) - output = output.to(orig_dtype) - return output - - -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs - - class VisionMlp(nn.Module): def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: super().__init__() @@ -176,41 +93,34 @@ def forward(self, x) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) -class VisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: +# copy form vllm +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads # 初始化 head_dim,每个头的维度 - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) + self.dim = dim + self.theta = theta + self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self._seq_len_cached = 0 + self._freqs_cos_cached = None + self._freqs_sin_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim) + ) + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cos_cached = freqs.cos() + self._freqs_sin_cached = freqs.sin() - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cos_cached[:seqlen], self._freqs_sin_cached[:seqlen] -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class VisionFlashAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() @@ -219,17 +129,18 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.proj = nn.Linear(dim, dim) def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int = 0, + rotary_cos: torch.Tensor = None, + rotary_sin: torch.Tensor = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb) - q = q.squeeze(0) - k = k.squeeze(0) + q = apply_rotary_pos_emb_triton(q, rotary_cos, rotary_sin) + k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin) - cu_seqlens = cu_seqlens.to(q.device, torch.int32) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) @@ -238,8 +149,6 @@ def forward( return attn_output -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class Qwen2VLVisionBlock(nn.Module): def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None: super().__init__() @@ -250,19 +159,22 @@ def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None: self.attn = VisionFlashAttention(embed_dim, num_heads=num_heads) self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act) - def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class Qwen2VisionTransformerPretrainedModel(nn.Module): def __init__( self, + kvargs, depth=32, embed_dim=1280, hidden_size=3584, @@ -276,6 +188,8 @@ def __init__( **kwargs, ): super().__init__() + self.data_type = kvargs.get("data_type", "bfloat16") + self.depth = depth self.embed_dim = embed_dim self.hidden_size = hidden_size @@ -295,7 +209,7 @@ def __init__( ) head_dim = self.embed_dim // self.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda() self.blocks = nn.ModuleList( [ @@ -305,64 +219,20 @@ def __init__( ) self.merger = PatchMerger(dim=self.hidden_size, context_dim=self.embed_dim) - self.device = self.get_device() - self.dtype = self.get_dtype() - - def get_dtype(self) -> torch.dtype: - return self.blocks[0].mlp.fc2.weight.dtype - - def get_device(self) -> torch.device: - return self.blocks[0].mlp.fc2.weight.device - - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).type(torch.float32) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: - hidden_states = hidden_states.to( - dtype=self.get_dtype(), - device=self.device, - ) - grid_thw = grid_thw.to( - dtype=torch.int32, - device=self.device, - ) - - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32 - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - for blk in self.blocks: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) - return self.merger(hidden_states) + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return def load_model(self, weight_dir): @@ -379,7 +249,6 @@ def load_model(self, weight_dir): for k, v in f.items(): if "visual" in k: weight_dict[k[len("visual.") :]] = v - else: hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] weight_dict = {} @@ -391,22 +260,59 @@ def load_model(self, weight_dir): self.load_state_dict(weight_dict) + def rot_pos_emb(self, grid_thw): + pos_ids = [] + s = self.spatial_merge_size + for _, h, w in grid_thw: + pos_shape = (h // s, s, w // s, s) + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + cos_full, sin_full = self.rotary_pos_emb(max_grid_size) + cos = cos_full[pos_ids].flatten(1) + sin = sin_full[pos_ids].flatten(1) + return cos, sin + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) + rotary_cos = rotary_cos.to("cuda", non_blocking=True) + rotary_sin = rotary_sin.to("cuda", non_blocking=True) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + cu_seqlens = cu_seqlens.to("cuda", non_blocking=True) + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + ) + return self.merger(hidden_states) + def encode(self, images: List[ImageItem]): img_tensors = [] valid_ids = [] valid_id = 0 img_grids = [] uuids = [] - for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image_data) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: @@ -424,10 +330,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.cuda().to(dtype=torch.float32) - image_grid_thw = grid_thw.cuda() + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_thw = grid_thw.to("cuda", non_blocking=True) - pixel_values = pixel_values.type(self.get_dtype()) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py new file mode 100644 index 000000000..07e7c8b3f --- /dev/null +++ b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py @@ -0,0 +1,89 @@ +import math +import torch +import triton +import triton.language as tl + + +@triton.jit +def rotary_kernel( + inp_ptr, + cos_ptr, + sin_ptr, + out_ptr, + stride_l, + stride_h, + stride_d, + stride_cos_l, + stride_cos_d, + stride_sin_l, + stride_sin_d, + D: tl.constexpr, + HALF_D: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_h = tl.program_id(0).to(tl.int64) + pid_l = tl.program_id(1).to(tl.int64) + pid_blk = tl.program_id(2).to(tl.int64) + + offs_d = tl.arange(0, BLOCK_D) + d = pid_blk * BLOCK_D + offs_d + mask = d < D + + base = pid_l * stride_l + pid_h * stride_h + + in_ptr = inp_ptr + base + d * stride_d + cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d + sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d + + x = tl.load(in_ptr, mask=mask) + cos = tl.load(cos_ptr_, mask=mask) + sin = tl.load(sin_ptr_, mask=mask) + + partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D) + partner_ptr = inp_ptr + base + partner_d * stride_d + partner_val = tl.load(partner_ptr, mask=mask) + rotated = tl.where(d < HALF_D, -partner_val, partner_val) + + y = x * cos + rotated * sin + + out_ptr_ = out_ptr + base + d + tl.store(out_ptr_, y, mask=mask) + + +def apply_rotary_pos_emb_triton( + tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128 +) -> torch.Tensor: + assert tensor.is_cuda and cos.is_cuda and sin.is_cuda + assert cos.is_contiguous() and sin.is_contiguous() + if tensor.ndim != 3: + raise RuntimeError("tensor shape should be [L, H, D]") + orig_dtype = tensor.dtype + x = tensor.float() + + cos = cos.repeat(1, 2).view(cos.size(0), -1).contiguous().float() + sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float() + + L, H, D = x.shape + HALF_D = D // 2 + y = torch.empty_like(x) + + grid = (H, L, triton.cdiv(D, BLOCK_D)) + + rotary_kernel[grid]( + inp_ptr=x, + cos_ptr=cos, + sin_ptr=sin, + out_ptr=y, + stride_l=x.stride(0), + stride_h=x.stride(1), + stride_d=x.stride(2), + stride_cos_l=cos.stride(0), + stride_cos_d=cos.stride(1), + stride_sin_l=sin.stride(0), + stride_sin_d=sin.stride(1), + D=D, + HALF_D=HALF_D, + BLOCK_D=BLOCK_D, + ) + + return y.to(orig_dtype) diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 9366ca747..692f3aac3 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -1,38 +1,11 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations - -import base64 -from io import BytesIO import math -from typing import Dict, List, Optional, Union -import numpy as np -import requests import torch +import numpy as np from PIL import Image -from torchvision import io, transforms -from torchvision.transforms import InterpolationMode -from transformers import AutoProcessor +from typing import List, Optional, Union, Tuple -from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_processing_utils import BaseImageProcessor from transformers.image_transforms import ( convert_to_rgb, resize, @@ -42,190 +15,48 @@ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension, - ImageInput, PILImageResampling, get_image_size, infer_channel_dimension_format, - is_scaled_image, - is_valid_image, - make_list_of_images, to_numpy_array, - valid_images, - validate_preprocess_arguments, ) -from transformers.video_utils import VideoInput -from transformers.utils import TensorType, is_vision_available, logging - -logger = logging.get_logger(__name__) IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28 MAX_RATIO = 200 - -VIDEO_MIN_PIXELS = 128 * 28 * 28 -VIDEO_MAX_PIXELS = 768 * 28 * 28 -VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 FRAME_FACTOR = 2 FPS = 2.0 FPS_MIN_FRAMES = 4 FPS_MAX_FRAMES = 768 -def make_batched_images(images) -> List[List[ImageInput]]: - """ - Accepts images in list or nested list format, and makes a list of images for preprocessing. - - Args: - images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): - The input image. - - Returns: - list: A list of images. - """ - if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): - return [img for img_list in images for img in img_list] - - elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): - return images - - elif is_valid_image(images): - return [images] - - raise ValueError(f"Could not make batched images from {images}") - - -# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos -def make_batched_videos(videos) -> List[VideoInput]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if isinstance(videos[0], Image.Image): - return [videos] - elif len(videos[0].shape) == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos) and len(videos.shape) == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - -def round_by_factor(number: int, factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - - def smart_resize( height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS ) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - 3. The aspect ratio of the image is maintained as closely as possible. - """ if max(height, width) / min(height, width) > MAX_RATIO: raise ValueError( f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) + h_bar = max(factor, round(height / factor) * factor) + w_bar = max(factor, round(width / factor) * factor) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar -def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: - if "image" in ele: - image = ele["image"] - else: - image = ele["image_url"] - image_obj = None - if isinstance(image, Image.Image): - image_obj = image - elif image.startswith("http://") or image.startswith("https://"): - image_obj = Image.open(requests.get(image, stream=True).raw) - elif image.startswith("file://"): - image_obj = Image.open(image[7:]) - elif image.startswith("data:image"): - data = image.split(";", 1)[1] - if data.startswith("base64,"): - data = base64.b64decode(data[7:]) - image_obj = Image.open(BytesIO(data)) - else: - image_obj = Image.open(image) - if image_obj is None: - raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") - image = image_obj.convert("RGB") - ## resize - if "resized_height" in ele and "resized_width" in ele: - resized_height, resized_width = smart_resize( - ele["resized_height"], - ele["resized_width"], - factor=size_factor, - ) - else: - width, height = image.size - min_pixels = ele.get("min_pixels", MIN_PIXELS) - max_pixels = ele.get("max_pixels", MAX_PIXELS) - resized_height, resized_width = smart_resize( - height, - width, - factor=size_factor, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - image = image.resize((resized_width, resized_height)) - return image - - -def get_image(image_file: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: - image_obj = None - - if isinstance(image_file, Image.Image): - image_obj = image_file - elif image_file.startswith("http://") or image_file.startswith("https://"): - image_obj = Image.open(requests.get(image_file, stream=True).raw) - elif image_file.startswith("file://"): - image_obj = Image.open(image_file[7:]) - elif image_file.startswith("data:image"): - data = image_file.split(";", 1)[1] - if data.startswith("base64,"): - data = base64.b64decode(data[7:]) - image_obj = Image.open(BytesIO(data)) - else: - image_obj = Image.open(image_file) - - if image_obj is None: - raise ValueError("Unrecognized image input. Supports local path, http url, base64, and PIL.Image.") +def resize_image(image_file: Image.Image, size_factor: int = IMAGE_FACTOR) -> tuple[Image.Image, int, int]: - image = image_obj.convert("RGB") - - # 获取原始宽度和高度 + image = image_file.convert("RGB") width, height = image.size - # 使用默认的最小像素和最大像素调整大小 resized_height, resized_width = smart_resize( height, width, @@ -233,56 +64,12 @@ def get_image(image_file: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS, ) - - # 调整图片大小 image = image.resize((resized_width, resized_height)) return image -def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: - vision_infos = [] - if isinstance(conversations[0], dict): - conversations = [conversations] - for conversation in conversations: - for message in conversation: - if isinstance(message["content"], list): - for ele in message["content"]: - if ( - "image" in ele - or "image_url" in ele - or "video" in ele - or ele["type"] in ("image", "image_url", "video") - ): - vision_infos.append(ele) - return vision_infos - - -def process_vision_info( - conversations: list[dict] | list[list[dict]], -) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]: - vision_infos = extract_vision_info(conversations) - ## Read images or videos - image_inputs = [] - # video_inputs = [] - for vision_info in vision_infos: - if "image" in vision_info or "image_url" in vision_info: - image_inputs.append(fetch_image(vision_info)) - # elif "video" in vision_info: - # video_inputs.append(fetch_video(vision_info)) - else: - raise ValueError("image, image_url or video should in content.") - if len(image_inputs) == 0: - image_inputs = None - return image_inputs - - -# adapted from -# transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py class Qwen2VLImageProcessor(BaseImageProcessor): - - model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"] - def __init__( self, do_resize: bool = True, @@ -306,6 +93,7 @@ def __init__( self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.min_pixels = min_pixels @@ -313,70 +101,42 @@ def __init__( self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size - self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} - self.do_convert_rgb = do_convert_rgb - - def _preprocess( - self, - images: Union[ImageInput, VideoInput], - do_resize: bool = None, - resample: PILImageResampling = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: bool = None, - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - ): - - images = make_list_of_images(images) + self.data_format = ChannelDimension.FIRST - if do_convert_rgb: - images = [convert_to_rgb(image) for image in images] + def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: + if self.do_convert_rgb: + image = convert_to_rgb(image) + image = to_numpy_array(image) + input_data_format = infer_channel_dimension_format(image) + height, width = get_image_size(image, channel_dim=input_data_format) - # All transformations expect numpy arrays. - images = [to_numpy_array(image) for image in images] - - if is_scaled_image(images[0]) and do_rescale: - logger.warning_once( - "It looks like you are trying to rescale already rescaled images. If the input" - " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + resized_height, resized_width = height, width + if self.do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, size=(resized_height, resized_width), resample=self.resample, input_data_format=input_data_format ) - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images[0]) - height, width = get_image_size(images[0], channel_dim=input_data_format) - resized_height, resized_width = height, width - processed_images = [] - for image in images: - if do_resize: - resized_height, resized_width = smart_resize( - height, - width, - factor=self.patch_size * self.merge_size, - min_pixels=self.min_pixels, - max_pixels=self.max_pixels, - ) - image = resize( - image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format - ) + if self.do_rescale: + image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format) - if do_rescale: - image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + if self.do_normalize: + image = self.normalize( + image=image, mean=self.image_mean, std=self.image_std, input_data_format=input_data_format + ) - if do_normalize: - image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + image = to_channel_dimension_format(image, self.data_format, input_channel_dim=input_data_format) - image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - processed_images.append(image) + patches = np.array([image]) - patches = np.array(processed_images) - if data_format == ChannelDimension.LAST: - patches = patches.transpose(0, 3, 1, 2) if patches.shape[0] == 1: + # why to copy image 2 times. use self.temporal_patch_size = 2. patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) channel = patches.shape[1] grid_t = patches.shape[0] // self.temporal_patch_size @@ -396,100 +156,8 @@ def _preprocess( flatten_patches = patches.reshape( grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size ) + image_grid_thw = (grid_t, grid_h, grid_w) + pixel_values = torch.as_tensor(flatten_patches) + grid_thw = torch.as_tensor([image_grid_thw]) - return flatten_patches, (grid_t, grid_h, grid_w) - - def preprocess( - self, - images: ImageInput, - videos: VideoInput = None, - do_resize: bool = None, - size: Dict[str, int] = None, - resample: PILImageResampling = None, - do_rescale: bool = None, - rescale_factor: float = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - do_convert_rgb: bool = None, - return_tensors: Optional[Union[str, TensorType]] = "pt", - data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - ): - - do_resize = do_resize if do_resize is not None else self.do_resize - size = size if size is not None else self.size - resample = resample if resample is not None else self.resample - do_rescale = do_rescale if do_rescale is not None else self.do_rescale - rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor - do_normalize = do_normalize if do_normalize is not None else self.do_normalize - image_mean = image_mean if image_mean is not None else self.image_mean - image_std = image_std if image_std is not None else self.image_std - do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb - - if images is not None: - images = make_batched_images(images) - if videos is not None: - videos = make_batched_videos(videos) - - if images is not None and not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - - validate_preprocess_arguments( - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - do_resize=do_resize, - size=size, - resample=resample, - ) - - if images is not None: - pixel_values, vision_grid_thws = [], [] - for image in images: - patches, image_grid_thw = self._preprocess( - image, - do_resize=do_resize, - resample=resample, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - data_format=data_format, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, - ) - pixel_values.extend(patches) - vision_grid_thws.append(image_grid_thw) - pixel_values = np.array(pixel_values) - vision_grid_thws = np.array(vision_grid_thws) - data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} - - if videos is not None: - pixel_values, vision_grid_thws = [], [] - for images in videos: - patches, video_grid_thw = self._preprocess( - images, - do_resize=do_resize, - resample=resample, - do_rescale=do_rescale, - rescale_factor=rescale_factor, - do_normalize=do_normalize, - image_mean=image_mean, - image_std=image_std, - data_format=data_format, - do_convert_rgb=do_convert_rgb, - input_data_format=input_data_format, - ) - pixel_values.extend(patches) - vision_grid_thws.append(video_grid_thw) - pixel_values = np.array(pixel_values) - vision_grid_thws = np.array(vision_grid_thws) - data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws} - - return BatchFeature(data=data, tensor_type=return_tensors) + return pixel_values, grid_thw diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 978b9bd17..9deaf0857 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -16,7 +16,7 @@ from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.server.multimodal_params import ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, get_image +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image def add_split_tokens(image_features, image_newline_embed, image_new_embed): @@ -253,10 +253,9 @@ def encode(self, images: List[ImageItem]): uuids.append(img.uuid) image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) - image_data = get_image(image_data) - image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) - image_grid_thw = image_inputs["image_grid_thw"] + image_data = resize_image(image_data) + pixel_values, image_grid_thw = self.processor.preprocess(image=image_data) + pixel_values = pixel_values.to(dtype=torch.bfloat16) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index ab3770a36..fcb6fbb5d 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -215,7 +215,7 @@ def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen): 统一的 Flash Attention 接口。如果 sgl_kernel 存在, 则使用 sgl_kernel里的接口,否则使用 Triton 版本。 """ - if _flash_attn_v3_available and is_hopper() and False: + if _flash_attn_v3_available and is_hopper(): flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen) else: _flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 7907380fd..1f10aa5ec 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -31,7 +31,6 @@ from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer - # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" @@ -89,8 +88,10 @@ def get_tokenizer( elif model_type in ["qwen2_vl", "qwen2_5_vl"] and "vision_config" in model_cfg: from transformers import AutoProcessor - image_processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen2VLTokenizer(tokenizer=tokenizer, image_processor=image_processor, model_cfg=model_cfg) + processor = AutoProcessor.from_pretrained(tokenizer_name) + tokenizer = QWen2VLTokenizer( + tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + ) elif model_type == "internvl_chat": tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d2d45f2fd..a25065e42 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -45,25 +45,29 @@ def exposed_init_model(self, kvargs): model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) try: + kvargs = { + "weight_dir": weight_dir, + "data_type": self.data_type, + "quant_type": kvargs["quant_type"], + "quant_cfg": kvargs["quant_cfg"], + "max_batch_size": kvargs["max_batch_size"], + } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() elif self.model_type == "qwen2_vl": - self.model = Qwen2VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() + self.model = ( + Qwen2VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() + ) elif self.model_type == "qwen2_5_vl": - self.model = Qwen2_5_VisionTransformerPretrainedModel(**model_cfg["vision_config"]).eval().bfloat16() + self.model = ( + Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() + ) elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration": self.model = TarsierVisionTransformerPretrainedModel(**model_cfg).eval().bfloat16() elif self.model_type == "llava": self.model = LlavaVisionModel() elif self.model_type == "internvl_chat": - kvargs = { - "weight_dir": weight_dir, - "data_type": self.data_type, - "quant_type": kvargs["quant_type"], - "quant_cfg": kvargs["quant_cfg"], - "max_batch_size": kvargs["max_batch_size"], - } self.model = VisionTransformer(kvargs) # self.model = InternVLVisionModel() elif self.model_type == "gemma3": diff --git a/unit_tests/models/qwen2_vl/test_rotary_pos_emb.py b/unit_tests/models/qwen2_vl/test_rotary_pos_emb.py new file mode 100644 index 000000000..34e41578a --- /dev/null +++ b/unit_tests/models/qwen2_vl/test_rotary_pos_emb.py @@ -0,0 +1,55 @@ +import math +import torch +import pytest + +from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +@pytest.mark.parametrize( + "shape", + [ + (1296, 64, 80), + (1024, 2, 192), + (1024, 1, 256), + (1024, 3, 160), + ], +) +def test_triton_matches_reference(shape): + L, H, D = shape + assert D % 2 == 0 + + torch.manual_seed(0) + + freqs = torch.randn(L, D // 2, device="cuda", dtype=torch.bfloat16) + cos = freqs.cos() + sin = freqs.sin() + + tensor = torch.randn(L, H, D, device="cuda", dtype=torch.bfloat16) + + ref = apply_rotary_pos_emb_vision(tensor.unsqueeze(0), cos, sin).squeeze(0) + out = apply_rotary_pos_emb_triton(tensor, cos, sin) + + assert out.dtype == tensor.dtype, "输出 dtype 应与输入一致" + assert out.shape == tensor.shape, "输出形状应与输入一致" + assert torch.allclose(out, ref, rtol=1e-2, atol=1e-2), "Triton 与参考实现不一致" + + +if __name__ == "__main__": + pytest.main()