Skip to content

Commit 604e96a

Browse files
Support emb/reranker/seq_cls padding_free (#6007)
1 parent 9fce30c commit 604e96a

File tree

12 files changed

+265
-81
lines changed

12 files changed

+265
-81
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ You can contact us and communicate with us by adding our group:
7575

7676

7777
## 🎉 News
78+
- 🎁 2025.09.29: Support padding_free for embedding/reranker/seq_cls tasks, use `--padding_free true --task_type embedding/reranker/generative_reranker/seq_cls` to begin!
7879
- 🎁 2025.09.07: Added support for CHORD training algorithm. See the [documentation](./docs/source_en/Instruction/GRPO/AdvancedResearch/CHORD.md)
7980
- 🎁 2025.09.06: Ulysses can now be used with ring-attention, allowing sequences to be sharded into any number of chunks (no longer limited by the number of heads). The argument remains `--sequence_parallel_size N`.
8081
- 🎁 2025.09.02: Megatron-SWIFT now supports multimodal model training. Documentation can be found [here](./docs/source_en/Megatron-SWIFT/Multimodal-Model.md).

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
- **模型量化**:支持AWQ、GPTQ、FP8和BNB的量化导出,导出的模型支持使用vLLM/SGLang/LmDeploy推理加速,并支持继续训练。
7272

7373
## 🎉 新闻
74+
- 🎁 2025.09.29: 支持embedding/reranker/seq_cls任务的padding_free参数, 使用`--padding_free true --task_type embedding/reranker/generative_reranker/seq_cls`开始训练!
7475
- 🎁 2025.09.07: 支持CHORD训练算法,请查看[文档](docs/source/Instruction/GRPO/AdvancedResearch/CHORD.md)
7576
- 🎁 2025.09.06: Ulysses现已支持与ring-attention结合使用,使得输入序列可以被切分成任意数量的块(不再受限于num_heads),命令参数仍然是`--sequence_parallel_size N`
7677
- 🎁 2025.09.02: Megatron-SWIFT支持多模态模型训练。文档参考[这里](./docs/source/Megatron-SWIFT/多模态模型.md)

examples/train/reranker/train_generative_reranker.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
nproc_per_node=4
22
# 4*47G
33
# losses: plugin/loss.py
4+
# only support --padding_side left
45
NPROC_PER_NODE=$nproc_per_node \
56
swift sft \
67
--model Qwen/Qwen3-Reranker-4B \
@@ -11,6 +12,7 @@ swift sft \
1112
--load_from_cache_file true \
1213
--split_dataset_ratio 0.05 \
1314
--eval_strategy steps \
15+
--padding_side left \
1416
--output_dir output \
1517
--eval_steps 100 \
1618
--num_train_epochs 1 \

examples/train/reranker/train_generative_reranker_listwise.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
nproc_per_node=4
22
# 4*47G
33
# losses: plugin/loss.py
4+
# only support --padding_side left
45
NPROC_PER_NODE=$nproc_per_node \
56
swift sft \
67
--model Qwen/Qwen3-Reranker-4B \
@@ -10,6 +11,7 @@ swift sft \
1011
--dataset MTEB/scidocs-reranking \
1112
--load_from_cache_file true \
1213
--split_dataset_ratio 0.05 \
14+
--padding_side left \
1315
--eval_strategy steps \
1416
--output_dir output \
1517
--eval_steps 100 \

swift/llm/model/patcher.py

Lines changed: 147 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,24 @@
44
from contextlib import contextmanager
55
from functools import wraps
66
from types import MethodType
7-
from typing import Dict, List, Optional, Union
7+
from typing import Any, Dict, List, Optional, Union
88

99
import accelerate
1010
import torch
1111
import torch.nn as nn
1212
import transformers
1313
from accelerate.utils import find_device
1414
from packaging import version
15+
from peft import PeftModel
1516
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
1617
from torch.nn.parallel import DistributedDataParallel as DDP
1718
from transformers import PreTrainedModel, dynamic_module_utils, trainer
1819
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
1920

2021
from swift.llm import deep_getattr, to_device, to_float_dtype
2122
from swift.utils import get_dist_setting, get_logger, is_mp, is_mp_ddp, safe_ddp_context
22-
from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count
23+
from swift.utils.torch_utils import (_get_max_memory, _sync_max_memory, get_cu_seqlens_from_position_ids,
24+
get_device_count, get_position_ids_from_cu_seqlens)
2325
from .utils import HfConfigFactory
2426

2527
logger = get_logger()
@@ -151,6 +153,8 @@ def _check_imports(filename) -> List[str]:
151153

152154

153155
def get_lm_head_model(model, model_meta=None, lm_heads=None):
156+
if isinstance(model, PeftModel):
157+
model = model.model
154158
model_meta = model_meta or model.model_meta
155159
lm_heads = lm_heads or ['lm_head']
156160
llm_prefix_list = getattr(model_meta.model_arch, 'language_model', None)
@@ -167,6 +171,81 @@ def get_lm_head_model(model, model_meta=None, lm_heads=None):
167171
return model
168172

169173

174+
def transformers_seq_cls_forward(self, *args, origin_forward, **kwargs):
175+
labels = kwargs.pop('labels', None)
176+
return_dict = kwargs.pop('return_dict', None)
177+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
178+
input_ids = kwargs.get('input_ids')
179+
inputs_embeds = kwargs.get('inputs_embeds')
180+
181+
output = origin_forward(*args, **kwargs)
182+
if hasattr(output, 'logits'):
183+
output.logits = output.logits.to(self.score.weight.dtype)
184+
elif 'last_hidden_state' in output:
185+
output.logits = output['last_hidden_state'].to(self.score.weight.dtype)
186+
logits = self.score(output.logits)
187+
if input_ids is not None:
188+
batch_size = input_ids.shape[0]
189+
else:
190+
batch_size = inputs_embeds.shape[0]
191+
192+
if self.config.pad_token_id is None and batch_size != 1:
193+
raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
194+
if self.config.pad_token_id is None:
195+
sequence_lengths = -1
196+
else:
197+
if output.get('attention_mask') is not None:
198+
# When use padding_free in seq_cls tasks, `revert_padding_free` will add a attention_mask in the output
199+
batch_size = output.get('attention_mask').shape[0]
200+
sequence_lengths = output.get('attention_mask').sum(dim=1) - 1
201+
elif input_ids is not None:
202+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
203+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
204+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
205+
sequence_lengths = sequence_lengths.to(logits.device)
206+
elif kwargs.get('attention_mask') is not None:
207+
sequence_lengths = kwargs['attention_mask'].sum(dim=1) - 1
208+
else:
209+
sequence_lengths = -1
210+
211+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
212+
213+
loss = None
214+
if labels is not None:
215+
labels = labels.to(logits.device)
216+
if self.config.problem_type is None:
217+
if self.num_labels == 1:
218+
self.config.problem_type = 'regression'
219+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
220+
self.config.problem_type = 'single_label_classification'
221+
else:
222+
self.config.problem_type = 'multi_label_classification'
223+
224+
if self.config.problem_type == 'regression':
225+
loss_fct = MSELoss()
226+
if self.num_labels == 1:
227+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
228+
else:
229+
loss = loss_fct(pooled_logits, labels)
230+
elif self.config.problem_type == 'single_label_classification':
231+
loss_fct = CrossEntropyLoss()
232+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
233+
elif self.config.problem_type == 'multi_label_classification':
234+
loss_fct = BCEWithLogitsLoss()
235+
loss = loss_fct(pooled_logits, labels)
236+
if not return_dict:
237+
output = (pooled_logits, ) + output[1:]
238+
return ((loss, ) + output) if loss is not None else output
239+
240+
return SequenceClassifierOutputWithPast(
241+
loss=loss,
242+
logits=pooled_logits,
243+
past_key_values=output.past_key_values,
244+
hidden_states=output.hidden_states,
245+
attentions=output.attentions,
246+
)
247+
248+
170249
def _patch_sequence_classification(model, model_meta):
171250
hidden_size = HfConfigFactory.get_config_attr(model.config, 'hidden_size')
172251
initializer_range = HfConfigFactory.get_config_attr(model.config, 'initializer_range')
@@ -183,73 +262,11 @@ def _patch_sequence_classification(model, model_meta):
183262
setattr(llm_model, lm_head, nn.Identity())
184263
break
185264

186-
origin_forward = llm_model.forward.__func__
265+
origin_forward = llm_model.forward
187266

188-
@wraps(origin_forward)
267+
@wraps(origin_forward.__func__)
189268
def new_forward(self, *args, **kwargs):
190-
labels = kwargs.pop('labels', None)
191-
return_dict = kwargs.pop('return_dict', None)
192-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
193-
input_ids = kwargs.get('input_ids')
194-
inputs_embeds = kwargs.get('inputs_embeds')
195-
196-
output = origin_forward(self, *args, **kwargs)
197-
output.logits = output.logits.to(self.score.weight.dtype)
198-
logits = self.score(output.logits)
199-
if input_ids is not None:
200-
batch_size = input_ids.shape[0]
201-
else:
202-
batch_size = inputs_embeds.shape[0]
203-
204-
if self.config.pad_token_id is None and batch_size != 1:
205-
raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
206-
if self.config.pad_token_id is None:
207-
sequence_lengths = -1
208-
else:
209-
if input_ids is not None:
210-
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
211-
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
212-
sequence_lengths = sequence_lengths % input_ids.shape[-1]
213-
sequence_lengths = sequence_lengths.to(logits.device)
214-
else:
215-
sequence_lengths = -1
216-
217-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
218-
219-
loss = None
220-
if labels is not None:
221-
labels = labels.to(logits.device)
222-
if self.config.problem_type is None:
223-
if self.num_labels == 1:
224-
self.config.problem_type = 'regression'
225-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
226-
self.config.problem_type = 'single_label_classification'
227-
else:
228-
self.config.problem_type = 'multi_label_classification'
229-
230-
if self.config.problem_type == 'regression':
231-
loss_fct = MSELoss()
232-
if self.num_labels == 1:
233-
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
234-
else:
235-
loss = loss_fct(pooled_logits, labels)
236-
elif self.config.problem_type == 'single_label_classification':
237-
loss_fct = CrossEntropyLoss()
238-
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
239-
elif self.config.problem_type == 'multi_label_classification':
240-
loss_fct = BCEWithLogitsLoss()
241-
loss = loss_fct(pooled_logits, labels)
242-
if not return_dict:
243-
output = (pooled_logits, ) + output[1:]
244-
return ((loss, ) + output) if loss is not None else output
245-
246-
return SequenceClassifierOutputWithPast(
247-
loss=loss,
248-
logits=pooled_logits,
249-
past_key_values=output.past_key_values,
250-
hidden_states=output.hidden_states,
251-
attentions=output.attentions,
252-
)
269+
return transformers_seq_cls_forward(self, *args, origin_forward=origin_forward, **kwargs)
253270

254271
llm_model.forward = MethodType(new_forward, llm_model)
255272

@@ -454,6 +471,69 @@ def patch_tp_plan(load_model: bool):
454471
os.environ['WORLD_SIZE'] = WORLD_SIZE
455472

456473

474+
def revert_padding_free(outputs: Dict[str, Any], inputs: Dict[str, Any], padding_side='left'):
475+
hidden_state_key = None
476+
if 'last_hidden_state' in outputs:
477+
hidden_state_key = 'last_hidden_state'
478+
elif 'logits' in outputs:
479+
hidden_state_key = 'logits'
480+
elif 'token_embeddings' in outputs:
481+
hidden_state_key = 'token_embeddings'
482+
483+
if hidden_state_key is None:
484+
raise NotImplementedError()
485+
last_hidden_state = outputs[hidden_state_key]
486+
last_hidden_state = last_hidden_state.squeeze(dim=0)
487+
if 'cu_seq_lens_q' in inputs:
488+
position_ids = get_position_ids_from_cu_seqlens(inputs['cu_seq_lens_q'])
489+
elif 'position_ids' in inputs and inputs['position_ids'].shape[0] == 1:
490+
position_ids = inputs['position_ids']
491+
else:
492+
raise ValueError(
493+
"revert_padding_free requires 'cu_seq_lens_q' or 'position_ids' in inputs, but neither was found.")
494+
495+
seq_lengths = []
496+
pos = position_ids[0]
497+
resets = torch.where(pos[1:] < pos[:-1])[0] + 1
498+
499+
if len(resets) == 0:
500+
# Only one sequence in this batch item
501+
seq_lengths = [pos.max().item() + 1]
502+
else:
503+
# Multiple sequences
504+
start = 0
505+
for end in resets:
506+
seq_lengths.append(end - start)
507+
start = end
508+
seq_lengths.append(pos.shape[0] - start)
509+
510+
max_length = max(seq_lengths)
511+
unpacked_logits = []
512+
attention_mask = []
513+
514+
start = 0
515+
for length in seq_lengths:
516+
seq_state = last_hidden_state[start:start + length]
517+
mask = torch.ones((seq_state.shape[0])).to(last_hidden_state.device)
518+
padding = torch.zeros(
519+
(max_length - length, last_hidden_state.shape[-1])).to(last_hidden_state.dtype).to(last_hidden_state.device)
520+
attention_padding = torch.zeros((max_length - length)).to(last_hidden_state.device)
521+
# re-padding
522+
if padding_side == 'left':
523+
seq_state = torch.cat((padding, seq_state), dim=0)
524+
mask = torch.cat((attention_padding, mask), dim=0)
525+
else:
526+
seq_state = torch.cat((seq_state, padding), dim=0)
527+
mask = torch.cat((mask, attention_padding), dim=0)
528+
unpacked_logits.append(seq_state)
529+
attention_mask.append(mask)
530+
start += length
531+
outputs[hidden_state_key] = torch.stack(unpacked_logits, dim=0)
532+
inputs['attention_mask'] = torch.stack(attention_mask, dim=0).to(torch.int64)
533+
outputs['attention_mask'] = inputs['attention_mask']
534+
return outputs
535+
536+
457537
@contextmanager
458538
def patch_attach_align_device_hook_on_blocks():
459539
from accelerate import big_modeling

swift/llm/model/register.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,14 +357,6 @@ def make_inputs_require_grads(module, input, output):
357357

358358
model.enable_input_require_grads = MethodType(enable_input_require_grads, model)
359359
tokenizer = model.tokenizer
360-
361-
def forward(self, **kwargs):
362-
output = self._forward_origin(input=kwargs)
363-
return {'last_hidden_state': output['sentence_embedding']}
364-
365-
if not hasattr(model, '_forward_origin'):
366-
model._forward_origin = model.forward
367-
model.forward = MethodType(forward, model)
368360
else:
369361
model = None
370362
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

swift/llm/template/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def encode(self,
521521
if chosen.channel is not None:
522522
encoded['channel'] = chosen.channel
523523

524-
lengths = [0]
524+
lengths = [0] if self.task_type not in {'reranker', 'generative_reranker'} else []
525525
for key in list(encoded.keys()):
526526
if encoded[key] is None:
527527
encoded.pop(key)
@@ -532,7 +532,10 @@ def encode(self,
532532
elif isinstance(value, (tuple, list)):
533533
lengths += value
534534
if return_length:
535-
encoded['length'] = sum(lengths)
535+
if self.task_type in {'reranker', 'generative_reranker'}:
536+
encoded['length'] = lengths
537+
else:
538+
encoded['length'] = sum(lengths)
536539
else:
537540
encoded.pop('length', None)
538541
if return_template_inputs:
@@ -1542,10 +1545,13 @@ def _reranker_data_collator(self,
15421545
max_positive = min(positive_num, max_positive_samples)
15431546
max_negative = min(negative_num, max_negative_samples)
15441547
for i in random.sample(range(positive_num), max_positive):
1545-
new_batch.append({'input_ids': b['input_ids'][i]})
1548+
new_batch.append({'input_ids': b['input_ids'][i], 'length': b['length'][i]})
15461549
labels_list.append(1)
15471550
for j in random.sample(range(negative_num), max_negative):
1548-
new_batch.append({'input_ids': b['input_ids'][j + positive_num]})
1551+
new_batch.append({
1552+
'input_ids': b['input_ids'][j + positive_num],
1553+
'length': b['length'][j + positive_num]
1554+
})
15491555
labels_list.append(0)
15501556

15511557
res = self._data_collator(new_batch, padding_to=padding_to)

swift/trainers/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ class SwiftArgumentsMixin(RLHFArgumentsMixin, TrainArgumentsMixin):
154154
train_type: Optional[str] = None
155155
local_repo_path: Optional[str] = None
156156
galore_config: Optional[GaLoreConfig] = None
157+
padding_side: Optional[str] = None
158+
padding_free: Optional[bool] = None
159+
task_type: Optional[str] = None
157160

158161
def __post_init__(self):
159162
if hasattr(self, 'output_dir'):

0 commit comments

Comments
 (0)