24
24
from lightllm .models .qwen2_vl .qwen2_visual import PatchEmbed , VisionRotaryEmbedding
25
25
from lightllm .models .vit .triton_kernel .flashattention_nopad import flash_attention_fwd
26
26
from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
27
+ from lightllm .models .qwen2_vl .triton_kernel .rotary_pos_emb import apply_rotary_pos_emb_triton
27
28
28
29
# adapted from
29
30
# https://github.com/huggingface/transformers/blob/
30
31
# be37d34f44ff1bc928e59ffb8a30adecab8835a8/src
31
32
# /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1
32
- class Qwen2_5_VLVisionConfig (PretrainedConfig ):
33
- model_type = "qwen2_5_vl"
34
-
35
- def __init__ (
36
- self ,
37
- depth = 32 ,
38
- hidden_size = 3584 ,
39
- hidden_act = "silu" ,
40
- intermediate_size = 3420 ,
41
- num_heads = 16 ,
42
- in_channels = 3 ,
43
- patch_size = 14 ,
44
- spatial_merge_size = 2 ,
45
- temporal_patch_size = 2 ,
46
- tokens_per_second = 4 ,
47
- window_size = 112 ,
48
- out_hidden_size = 3584 ,
49
- fullatt_block_indexes = [7 , 15 , 23 , 31 ],
50
- ** kwargs ,
51
- ):
52
- super ().__init__ (** kwargs )
53
-
54
- self .depth = depth
55
- self .hidden_size = hidden_size
56
- self .hidden_act = hidden_act
57
- self .intermediate_size = intermediate_size
58
- self .num_heads = num_heads
59
- self .in_channels = in_channels
60
- self .patch_size = patch_size
61
- self .spatial_merge_size = spatial_merge_size
62
- self .temporal_patch_size = temporal_patch_size
63
- self .tokens_per_second = tokens_per_second
64
- self .window_size = window_size
65
- self .fullatt_block_indexes = fullatt_block_indexes
66
- self .out_hidden_size = out_hidden_size
67
-
68
-
69
33
class Qwen2RMSNorm (nn .Module ):
70
34
def __init__ (self , hidden_size , eps = 1e-6 ):
71
35
"""
@@ -104,54 +68,46 @@ def forward(self, hidden_state):
104
68
return self .down_proj (self .act_fn (self .gate_proj (hidden_state )) * self .up_proj (hidden_state ))
105
69
106
70
107
- def rotate_half (x ):
108
- """Rotates half the hidden dims of the input."""
109
- x1 = x [..., : x .shape [- 1 ] // 2 ]
110
- x2 = x [..., x .shape [- 1 ] // 2 :]
111
- return torch .cat ((- x2 , x1 ), dim = - 1 )
112
-
113
-
114
- def apply_rotary_pos_emb_vision (
115
- q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor
116
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
117
- orig_q_dtype = q .dtype
118
- orig_k_dtype = k .dtype
119
- q , k = q .float (), k .float ()
120
- cos , sin = cos .unsqueeze (- 2 ).float (), sin .unsqueeze (- 2 ).float ()
121
- q_embed = (q * cos ) + (rotate_half (q ) * sin )
122
- k_embed = (k * cos ) + (rotate_half (k ) * sin )
123
- q_embed = q_embed .to (orig_q_dtype )
124
- k_embed = k_embed .to (orig_k_dtype )
125
- return q_embed , k_embed
126
-
127
-
128
71
class Qwen2_5_VLVisionFlashAttention (nn .Module ):
129
72
def __init__ (self , dim : int , num_heads : int = 16 ) -> None :
130
73
super ().__init__ ()
131
74
self .num_heads = num_heads
132
75
self .head_dim = dim // num_heads
133
76
self .qkv = nn .Linear (dim , dim * 3 , bias = True )
134
77
self .proj = nn .Linear (dim , dim )
78
+ try :
79
+ from vllm .vllm_flash_attn .layers .rotary import apply_rotary_emb
80
+
81
+ self .has_vllm = True
82
+ self .apply_rotary_emb = apply_rotary_emb
83
+ except ImportError :
84
+ print ("Failed to import _flash_attn_forward from hopper.flash_attn_interface." )
85
+ self .has_vllm = False
86
+ self .apply_rotary_emb = apply_rotary_pos_emb_triton
87
+
88
+ def apply_rotary_pos_emb_vision (self , t : torch .Tensor , freqs : torch .Tensor ) -> torch .Tensor :
89
+ t_ = t .float ()
90
+ cos = freqs .cos ()
91
+ sin = freqs .sin ()
92
+ output = self .apply_rotary_emb (t_ , cos , sin ).type_as (t )
93
+ return output
135
94
136
95
def forward (
137
96
self ,
138
97
hidden_states : torch .Tensor ,
139
98
cu_seqlens : torch .Tensor ,
99
+ max_seqlen : int = 0 ,
140
100
rotary_pos_emb : Optional [torch .Tensor ] = None ,
141
- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
142
101
) -> torch .Tensor :
143
102
seq_length = hidden_states .shape [0 ]
144
103
q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
145
- if position_embeddings is None :
146
- emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
147
- cos = emb .cos ()
148
- sin = emb .sin ()
149
- else :
150
- cos , sin = position_embeddings
151
- q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
104
+ # if position_embeddings is None:
105
+ # position_embeddings = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
106
+ q = self .apply_rotary_pos_emb_vision (q .unsqueeze (0 ), rotary_pos_emb )
107
+ k = self .apply_rotary_pos_emb_vision (k .unsqueeze (0 ), rotary_pos_emb )
108
+ q = q .squeeze (0 )
109
+ k = k .squeeze (0 )
152
110
153
- cu_seqlens = cu_seqlens .to (q .device , torch .int32 )
154
- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
155
111
attn_output = g_cache_manager .alloc_tensor (q .shape , q .dtype , device = q .device )
156
112
flash_attention_fwd (q , k , v , attn_output , cu_seqlens , max_seqlen )
157
113
attn_output = attn_output .reshape (seq_length , - 1 )
@@ -183,14 +139,14 @@ def forward(
183
139
self ,
184
140
hidden_states : torch .Tensor ,
185
141
cu_seqlens : torch .Tensor ,
142
+ max_seqlen : int = 0 ,
186
143
rotary_pos_emb : Optional [torch .Tensor ] = None ,
187
- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
188
144
) -> torch .Tensor :
189
145
hidden_states = hidden_states + self .attn (
190
146
self .norm1 (hidden_states ),
191
147
cu_seqlens = cu_seqlens ,
148
+ max_seqlen = max_seqlen ,
192
149
rotary_pos_emb = rotary_pos_emb ,
193
- position_embeddings = position_embeddings ,
194
150
)
195
151
hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
196
152
return hidden_states
@@ -215,7 +171,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
215
171
class Qwen2_5VLTransformer (nn .Module ):
216
172
def __init__ (
217
173
self ,
218
- weight_dir ,
174
+ kvargs ,
219
175
depth = 32 ,
220
176
hidden_size = 3584 ,
221
177
hidden_act = "silu" ,
@@ -232,7 +188,13 @@ def __init__(
232
188
** kwargs ,
233
189
):
234
190
super ().__init__ ()
235
-
191
+ self .weight_dir = kvargs ["weight_dir" ]
192
+ self .data_type = kvargs .get ("data_type" , "bfloat16" )
193
+ # self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])]
194
+ # self.weight_dict = kvargs.get("weight_dict", None)
195
+ # self.quant_type = kvargs.get("quant_type", None)
196
+ # self.quant_cfg_path = kvargs.get("quant_cfg", None)
197
+ # self.max_batch_size = kvargs.get("max_batch_size", 1)
236
198
self .depth = depth
237
199
self .hidden_size = hidden_size
238
200
self .hidden_act = hidden_act
@@ -279,46 +241,42 @@ def __init__(
279
241
280
242
self .gradient_checkpointing = False
281
243
282
- processor_config_path = os .path .join (weight_dir , "preprocessor_config.json" )
244
+ processor_config_path = os .path .join (self . weight_dir , "preprocessor_config.json" )
283
245
with open (processor_config_path , "r" ) as f :
284
246
processor_config_dict = json .load (f )
285
247
self .processor = Qwen2VLImageProcessor (** processor_config_dict )
286
248
287
- self .device = self .get_device ()
288
- self .dtype = self .get_dtype ()
289
-
290
- def get_dtype (self ) -> torch .dtype :
291
- return self .blocks [0 ].mlp .down_proj .weight .dtype
292
-
293
- def get_device (self ) -> torch .device :
294
- return self .blocks [0 ].mlp .down_proj .weight .device
249
+ self ._init_datatype ()
250
+ self .load_model (kvargs ["weight_dir" ])
251
+ self .cuda ()
252
+
253
+ def _init_datatype (self ):
254
+ if isinstance (self .data_type , torch .dtype ):
255
+ return
256
+ if self .data_type in ["fp16" , "float16" ]:
257
+ self .data_type = torch .float16
258
+ elif self .data_type in ["bf16" , "bfloat16" ]:
259
+ self .data_type = torch .bfloat16
260
+ elif self .data_type in ["fp32" , "float32" ]:
261
+ self .data_type = torch .float32
262
+ else :
263
+ raise ValueError (f"Unsupport datatype { self .data_type } !" )
264
+ return
295
265
296
266
def rot_pos_emb (self , grid_thw ):
297
267
pos_ids = []
298
- for t , h , w in grid_thw :
268
+ s = self .spatial_merge_size
269
+ for _ , h , w in grid_thw :
270
+ pos_shape = (h // s , s , w // s , s )
299
271
hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
300
- hpos_ids = hpos_ids .reshape (
301
- h // self .spatial_merge_size ,
302
- self .spatial_merge_size ,
303
- w // self .spatial_merge_size ,
304
- self .spatial_merge_size ,
305
- )
306
- hpos_ids = hpos_ids .permute (0 , 2 , 1 , 3 )
307
- hpos_ids = hpos_ids .flatten ()
308
-
309
272
wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
310
- wpos_ids = wpos_ids .reshape (
311
- h // self .spatial_merge_size ,
312
- self .spatial_merge_size ,
313
- w // self .spatial_merge_size ,
314
- self .spatial_merge_size ,
315
- )
316
- wpos_ids = wpos_ids .permute (0 , 2 , 1 , 3 )
317
- wpos_ids = wpos_ids .flatten ()
318
- pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
273
+ hpos_ids = hpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
274
+ wpos_ids = wpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
275
+
276
+ pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ))
319
277
pos_ids = torch .cat (pos_ids , dim = 0 )
320
278
max_grid_size = grid_thw [:, 1 :].max ()
321
- rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
279
+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size ). type ( torch . float32 )
322
280
rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
323
281
return rotary_pos_emb
324
282
@@ -365,14 +323,22 @@ def get_window_index(self, grid_thw):
365
323
366
324
def forward (self , hidden_states : torch .Tensor , grid_thw : torch .Tensor ) -> torch .Tensor :
367
325
hidden_states = self .patch_embed (hidden_states )
368
- rotary_pos_emb = self .rot_pos_emb (grid_thw )
326
+ rotary_pos_emb = self .rot_pos_emb (grid_thw ).to ("cuda" , non_blocking = True )
327
+ cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
328
+ dim = 0 , dtype = torch .int32
329
+ )
330
+ cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
331
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
332
+ cu_seqlens = cu_seqlens .to ("cuda" , non_blocking = True )
333
+
369
334
window_index , cu_window_seqlens = self .get_window_index (grid_thw )
370
335
cu_window_seqlens = torch .tensor (
371
336
cu_window_seqlens ,
372
337
device = hidden_states .device ,
373
338
dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 ,
374
339
)
375
340
cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
341
+ max_window_seqlen = (cu_window_seqlens [1 :] - cu_window_seqlens [:- 1 ]).max ().item ()
376
342
377
343
seq_len , _ = hidden_states .size ()
378
344
hidden_states = hidden_states .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
@@ -381,40 +347,21 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
381
347
rotary_pos_emb = rotary_pos_emb .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
382
348
rotary_pos_emb = rotary_pos_emb [window_index , :, :]
383
349
rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
384
- emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
385
- position_embeddings = (emb .cos (), emb .sin ())
386
-
387
- cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
388
- dim = 0 ,
389
- # Select dtype based on the following factors:
390
- # - FA2 requires that cu_seqlens_q must have dtype int32
391
- # - torch.onnx.export requires that cu_seqlens_q must have same
392
- # dtype as grid_thw
393
- # See https://github.com/huggingface/transformers/pull/34852
394
- # for more information
395
- dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 ,
396
- )
397
- cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
398
350
399
351
for layer_num , blk in enumerate (self .blocks ):
400
352
if layer_num in self .fullatt_block_indexes :
401
353
cu_seqlens_now = cu_seqlens
354
+ max_seqlen_now = max_seqlen
402
355
else :
403
356
cu_seqlens_now = cu_window_seqlens
404
- if self .gradient_checkpointing and self .training :
405
- hidden_states = self ._gradient_checkpointing_func (
406
- blk .__call__ ,
407
- hidden_states ,
408
- cu_seqlens_now ,
409
- None ,
410
- position_embeddings ,
411
- )
412
- else :
413
- hidden_states = blk (
414
- hidden_states ,
415
- cu_seqlens = cu_seqlens_now ,
416
- position_embeddings = position_embeddings ,
417
- )
357
+ max_seqlen_now = max_window_seqlen
358
+
359
+ hidden_states = blk (
360
+ hidden_states ,
361
+ cu_seqlens = cu_seqlens_now ,
362
+ max_seqlen = max_seqlen_now ,
363
+ rotary_pos_emb = rotary_pos_emb ,
364
+ )
418
365
419
366
hidden_states = self .merger (hidden_states )
420
367
reverse_indices = torch .argsort (window_index )
@@ -428,19 +375,15 @@ def load_image(self, img: List[ImageItem]):
428
375
image_data = read_shm (get_shm_name_data (img .uuid ))
429
376
image_data = Image .open (BytesIO (image_data ))
430
377
image_data = resize_image (image_data )
431
- image_inputs = self .processor .preprocess (images = image_data , return_tensors = "pt" )
432
- pixel_values = image_inputs ["pixel_values" ].to (dtype = torch .bfloat16 )
433
- image_grid_thw = image_inputs ["image_grid_thw" ]
378
+ pixel_values , image_grid_thw = self .processor .preprocess (image_data )
434
379
elif isinstance (img , dict ):
435
380
image_data = read_shm (get_shm_name_data (img ["uuid" ]))
436
381
image_data = Image .open (BytesIO (image_data ))
437
382
image_data = resize_image (image_data )
438
- image_inputs = self .processor .preprocess (images = image_data , return_tensors = "pt" )
439
- pixel_values = image_inputs ["pixel_values" ].to (dtype = torch .bfloat16 )
440
- image_grid_thw = image_inputs ["image_grid_thw" ]
383
+ pixel_values , image_grid_thw = self .processor .preprocess (image_data )
441
384
else :
442
385
raise Exception ("Unsupport input types: {} for {}" .format (type (img ), img ))
443
- return pixel_values .to (dtype = self .get_dtype () ), image_grid_thw
386
+ return pixel_values .to (dtype = self .data_type ), image_grid_thw
444
387
445
388
def load_model (self , weight_dir ):
446
389
0 commit comments