Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
33ec4aa
refactor rejected_messages
Jintao-Huang Aug 27, 2025
1b39d1f
update
Jintao-Huang Aug 27, 2025
8324d48
Merge branch 'main' into refactor_rejected_messages
Jintao-Huang Aug 27, 2025
bc9dafb
Merge branch 'main' into refactor_rejected_messages
Jintao-Huang Aug 28, 2025
2108d0e
Merge branch 'main' into refactor_rejected_messages
Jintao-Huang Aug 28, 2025
2a10813
update
Jintao-Huang Aug 28, 2025
d14e2f7
fix
Jintao-Huang Aug 28, 2025
cc6fbdc
fix
Jintao-Huang Aug 28, 2025
7884f22
fix
Jintao-Huang Aug 28, 2025
1a2773f
fix
Jintao-Huang Aug 28, 2025
6d65f2a
lint pass
Jintao-Huang Aug 28, 2025
8711a69
fix
Jintao-Huang Aug 28, 2025
29865e9
update
Jintao-Huang Aug 28, 2025
9a24218
Merge remote-tracking branch 'refs/remotes/origin/refactor_rejected_m…
Jintao-Huang Aug 28, 2025
9e374b1
fix
Jintao-Huang Aug 28, 2025
ea78160
fix
Jintao-Huang Aug 28, 2025
d22c514
Merge remote-tracking branch 'refs/remotes/origin/refactor_rejected_m…
Jintao-Huang Aug 28, 2025
34074ec
update
Jintao-Huang Aug 28, 2025
a539821
Merge remote-tracking branch 'refs/remotes/origin/refactor_rejected_m…
Jintao-Huang Aug 28, 2025
df3c124
update
Jintao-Huang Aug 28, 2025
a8f3217
Merge branch 'main' into refactor_rejected_messages
Jintao-Huang Aug 28, 2025
7df2976
update docs
Jintao-Huang Aug 28, 2025
2e4bdfa
fix
Jintao-Huang Aug 28, 2025
5092dd9
fix docs
Jintao-Huang Aug 28, 2025
d74956d
fix
Jintao-Huang Aug 28, 2025
eafdd5d
Merge remote-tracking branch 'refs/remotes/origin/refactor_rejected_m…
Jintao-Huang Aug 28, 2025
060efdc
fix
Jintao-Huang Aug 29, 2025
f68cadb
update
Jintao-Huang Aug 29, 2025
ae3c698
fix
Jintao-Huang Aug 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/Customization/自定义数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ alpaca格式:

