Skip to content

Commit 2497407

Browse files
SangChengCliujiacheng
authored andcommitted
0811-fix-qwen2-5
1 parent 88577b3 commit 2497407

File tree

4 files changed

+110
-156
lines changed

4 files changed

+110
-156
lines changed

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 80 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -24,48 +24,12 @@
2424
from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding
2525
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
2626
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
27+
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton
2728

2829
# adapted from
2930
# https://github.com/huggingface/transformers/blob/
3031
# be37d34f44ff1bc928e59ffb8a30adecab8835a8/src
3132
# /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1
32-
class Qwen2_5_VLVisionConfig(PretrainedConfig):
33-
model_type = "qwen2_5_vl"
34-
35-
def __init__(
36-
self,
37-
depth=32,
38-
hidden_size=3584,
39-
hidden_act="silu",
40-
intermediate_size=3420,
41-
num_heads=16,
42-
in_channels=3,
43-
patch_size=14,
44-
spatial_merge_size=2,
45-
temporal_patch_size=2,
46-
tokens_per_second=4,
47-
window_size=112,
48-
out_hidden_size=3584,
49-
fullatt_block_indexes=[7, 15, 23, 31],
50-
**kwargs,
51-
):
52-
super().__init__(**kwargs)
53-
54-
self.depth = depth
55-
self.hidden_size = hidden_size
56-
self.hidden_act = hidden_act
57-
self.intermediate_size = intermediate_size
58-
self.num_heads = num_heads
59-
self.in_channels = in_channels
60-
self.patch_size = patch_size
61-
self.spatial_merge_size = spatial_merge_size
62-
self.temporal_patch_size = temporal_patch_size
63-
self.tokens_per_second = tokens_per_second
64-
self.window_size = window_size
65-
self.fullatt_block_indexes = fullatt_block_indexes
66-
self.out_hidden_size = out_hidden_size
67-
68-
6933
class Qwen2RMSNorm(nn.Module):
7034
def __init__(self, hidden_size, eps=1e-6):
7135
"""
@@ -104,54 +68,46 @@ def forward(self, hidden_state):
10468
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
10569

10670

107-
def rotate_half(x):
108-
"""Rotates half the hidden dims of the input."""
109-
x1 = x[..., : x.shape[-1] // 2]
110-
x2 = x[..., x.shape[-1] // 2 :]
111-
return torch.cat((-x2, x1), dim=-1)
112-
113-
114-
def apply_rotary_pos_emb_vision(
115-
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
116-
) -> Tuple[torch.Tensor, torch.Tensor]:
117-
orig_q_dtype = q.dtype
118-
orig_k_dtype = k.dtype
119-
q, k = q.float(), k.float()
120-
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
121-
q_embed = (q * cos) + (rotate_half(q) * sin)
122-
k_embed = (k * cos) + (rotate_half(k) * sin)
123-
q_embed = q_embed.to(orig_q_dtype)
124-
k_embed = k_embed.to(orig_k_dtype)
125-
return q_embed, k_embed
126-
127-
12871
class Qwen2_5_VLVisionFlashAttention(nn.Module):
12972
def __init__(self, dim: int, num_heads: int = 16) -> None:
13073
super().__init__()
13174
self.num_heads = num_heads
13275
self.head_dim = dim // num_heads
13376
self.qkv = nn.Linear(dim, dim * 3, bias=True)
13477
self.proj = nn.Linear(dim, dim)
78+
try:
79+
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
80+
81+
self.has_vllm = True
82+
self.apply_rotary_emb = apply_rotary_emb
83+
except ImportError:
84+
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")
85+
self.has_vllm = False
86+
self.apply_rotary_emb = apply_rotary_pos_emb_triton
87+
88+
def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
89+
t_ = t.float()
90+
cos = freqs.cos()
91+
sin = freqs.sin()
92+
output = self.apply_rotary_emb(t_, cos, sin).type_as(t)
93+
return output
13594

13695
def forward(
13796
self,
13897
hidden_states: torch.Tensor,
13998
cu_seqlens: torch.Tensor,
99+
max_seqlen: int = 0,
140100
rotary_pos_emb: Optional[torch.Tensor] = None,
141-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
142101
) -> torch.Tensor:
143102
seq_length = hidden_states.shape[0]
144103
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
145-
if position_embeddings is None:
146-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
147-
cos = emb.cos()
148-
sin = emb.sin()
149-
else:
150-
cos, sin = position_embeddings
151-
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
104+
# if position_embeddings is None:
105+
# position_embeddings = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
106+
q = self.apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
107+
k = self.apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
108+
q = q.squeeze(0)
109+
k = k.squeeze(0)
152110

153-
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
154-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
155111
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
156112
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
157113
attn_output = attn_output.reshape(seq_length, -1)
@@ -183,14 +139,14 @@ def forward(
183139
self,
184140
hidden_states: torch.Tensor,
185141
cu_seqlens: torch.Tensor,
142+
max_seqlen: int = 0,
186143
rotary_pos_emb: Optional[torch.Tensor] = None,
187-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
188144
) -> torch.Tensor:
189145
hidden_states = hidden_states + self.attn(
190146
self.norm1(hidden_states),
191147
cu_seqlens=cu_seqlens,
148+
max_seqlen=max_seqlen,
192149
rotary_pos_emb=rotary_pos_emb,
193-
position_embeddings=position_embeddings,
194150
)
195151
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
196152
return hidden_states
@@ -215,7 +171,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
215171
class Qwen2_5VLTransformer(nn.Module):
216172
def __init__(
217173
self,
218-
weight_dir,
174+
kvargs,
219175
depth=32,
220176
hidden_size=3584,
221177
hidden_act="silu",
@@ -232,7 +188,13 @@ def __init__(
232188
**kwargs,
233189
):
234190
super().__init__()
235-
191+
self.weight_dir = kvargs["weight_dir"]
192+
self.data_type = kvargs.get("data_type", "bfloat16")
193+
# self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])]
194+
# self.weight_dict = kvargs.get("weight_dict", None)
195+
# self.quant_type = kvargs.get("quant_type", None)
196+
# self.quant_cfg_path = kvargs.get("quant_cfg", None)
197+
# self.max_batch_size = kvargs.get("max_batch_size", 1)
236198
self.depth = depth
237199
self.hidden_size = hidden_size
238200
self.hidden_act = hidden_act
@@ -279,46 +241,42 @@ def __init__(
279241

280242
self.gradient_checkpointing = False
281243

282-
processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
244+
processor_config_path = os.path.join(self.weight_dir, "preprocessor_config.json")
283245
with open(processor_config_path, "r") as f:
284246
processor_config_dict = json.load(f)
285247
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
286248

287-
self.device = self.get_device()
288-
self.dtype = self.get_dtype()
289-
290-
def get_dtype(self) -> torch.dtype:
291-
return self.blocks[0].mlp.down_proj.weight.dtype
292-
293-
def get_device(self) -> torch.device:
294-
return self.blocks[0].mlp.down_proj.weight.device
249+
self._init_datatype()
250+
self.load_model(kvargs["weight_dir"])
251+
self.cuda()
252+
253+
def _init_datatype(self):
254+
if isinstance(self.data_type, torch.dtype):
255+
return
256+
if self.data_type in ["fp16", "float16"]:
257+
self.data_type = torch.float16
258+
elif self.data_type in ["bf16", "bfloat16"]:
259+
self.data_type = torch.bfloat16
260+
elif self.data_type in ["fp32", "float32"]:
261+
self.data_type = torch.float32
262+
else:
263+
raise ValueError(f"Unsupport datatype {self.data_type}!")
264+
return
295265

296266
def rot_pos_emb(self, grid_thw):
297267
pos_ids = []
298-
for t, h, w in grid_thw:
268+
s = self.spatial_merge_size
269+
for _, h, w in grid_thw:
270+
pos_shape = (h // s, s, w // s, s)
299271
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
300-
hpos_ids = hpos_ids.reshape(
301-
h // self.spatial_merge_size,
302-
self.spatial_merge_size,
303-
w // self.spatial_merge_size,
304-
self.spatial_merge_size,
305-
)
306-
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
307-
hpos_ids = hpos_ids.flatten()
308-
309272
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
310-
wpos_ids = wpos_ids.reshape(
311-
h // self.spatial_merge_size,
312-
self.spatial_merge_size,
313-
w // self.spatial_merge_size,
314-
self.spatial_merge_size,
315-
)
316-
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
317-
wpos_ids = wpos_ids.flatten()
318-
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
273+
hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
274+
wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
275+
276+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1))
319277
pos_ids = torch.cat(pos_ids, dim=0)
320278
max_grid_size = grid_thw[:, 1:].max()
321-
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
279+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).type(torch.float32)
322280
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
323281
return rotary_pos_emb
324282

@@ -365,14 +323,22 @@ def get_window_index(self, grid_thw):
365323

366324
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
367325
hidden_states = self.patch_embed(hidden_states)
368-
rotary_pos_emb = self.rot_pos_emb(grid_thw)
326+
rotary_pos_emb = self.rot_pos_emb(grid_thw).to("cuda", non_blocking=True)
327+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
328+
dim=0, dtype=torch.int32
329+
)
330+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
331+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
332+
cu_seqlens = cu_seqlens.to("cuda", non_blocking=True)
333+
369334
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
370335
cu_window_seqlens = torch.tensor(
371336
cu_window_seqlens,
372337
device=hidden_states.device,
373338
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
374339
)
375340
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
341+
max_window_seqlen = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item()
376342

377343
seq_len, _ = hidden_states.size()
378344
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -381,40 +347,21 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
381347
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
382348
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
383349
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
384-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
385-
position_embeddings = (emb.cos(), emb.sin())
386-
387-
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
388-
dim=0,
389-
# Select dtype based on the following factors:
390-
# - FA2 requires that cu_seqlens_q must have dtype int32
391-
# - torch.onnx.export requires that cu_seqlens_q must have same
392-
# dtype as grid_thw
393-
# See https://github.com/huggingface/transformers/pull/34852
394-
# for more information
395-
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
396-
)
397-
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
398350

399351
for layer_num, blk in enumerate(self.blocks):
400352
if layer_num in self.fullatt_block_indexes:
401353
cu_seqlens_now = cu_seqlens
354+
max_seqlen_now = max_seqlen
402355
else:
403356
cu_seqlens_now = cu_window_seqlens
404-
if self.gradient_checkpointing and self.training:
405-
hidden_states = self._gradient_checkpointing_func(
406-
blk.__call__,
407-
hidden_states,
408-
cu_seqlens_now,
409-
None,
410-
position_embeddings,
411-
)
412-
else:
413-
hidden_states = blk(
414-
hidden_states,
415-
cu_seqlens=cu_seqlens_now,
416-
position_embeddings=position_embeddings,
417-
)
357+
max_seqlen_now = max_window_seqlen
358+
359+
hidden_states = blk(
360+
hidden_states,
361+
cu_seqlens=cu_seqlens_now,
362+
max_seqlen=max_seqlen_now,
363+
rotary_pos_emb=rotary_pos_emb,
364+
)
418365

419366
hidden_states = self.merger(hidden_states)
420367
reverse_indices = torch.argsort(window_index)
@@ -428,19 +375,15 @@ def load_image(self, img: List[ImageItem]):
428375
image_data = read_shm(get_shm_name_data(img.uuid))
429376
image_data = Image.open(BytesIO(image_data))
430377
image_data = resize_image(image_data)
431-
image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")
432-
pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16)
433-
image_grid_thw = image_inputs["image_grid_thw"]
378+
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
434379
elif isinstance(img, dict):
435380
image_data = read_shm(get_shm_name_data(img["uuid"]))
436381
image_data = Image.open(BytesIO(image_data))
437382
image_data = resize_image(image_data)
438-
image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")
439-
pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16)
440-
image_grid_thw = image_inputs["image_grid_thw"]
383+
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
441384
else:
442385
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
443-
return pixel_values.to(dtype=self.get_dtype()), image_grid_thw
386+
return pixel_values.to(dtype=self.data_type), image_grid_thw
444387

445388
def load_model(self, weight_dir):
446389

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton
4747
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
4848
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
49-
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
5049

5150
# adapted from
5251
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
@@ -136,12 +135,22 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
136135
self.num_heads = num_heads
137136
self.qkv = nn.Linear(dim, dim * 3, bias=True)
138137
self.proj = nn.Linear(dim, dim)
138+
self.has_vllm = False
139+
try:
140+
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
141+
142+
self.has_vllm = True
143+
self.apply_rotary_emb = apply_rotary_emb
144+
except ImportError:
145+
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")
146+
self.has_vllm = False
147+
self.apply_rotary_emb = apply_rotary_pos_emb_triton
139148

140149
def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
141150
t_ = t.float()
142151
cos = freqs.cos()
143152
sin = freqs.sin()
144-
output = apply_rotary_emb(t_, cos, sin).type_as(t)
153+
output = self.apply_rotary_emb(t_, cos, sin).type_as(t)
145154
return output
146155

147156
def forward(
@@ -321,13 +330,11 @@ def load_image(self, img: List[ImageItem]):
321330
image_data = Image.open(BytesIO(image_data))
322331
image_data = resize_image(image_data)
323332
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
324-
# pixel_values, image_grid_thw = tensor["pixel_values"], tensor["image_grid_thw"]
325333
elif isinstance(img, dict):
326334
image_data = read_shm(get_shm_name_data(img["uuid"]))
327335
image_data = Image.open(BytesIO(image_data))
328336
image_data = resize_image(image_data)
329337
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
330-
# pixel_values, image_grid_thw = tensor["pixel_values"], tensor["image_grid_thw"]
331338
else:
332339
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
333340
return pixel_values.to(dtype=self.data_type), image_grid_thw

0 commit comments

Comments
 (0)