4
4
from contextlib import contextmanager
5
5
from functools import wraps
6
6
from types import MethodType
7
- from typing import Dict , List , Optional , Union
7
+ from typing import Any , Dict , List , Optional , Union
8
8
9
9
import accelerate
10
10
import torch
11
11
import torch .nn as nn
12
12
import transformers
13
13
from accelerate .utils import find_device
14
14
from packaging import version
15
+ from peft import PeftModel
15
16
from torch .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
16
17
from torch .nn .parallel import DistributedDataParallel as DDP
17
18
from transformers import PreTrainedModel , dynamic_module_utils , trainer
18
19
from transformers .modeling_outputs import SequenceClassifierOutputWithPast
19
20
20
21
from swift .llm import deep_getattr , to_device , to_float_dtype
21
22
from swift .utils import get_dist_setting , get_logger , is_mp , is_mp_ddp , safe_ddp_context
22
- from swift .utils .torch_utils import _get_max_memory , _sync_max_memory , get_device_count
23
+ from swift .utils .torch_utils import (_get_max_memory , _sync_max_memory , get_cu_seqlens_from_position_ids ,
24
+ get_device_count , get_position_ids_from_cu_seqlens )
23
25
from .utils import HfConfigFactory
24
26
25
27
logger = get_logger ()
@@ -151,6 +153,8 @@ def _check_imports(filename) -> List[str]:
151
153
152
154
153
155
def get_lm_head_model (model , model_meta = None , lm_heads = None ):
156
+ if isinstance (model , PeftModel ):
157
+ model = model .model
154
158
model_meta = model_meta or model .model_meta
155
159
lm_heads = lm_heads or ['lm_head' ]
156
160
llm_prefix_list = getattr (model_meta .model_arch , 'language_model' , None )
@@ -167,6 +171,81 @@ def get_lm_head_model(model, model_meta=None, lm_heads=None):
167
171
return model
168
172
169
173
174
+ def transformers_seq_cls_forward (self , * args , origin_forward , ** kwargs ):
175
+ labels = kwargs .pop ('labels' , None )
176
+ return_dict = kwargs .pop ('return_dict' , None )
177
+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
178
+ input_ids = kwargs .get ('input_ids' )
179
+ inputs_embeds = kwargs .get ('inputs_embeds' )
180
+
181
+ output = origin_forward (* args , ** kwargs )
182
+ if hasattr (output , 'logits' ):
183
+ output .logits = output .logits .to (self .score .weight .dtype )
184
+ elif 'last_hidden_state' in output :
185
+ output .logits = output ['last_hidden_state' ].to (self .score .weight .dtype )
186
+ logits = self .score (output .logits )
187
+ if input_ids is not None :
188
+ batch_size = input_ids .shape [0 ]
189
+ else :
190
+ batch_size = inputs_embeds .shape [0 ]
191
+
192
+ if self .config .pad_token_id is None and batch_size != 1 :
193
+ raise ValueError ('Cannot handle batch sizes > 1 if no padding token is defined.' )
194
+ if self .config .pad_token_id is None :
195
+ sequence_lengths = - 1
196
+ else :
197
+ if output .get ('attention_mask' ) is not None :
198
+ # When use padding_free in seq_cls tasks, `revert_padding_free` will add a attention_mask in the output
199
+ batch_size = output .get ('attention_mask' ).shape [0 ]
200
+ sequence_lengths = output .get ('attention_mask' ).sum (dim = 1 ) - 1
201
+ elif input_ids is not None :
202
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
203
+ sequence_lengths = torch .eq (input_ids , self .config .pad_token_id ).int ().argmax (- 1 ) - 1
204
+ sequence_lengths = sequence_lengths % input_ids .shape [- 1 ]
205
+ sequence_lengths = sequence_lengths .to (logits .device )
206
+ elif kwargs .get ('attention_mask' ) is not None :
207
+ sequence_lengths = kwargs ['attention_mask' ].sum (dim = 1 ) - 1
208
+ else :
209
+ sequence_lengths = - 1
210
+
211
+ pooled_logits = logits [torch .arange (batch_size , device = logits .device ), sequence_lengths ]
212
+
213
+ loss = None
214
+ if labels is not None :
215
+ labels = labels .to (logits .device )
216
+ if self .config .problem_type is None :
217
+ if self .num_labels == 1 :
218
+ self .config .problem_type = 'regression'
219
+ elif self .num_labels > 1 and (labels .dtype == torch .long or labels .dtype == torch .int ):
220
+ self .config .problem_type = 'single_label_classification'
221
+ else :
222
+ self .config .problem_type = 'multi_label_classification'
223
+
224
+ if self .config .problem_type == 'regression' :
225
+ loss_fct = MSELoss ()
226
+ if self .num_labels == 1 :
227
+ loss = loss_fct (pooled_logits .squeeze (), labels .squeeze ())
228
+ else :
229
+ loss = loss_fct (pooled_logits , labels )
230
+ elif self .config .problem_type == 'single_label_classification' :
231
+ loss_fct = CrossEntropyLoss ()
232
+ loss = loss_fct (pooled_logits .view (- 1 , self .num_labels ), labels .view (- 1 ))
233
+ elif self .config .problem_type == 'multi_label_classification' :
234
+ loss_fct = BCEWithLogitsLoss ()
235
+ loss = loss_fct (pooled_logits , labels )
236
+ if not return_dict :
237
+ output = (pooled_logits , ) + output [1 :]
238
+ return ((loss , ) + output ) if loss is not None else output
239
+
240
+ return SequenceClassifierOutputWithPast (
241
+ loss = loss ,
242
+ logits = pooled_logits ,
243
+ past_key_values = output .past_key_values ,
244
+ hidden_states = output .hidden_states ,
245
+ attentions = output .attentions ,
246
+ )
247
+
248
+
170
249
def _patch_sequence_classification (model , model_meta ):
171
250
hidden_size = HfConfigFactory .get_config_attr (model .config , 'hidden_size' )
172
251
initializer_range = HfConfigFactory .get_config_attr (model .config , 'initializer_range' )
@@ -183,73 +262,11 @@ def _patch_sequence_classification(model, model_meta):
183
262
setattr (llm_model , lm_head , nn .Identity ())
184
263
break
185
264
186
- origin_forward = llm_model .forward . __func__
265
+ origin_forward = llm_model .forward
187
266
188
- @wraps (origin_forward )
267
+ @wraps (origin_forward . __func__ )
189
268
def new_forward (self , * args , ** kwargs ):
190
- labels = kwargs .pop ('labels' , None )
191
- return_dict = kwargs .pop ('return_dict' , None )
192
- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
193
- input_ids = kwargs .get ('input_ids' )
194
- inputs_embeds = kwargs .get ('inputs_embeds' )
195
-
196
- output = origin_forward (self , * args , ** kwargs )
197
- output .logits = output .logits .to (self .score .weight .dtype )
198
- logits = self .score (output .logits )
199
- if input_ids is not None :
200
- batch_size = input_ids .shape [0 ]
201
- else :
202
- batch_size = inputs_embeds .shape [0 ]
203
-
204
- if self .config .pad_token_id is None and batch_size != 1 :
205
- raise ValueError ('Cannot handle batch sizes > 1 if no padding token is defined.' )
206
- if self .config .pad_token_id is None :
207
- sequence_lengths = - 1
208
- else :
209
- if input_ids is not None :
210
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
211
- sequence_lengths = torch .eq (input_ids , self .config .pad_token_id ).int ().argmax (- 1 ) - 1
212
- sequence_lengths = sequence_lengths % input_ids .shape [- 1 ]
213
- sequence_lengths = sequence_lengths .to (logits .device )
214
- else :
215
- sequence_lengths = - 1
216
-
217
- pooled_logits = logits [torch .arange (batch_size , device = logits .device ), sequence_lengths ]
218
-
219
- loss = None
220
- if labels is not None :
221
- labels = labels .to (logits .device )
222
- if self .config .problem_type is None :
223
- if self .num_labels == 1 :
224
- self .config .problem_type = 'regression'
225
- elif self .num_labels > 1 and (labels .dtype == torch .long or labels .dtype == torch .int ):
226
- self .config .problem_type = 'single_label_classification'
227
- else :
228
- self .config .problem_type = 'multi_label_classification'
229
-
230
- if self .config .problem_type == 'regression' :
231
- loss_fct = MSELoss ()
232
- if self .num_labels == 1 :
233
- loss = loss_fct (pooled_logits .squeeze (), labels .squeeze ())
234
- else :
235
- loss = loss_fct (pooled_logits , labels )
236
- elif self .config .problem_type == 'single_label_classification' :
237
- loss_fct = CrossEntropyLoss ()
238
- loss = loss_fct (pooled_logits .view (- 1 , self .num_labels ), labels .view (- 1 ))
239
- elif self .config .problem_type == 'multi_label_classification' :
240
- loss_fct = BCEWithLogitsLoss ()
241
- loss = loss_fct (pooled_logits , labels )
242
- if not return_dict :
243
- output = (pooled_logits , ) + output [1 :]
244
- return ((loss , ) + output ) if loss is not None else output
245
-
246
- return SequenceClassifierOutputWithPast (
247
- loss = loss ,
248
- logits = pooled_logits ,
249
- past_key_values = output .past_key_values ,
250
- hidden_states = output .hidden_states ,
251
- attentions = output .attentions ,
252
- )
269
+ return transformers_seq_cls_forward (self , * args , origin_forward = origin_forward , ** kwargs )
253
270
254
271
llm_model .forward = MethodType (new_forward , llm_model )
255
272
@@ -454,6 +471,69 @@ def patch_tp_plan(load_model: bool):
454
471
os .environ ['WORLD_SIZE' ] = WORLD_SIZE
455
472
456
473
474
+ def revert_padding_free (outputs : Dict [str , Any ], inputs : Dict [str , Any ], padding_side = 'left' ):
475
+ hidden_state_key = None
476
+ if 'last_hidden_state' in outputs :
477
+ hidden_state_key = 'last_hidden_state'
478
+ elif 'logits' in outputs :
479
+ hidden_state_key = 'logits'
480
+ elif 'token_embeddings' in outputs :
481
+ hidden_state_key = 'token_embeddings'
482
+
483
+ if hidden_state_key is None :
484
+ raise NotImplementedError ()
485
+ last_hidden_state = outputs [hidden_state_key ]
486
+ last_hidden_state = last_hidden_state .squeeze (dim = 0 )
487
+ if 'cu_seq_lens_q' in inputs :
488
+ position_ids = get_position_ids_from_cu_seqlens (inputs ['cu_seq_lens_q' ])
489
+ elif 'position_ids' in inputs and inputs ['position_ids' ].shape [0 ] == 1 :
490
+ position_ids = inputs ['position_ids' ]
491
+ else :
492
+ raise ValueError (
493
+ "revert_padding_free requires 'cu_seq_lens_q' or 'position_ids' in inputs, but neither was found." )
494
+
495
+ seq_lengths = []
496
+ pos = position_ids [0 ]
497
+ resets = torch .where (pos [1 :] < pos [:- 1 ])[0 ] + 1
498
+
499
+ if len (resets ) == 0 :
500
+ # Only one sequence in this batch item
501
+ seq_lengths = [pos .max ().item () + 1 ]
502
+ else :
503
+ # Multiple sequences
504
+ start = 0
505
+ for end in resets :
506
+ seq_lengths .append (end - start )
507
+ start = end
508
+ seq_lengths .append (pos .shape [0 ] - start )
509
+
510
+ max_length = max (seq_lengths )
511
+ unpacked_logits = []
512
+ attention_mask = []
513
+
514
+ start = 0
515
+ for length in seq_lengths :
516
+ seq_state = last_hidden_state [start :start + length ]
517
+ mask = torch .ones ((seq_state .shape [0 ])).to (last_hidden_state .device )
518
+ padding = torch .zeros (
519
+ (max_length - length , last_hidden_state .shape [- 1 ])).to (last_hidden_state .dtype ).to (last_hidden_state .device )
520
+ attention_padding = torch .zeros ((max_length - length )).to (last_hidden_state .device )
521
+ # re-padding
522
+ if padding_side == 'left' :
523
+ seq_state = torch .cat ((padding , seq_state ), dim = 0 )
524
+ mask = torch .cat ((attention_padding , mask ), dim = 0 )
525
+ else :
526
+ seq_state = torch .cat ((seq_state , padding ), dim = 0 )
527
+ mask = torch .cat ((mask , attention_padding ), dim = 0 )
528
+ unpacked_logits .append (seq_state )
529
+ attention_mask .append (mask )
530
+ start += length
531
+ outputs [hidden_state_key ] = torch .stack (unpacked_logits , dim = 0 )
532
+ inputs ['attention_mask' ] = torch .stack (attention_mask , dim = 0 ).to (torch .int64 )
533
+ outputs ['attention_mask' ] = inputs ['attention_mask' ]
534
+ return outputs
535
+
536
+
457
537
@contextmanager
458
538
def patch_attach_align_device_hook_on_blocks ():
459
539
from accelerate import big_modeling
0 commit comments