> 注: RM 额外支持 margin 列,参考[RM文档](../Instruction/人类对齐.md#rm)

当然,你也可以直接使用`rejected_messages`,而不是只提供`rejected_response`/`rejected_images`(需ms-swift>=3.8),这将提供更大的灵活度(例如多模态/agent场景)。在多模态场景下,若使用rejected_messages,你需要额外传入"rejected_images","rejected_audios","rejected_videos"等内容。数据格式例子如下:

```jsonl
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小狗。"}], "rejected_images": ["cat.png"]}
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "rejected_images": ["dog.png"]}
```

以上格式等价于:
```jsonl
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_response": "这是一只小狗。"}
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_images": ["dog.png"]}
```


#### KTO

```jsonl
Expand Down
14 changes: 14 additions & 0 deletions docs/source_en/Customization/Custom-dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ The format of multimodal data should follow the specifications in [Multimodal Da

> Note: RM additionally supports the margin column. For details, refer to the [RM documentation](../Instruction/RLHF.md#rm).

Sure, you can also directly use `rejected_messages` instead of only providing `rejected_response` / `rejected_images` (requires ms-swift>=3.8), which offers greater flexibility (e.g., for multimodal or agent scenarios). In multimodal cases, if you use `rejected_messages`, you need to additionally provide fields such as `"rejected_images"`, `"rejected_audios"`, `"rejected_videos"`, etc. An example of the data format is as follows:

```jsonl
{"messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "images": ["kitten.png"], "rejected_messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a puppy."}], "rejected_images": ["kitten.png"]}
{"messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "images": ["kitten.png"], "rejected_messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "rejected_images": ["puppy.png"]}
```

The above format is equivalent to:

```jsonl
{"messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "images": ["kitten.png"], "rejected_response": "This is a puppy."}
{"messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "images": ["kitten.png"], "rejected_images": ["puppy.png"]}
```

#### KTO

```jsonl
Expand Down
54 changes: 23 additions & 31 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from collections import Counter
from contextlib import contextmanager
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
Expand All @@ -18,21 +19,18 @@

logger = get_logger()

_pair_keys = ['messages', 'images', 'videos', 'audios', 'tools', 'objects']


class RowPreprocessor:
standard_keys = [
'messages',
'rejected_response',
'rejected_images',
'label',
'images',
'videos',
'audios',
'tools',
'objects',
'channel',
'margin',
]
standard_keys = _pair_keys + list(
chain.from_iterable([f'{prefix}_{k}' for k in _pair_keys]
for prefix in ['rejected', 'positive', 'negative'])) + [
'rejected_response',
'label',
'channel',
'margin',
]

def __init__(self,
*,
Expand Down Expand Up @@ -100,25 +98,11 @@ def _cast_mm_data(row: Dict[str, Any]) -> None:

@staticmethod
def _check_rejected_response(row: Dict[str, Any]) -> None:
if 'rejected_messages' in row:
chosen_messages = row['messages']
rejected_messages = row['rejected_messages']
messages = []
rejected_response = None
for chosen_user, chosen_assistant, rejected_user, rejected_assistant in zip(
chosen_messages[::2], chosen_messages[1::2], rejected_messages[::2], rejected_messages[1::2]):
assert chosen_user == rejected_user
messages.append(chosen_user)
messages.append(chosen_assistant)
if chosen_assistant != rejected_assistant:
rejected_response = rejected_assistant['content']
row['messages'] = messages
row['rejected_response'] = rejected_response

if 'rejected_response' in row:
messages = row['messages']
rejected_response = row['rejected_response']
if rejected_response is None or rejected_response == messages[-1]['content']:
if (rejected_response is None
or isinstance(rejected_response, str) and rejected_response == messages[-1]['content']):
raise ValueError(f'rejected_response: {rejected_response}')

def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -197,8 +181,8 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool,
row = [row]
for r in row:
self._check_objects(r)
self._check_messages(r)
self._check_rejected_response(r)
self._check_messages(r)
self._cast_mm_data(r)
except Exception as e:
if strict:
Expand Down Expand Up @@ -275,11 +259,19 @@ def _patch_arrow_writer():
def _new_init(self, schema=None, features=None, *args, **kwargs):

if features is not None:
features['messages'] = [{
messages_feature = [{
'role': Value(dtype='string'),
'content': Value(dtype='string'),
}]
messages_feature_with_loss = [{
'role': Value(dtype='string'),
'content': Value(dtype='string'),
'loss': Value(dtype='float64'),
}]
features['messages'] = messages_feature_with_loss
features['rejected_messages'] = messages_feature_with_loss
features['positive_messages'] = [messages_feature]
features['negative_messages'] = [messages_feature]
features['images'] = [{'bytes': Value(dtype='binary'), 'path': Value(dtype='string')}]
features['objects'] = {
'ref': Sequence(feature=Value(dtype='string'), length=-1),
Expand Down
120 changes: 60 additions & 60 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
from .template_meta import TemplateMeta
from swift.plugin import agent_templates, loss_scale_map
self._processor_inited = False
self._version = 'v2' # Avoid compatibility issues caused by load_from_cache_file caching.
self._version = 'v3' # Avoid compatibility issues caused by load_from_cache_file caching.
self.max_length = max_length
self.model = None

Expand Down Expand Up @@ -232,6 +232,13 @@ def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None:
else:
i += 1

def _preprocess_inputs_reranker(
self,
inputs: StdTemplateInputs,
) -> None:
# TODO: remove
return

def _preprocess_inputs(
self,
inputs: StdTemplateInputs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The inputs argument can be None when called from _encode_truncated (e.g., with a None rejected sample). The type hint should be updated to Optional[StdTemplateInputs] to accurately reflect this.

Expand All @@ -241,31 +248,31 @@ def _preprocess_inputs(
self._replace_image_tags(inputs)
self._replace_start_image_tags(inputs)

for img_field in ['images', 'rejected_images']:
images = getattr(inputs, img_field, None)
if not images:
continue
load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'}
load_images_origin = load_images
if self.max_pixels is not None or inputs.objects:
load_images = True
if images:
for i, image in enumerate(images):
images[i] = self._load_image(image, load_images)
if inputs.objects:
self._get_height_width(inputs)
if self.max_pixels is not None and images:
images = [rescale_image(img, self.max_pixels) for img in images]
if images and not load_images_origin: # fix pt & qwen-vl
for i, image in enumerate(images):
if isinstance(image, Image.Image):
images[i] = self._save_pil_image(image)
setattr(inputs, img_field, images)
images = inputs.images
load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'}
load_images_origin = load_images
if self.max_pixels is not None or inputs.objects:
load_images = True
if images:
for i, image in enumerate(images):
images[i] = self._load_image(images[i], load_images)
if inputs.objects:
self._get_height_width(inputs)
if self.max_pixels is not None:
# Scale the image proportionally without affecting the scaled objects.
images = [rescale_image(img, self.max_pixels) for img in images]
if images and not load_images_origin: # fix pt & qwen-vl
for i, image in enumerate(images):
if isinstance(image, Image.Image):
images[i] = self._save_pil_image(image)
inputs.images = images

if self.mode == 'vllm' and inputs.audios:
sampling_rate = get_env_args('sampling_rate', int, None)
inputs.audios = load_batch(
inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate, return_sr=True))
if inputs.is_multimodal:
self._add_default_tags(inputs)

@staticmethod
def _replace_image_tags(inputs: StdTemplateInputs):
Expand Down Expand Up @@ -331,17 +338,11 @@ def get_base_model(model):
else:
return model

def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
margin = inputs.margin
chosen_inputs, rejected_inputs = inputs, deepcopy(inputs)

assert chosen_inputs.rejected_response or chosen_inputs.rejected_images, f'inputs: {inputs}'
if chosen_inputs.rejected_response:
rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response
if chosen_inputs.rejected_images:
rejected_inputs.images = chosen_inputs.rejected_images
chosen_encoded = self._encode_truncated(chosen_inputs)
rejected_encoded = self._encode_truncated(rejected_inputs)
def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
chosen = inputs.chosen
margin = chosen.margin
chosen_encoded = self._encode_truncated(chosen)
rejected_encoded = self._encode_truncated(inputs.rejected)

encoded = {}
for prefix in ['chosen', 'rejected']:
Expand All @@ -352,10 +353,9 @@ def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded['margin'] = float(margin)
return encoded

def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
label, inputs.label = inputs.label, None
def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
encoded = self._rlhf_encode(inputs)
encoded['label'] = bool(label)
encoded['label'] = bool(inputs.chosen.label)
return encoded

def _gkd_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
Expand All @@ -366,7 +366,8 @@ def _gkd_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded.pop(k, None)
return encoded

def _embedding_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
def _embedding_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
inputs = inputs.chosen # TODO: refactor
_encoded = {}
labels = []
inference = len(inputs.messages) == 1
Expand Down Expand Up @@ -429,7 +430,9 @@ def split_multi_medias(_inputs):
_encoded.pop('labels', None)
return _encoded

def _reranker_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
inputs = inputs.chosen # TODO: refactor
self._preprocess_inputs_reranker(inputs)
_encoded = {}
labels = []

Expand Down Expand Up @@ -486,45 +489,46 @@ def encode(self,
"""
assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: '
'template.init_processor(processor).')
if isinstance(inputs, (InferRequest, TemplateInputs)):
if isinstance(inputs, InferRequest):
inputs = asdict(inputs)

if isinstance(inputs, dict):
inputs = deepcopy(inputs)
if self.task_type == 'causal_lm' and not self.is_training:
InferRequest.remove_response(inputs['messages'])
inputs = StdTemplateInputs.from_dict(inputs)
elif isinstance(inputs, StdTemplateInputs):
inputs = TemplateInputs.from_dict(inputs)
elif isinstance(inputs, TemplateInputs):
inputs = deepcopy(inputs)
assert isinstance(inputs, StdTemplateInputs)
self._preprocess_inputs(inputs)
assert isinstance(inputs, TemplateInputs)

chosen = inputs.chosen
if self.task_type == 'causal_lm':
if self.mode in {'train', 'pt', 'vllm', 'lmdeploy', 'sglang'}:
encoded = self._encode_truncated(inputs)
encoded = self._encode_truncated(chosen)
elif self.mode == 'rlhf':
encoded = self._rlhf_encode(inputs)
elif self.mode == 'kto':
encoded = self._kto_encode(inputs)
elif self.mode == 'gkd':
encoded = self._gkd_encode(inputs)
encoded = self._gkd_encode(chosen)
elif self.task_type == 'seq_cls':
if self.mode == 'rlhf':
encoded = self._rlhf_encode(inputs)
for prefix in ['chosen', 'rejected']:
encoded.pop(f'{prefix}_labels', None)
encoded.pop(f'{prefix}_loss_scale', None)
else:
encoded = self._seq_cls_encode(inputs)
encoded = self._seq_cls_encode(chosen)
elif self.task_type == 'prm':
encoded = self._encode_truncated(inputs)
encoded = self._encode_truncated(chosen)
elif self.task_type == 'embedding':
encoded = self._embedding_encode(inputs)
elif self.task_type in {'reranker', 'generative_reranker'}:
encoded = self._reranker_encode(inputs)
else:
raise ValueError(f'task_type: {self.task_type} is not supported.')

if inputs.channel is not None:
encoded['channel'] = inputs.channel
if chosen.channel is not None:
encoded['channel'] = chosen.channel

lengths = [0]
for key in list(encoded.keys()):
Expand All @@ -541,9 +545,9 @@ def encode(self,
else:
encoded.pop('length', None)
if return_template_inputs:
encoded['template_inputs'] = inputs
encoded['template_inputs'] = chosen
if not self.remove_unused_columns:
encoded['_extra_kwargs'] = inputs.extra_kwargs
encoded['_extra_kwargs'] = chosen.extra_kwargs
return encoded

def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]:
Expand Down Expand Up @@ -932,11 +936,6 @@ def _add_default_tags(inputs: StdTemplateInputs):
elif message['role'] == 'assistant':
continue
total_content.append(content)
if inputs.rejected_response:
rejected_response = inputs.rejected_response
if isinstance(inputs.rejected_response, str):
rejected_response = [rejected_response]
total_content += rejected_response
total_content = '\n'.join(total_content)
if inputs.system:
total_content = f'{inputs.system}\n{total_content}'
Expand Down Expand Up @@ -1104,6 +1103,7 @@ def _swift_encode(self, inputs: StdTemplateInputs):
prefix = template_meta.system_prefix
self._concat_context_list(prefix, res_context_list, res_context_types, system=system)

assert len(inputs.messages) > 0, f'inputs.messages: {inputs.messages}'
n_round = len(inputs.messages) // 2
for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])):
query_role, query = query_message['role'], query_message['content']
Expand Down Expand Up @@ -1145,7 +1145,7 @@ def _swift_encode(self, inputs: StdTemplateInputs):
if isinstance(stop_word, str))
# self.is_training needed because we may want to continue generation from
# the current response
if (self.is_training and not sep_token or self.task_type == 'embedding') and not endswith_stop_words:
if (self.is_training or self.task_type != 'causal_lm') and not sep_token and not endswith_stop_words:
extra_context_list = template_meta.suffix
extra_context_type = ContextType.SUFFIX
elif template_meta.response_prefix:
Expand Down Expand Up @@ -1193,9 +1193,7 @@ def _truncate(self, input_ids: List[int], labels: Optional[List[int]], loss_mask
return input_ids, labels, loss_mask

def _encode_truncated(self, inputs: StdTemplateInputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for inputs should be Optional[StdTemplateInputs] since it can receive None (e.g., from inputs.rejected in _rlhf_encode). This will be addressed in a subsequent comment, but updating the type hint is the first step for correctness.

Suggested change
def _encode_truncated(self, inputs: StdTemplateInputs):
def _encode_truncated(self, inputs: Optional[StdTemplateInputs]):

if inputs.is_multimodal:
self._add_default_tags(inputs)

self._preprocess_inputs(inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inputs argument can be None, which will cause a crash later in this method when Template._encode(self, inputs) is called. A None check should be added at the beginning of the function to handle this case gracefully.

        if inputs is None:
            return {}
        self._preprocess_inputs(inputs)

if self.mode in {'vllm', 'lmdeploy', 'sglang'}:
encoded = Template._encode(self, inputs)
keys = ['images', 'audios', 'videos']
Expand Down Expand Up @@ -1425,6 +1423,8 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
res = self._embedding_data_collator(batch, padding_to=padding_to)
elif self.task_type in {'reranker', 'generative_reranker'}:
res = self._reranker_data_collator(batch, padding_to=padding_to)
else:
raise ValueError(f'task_type: {self.task_type} is not supported.')
if not self.remove_unused_columns:
extra_kwargs = [b['_extra_kwargs'] for b in batch if b.get('_extra_kwargs') is not None]
extra_kwargs = RowPreprocessor.rows_to_batched(extra_kwargs)
Expand Down
Loading
Loading