-
Notifications
You must be signed in to change notification settings - Fork 842
[dataset/template] refactor rejected_messages #5560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
33ec4aa
1b39d1f
8324d48
bc9dafb
2108d0e
2a10813
d14e2f7
cc6fbdc
7884f22
1a2773f
6d65f2a
8711a69
29865e9
9a24218
9e374b1
ea78160
d22c514
34074ec
a539821
df3c124
a8f3217
7df2976
2e4bdfa
5092dd9
d74956d
eafdd5d
060efdc
f68cadb
ae3c698
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -87,7 +87,7 @@ def __init__( | |||||
from .template_meta import TemplateMeta | ||||||
from swift.plugin import agent_templates, loss_scale_map | ||||||
self._processor_inited = False | ||||||
self._version = 'v2' # Avoid compatibility issues caused by load_from_cache_file caching. | ||||||
self._version = 'v3' # Avoid compatibility issues caused by load_from_cache_file caching. | ||||||
self.max_length = max_length | ||||||
self.model = None | ||||||
|
||||||
|
@@ -232,6 +232,13 @@ def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None: | |||||
else: | ||||||
i += 1 | ||||||
|
||||||
def _preprocess_inputs_reranker( | ||||||
self, | ||||||
inputs: StdTemplateInputs, | ||||||
) -> None: | ||||||
# TODO: remove | ||||||
return | ||||||
|
||||||
def _preprocess_inputs( | ||||||
self, | ||||||
inputs: StdTemplateInputs, | ||||||
|
@@ -241,31 +248,31 @@ def _preprocess_inputs( | |||||
self._replace_image_tags(inputs) | ||||||
self._replace_start_image_tags(inputs) | ||||||
|
||||||
for img_field in ['images', 'rejected_images']: | ||||||
images = getattr(inputs, img_field, None) | ||||||
if not images: | ||||||
continue | ||||||
load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'} | ||||||
load_images_origin = load_images | ||||||
if self.max_pixels is not None or inputs.objects: | ||||||
load_images = True | ||||||
if images: | ||||||
for i, image in enumerate(images): | ||||||
images[i] = self._load_image(image, load_images) | ||||||
if inputs.objects: | ||||||
self._get_height_width(inputs) | ||||||
if self.max_pixels is not None and images: | ||||||
images = [rescale_image(img, self.max_pixels) for img in images] | ||||||
if images and not load_images_origin: # fix pt & qwen-vl | ||||||
for i, image in enumerate(images): | ||||||
if isinstance(image, Image.Image): | ||||||
images[i] = self._save_pil_image(image) | ||||||
setattr(inputs, img_field, images) | ||||||
images = inputs.images | ||||||
load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'} | ||||||
load_images_origin = load_images | ||||||
if self.max_pixels is not None or inputs.objects: | ||||||
load_images = True | ||||||
if images: | ||||||
for i, image in enumerate(images): | ||||||
images[i] = self._load_image(images[i], load_images) | ||||||
if inputs.objects: | ||||||
self._get_height_width(inputs) | ||||||
if self.max_pixels is not None: | ||||||
# Scale the image proportionally without affecting the scaled objects. | ||||||
images = [rescale_image(img, self.max_pixels) for img in images] | ||||||
if images and not load_images_origin: # fix pt & qwen-vl | ||||||
for i, image in enumerate(images): | ||||||
if isinstance(image, Image.Image): | ||||||
images[i] = self._save_pil_image(image) | ||||||
inputs.images = images | ||||||
|
||||||
if self.mode == 'vllm' and inputs.audios: | ||||||
sampling_rate = get_env_args('sampling_rate', int, None) | ||||||
inputs.audios = load_batch( | ||||||
inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate, return_sr=True)) | ||||||
if inputs.is_multimodal: | ||||||
self._add_default_tags(inputs) | ||||||
|
||||||
@staticmethod | ||||||
def _replace_image_tags(inputs: StdTemplateInputs): | ||||||
|
@@ -331,17 +338,11 @@ def get_base_model(model): | |||||
else: | ||||||
return model | ||||||
|
||||||
def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | ||||||
margin = inputs.margin | ||||||
chosen_inputs, rejected_inputs = inputs, deepcopy(inputs) | ||||||
|
||||||
assert chosen_inputs.rejected_response or chosen_inputs.rejected_images, f'inputs: {inputs}' | ||||||
if chosen_inputs.rejected_response: | ||||||
rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response | ||||||
if chosen_inputs.rejected_images: | ||||||
rejected_inputs.images = chosen_inputs.rejected_images | ||||||
chosen_encoded = self._encode_truncated(chosen_inputs) | ||||||
rejected_encoded = self._encode_truncated(rejected_inputs) | ||||||
def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: | ||||||
chosen = inputs.chosen | ||||||
margin = chosen.margin | ||||||
chosen_encoded = self._encode_truncated(chosen) | ||||||
rejected_encoded = self._encode_truncated(inputs.rejected) | ||||||
|
||||||
encoded = {} | ||||||
for prefix in ['chosen', 'rejected']: | ||||||
|
@@ -352,10 +353,9 @@ def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | |||||
encoded['margin'] = float(margin) | ||||||
return encoded | ||||||
|
||||||
def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | ||||||
label, inputs.label = inputs.label, None | ||||||
def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: | ||||||
encoded = self._rlhf_encode(inputs) | ||||||
encoded['label'] = bool(label) | ||||||
encoded['label'] = bool(inputs.chosen.label) | ||||||
return encoded | ||||||
|
||||||
def _gkd_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | ||||||
|
@@ -366,7 +366,8 @@ def _gkd_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | |||||
encoded.pop(k, None) | ||||||
return encoded | ||||||
|
||||||
def _embedding_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | ||||||
def _embedding_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: | ||||||
inputs = inputs.chosen # TODO: refactor | ||||||
_encoded = {} | ||||||
labels = [] | ||||||
inference = len(inputs.messages) == 1 | ||||||
|
@@ -429,7 +430,9 @@ def split_multi_medias(_inputs): | |||||
_encoded.pop('labels', None) | ||||||
return _encoded | ||||||
|
||||||
def _reranker_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | ||||||
def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: | ||||||
inputs = inputs.chosen # TODO: refactor | ||||||
self._preprocess_inputs_reranker(inputs) | ||||||
_encoded = {} | ||||||
labels = [] | ||||||
|
||||||
|
@@ -486,45 +489,46 @@ def encode(self, | |||||
""" | ||||||
assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: ' | ||||||
'template.init_processor(processor).') | ||||||
if isinstance(inputs, (InferRequest, TemplateInputs)): | ||||||
if isinstance(inputs, InferRequest): | ||||||
inputs = asdict(inputs) | ||||||
|
||||||
if isinstance(inputs, dict): | ||||||
inputs = deepcopy(inputs) | ||||||
if self.task_type == 'causal_lm' and not self.is_training: | ||||||
InferRequest.remove_response(inputs['messages']) | ||||||
inputs = StdTemplateInputs.from_dict(inputs) | ||||||
elif isinstance(inputs, StdTemplateInputs): | ||||||
inputs = TemplateInputs.from_dict(inputs) | ||||||
elif isinstance(inputs, TemplateInputs): | ||||||
inputs = deepcopy(inputs) | ||||||
assert isinstance(inputs, StdTemplateInputs) | ||||||
self._preprocess_inputs(inputs) | ||||||
assert isinstance(inputs, TemplateInputs) | ||||||
|
||||||
chosen = inputs.chosen | ||||||
if self.task_type == 'causal_lm': | ||||||
if self.mode in {'train', 'pt', 'vllm', 'lmdeploy', 'sglang'}: | ||||||
encoded = self._encode_truncated(inputs) | ||||||
encoded = self._encode_truncated(chosen) | ||||||
elif self.mode == 'rlhf': | ||||||
encoded = self._rlhf_encode(inputs) | ||||||
elif self.mode == 'kto': | ||||||
encoded = self._kto_encode(inputs) | ||||||
elif self.mode == 'gkd': | ||||||
encoded = self._gkd_encode(inputs) | ||||||
encoded = self._gkd_encode(chosen) | ||||||
elif self.task_type == 'seq_cls': | ||||||
if self.mode == 'rlhf': | ||||||
encoded = self._rlhf_encode(inputs) | ||||||
for prefix in ['chosen', 'rejected']: | ||||||
encoded.pop(f'{prefix}_labels', None) | ||||||
encoded.pop(f'{prefix}_loss_scale', None) | ||||||
else: | ||||||
encoded = self._seq_cls_encode(inputs) | ||||||
encoded = self._seq_cls_encode(chosen) | ||||||
elif self.task_type == 'prm': | ||||||
encoded = self._encode_truncated(inputs) | ||||||
encoded = self._encode_truncated(chosen) | ||||||
elif self.task_type == 'embedding': | ||||||
encoded = self._embedding_encode(inputs) | ||||||
elif self.task_type in {'reranker', 'generative_reranker'}: | ||||||
encoded = self._reranker_encode(inputs) | ||||||
else: | ||||||
raise ValueError(f'task_type: {self.task_type} is not supported.') | ||||||
|
||||||
if inputs.channel is not None: | ||||||
encoded['channel'] = inputs.channel | ||||||
if chosen.channel is not None: | ||||||
encoded['channel'] = chosen.channel | ||||||
|
||||||
lengths = [0] | ||||||
for key in list(encoded.keys()): | ||||||
|
@@ -541,9 +545,9 @@ def encode(self, | |||||
else: | ||||||
encoded.pop('length', None) | ||||||
if return_template_inputs: | ||||||
encoded['template_inputs'] = inputs | ||||||
encoded['template_inputs'] = chosen | ||||||
if not self.remove_unused_columns: | ||||||
encoded['_extra_kwargs'] = inputs.extra_kwargs | ||||||
encoded['_extra_kwargs'] = chosen.extra_kwargs | ||||||
return encoded | ||||||
|
||||||
def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]: | ||||||
|
@@ -932,11 +936,6 @@ def _add_default_tags(inputs: StdTemplateInputs): | |||||
elif message['role'] == 'assistant': | ||||||
continue | ||||||
total_content.append(content) | ||||||
if inputs.rejected_response: | ||||||
rejected_response = inputs.rejected_response | ||||||
if isinstance(inputs.rejected_response, str): | ||||||
rejected_response = [rejected_response] | ||||||
total_content += rejected_response | ||||||
total_content = '\n'.join(total_content) | ||||||
if inputs.system: | ||||||
total_content = f'{inputs.system}\n{total_content}' | ||||||
|
@@ -1104,6 +1103,7 @@ def _swift_encode(self, inputs: StdTemplateInputs): | |||||
prefix = template_meta.system_prefix | ||||||
self._concat_context_list(prefix, res_context_list, res_context_types, system=system) | ||||||
|
||||||
assert len(inputs.messages) > 0, f'inputs.messages: {inputs.messages}' | ||||||
n_round = len(inputs.messages) // 2 | ||||||
for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])): | ||||||
query_role, query = query_message['role'], query_message['content'] | ||||||
|
@@ -1145,7 +1145,7 @@ def _swift_encode(self, inputs: StdTemplateInputs): | |||||
if isinstance(stop_word, str)) | ||||||
# self.is_training needed because we may want to continue generation from | ||||||
# the current response | ||||||
if (self.is_training and not sep_token or self.task_type == 'embedding') and not endswith_stop_words: | ||||||
if (self.is_training or self.task_type != 'causal_lm') and not sep_token and not endswith_stop_words: | ||||||
extra_context_list = template_meta.suffix | ||||||
extra_context_type = ContextType.SUFFIX | ||||||
elif template_meta.response_prefix: | ||||||
|
@@ -1193,9 +1193,7 @@ def _truncate(self, input_ids: List[int], labels: Optional[List[int]], loss_mask | |||||
return input_ids, labels, loss_mask | ||||||
|
||||||
def _encode_truncated(self, inputs: StdTemplateInputs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint for
Suggested change
|
||||||
if inputs.is_multimodal: | ||||||
self._add_default_tags(inputs) | ||||||
|
||||||
self._preprocess_inputs(inputs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
if self.mode in {'vllm', 'lmdeploy', 'sglang'}: | ||||||
encoded = Template._encode(self, inputs) | ||||||
keys = ['images', 'audios', 'videos'] | ||||||
|
@@ -1425,6 +1423,8 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int | |||||
res = self._embedding_data_collator(batch, padding_to=padding_to) | ||||||
elif self.task_type in {'reranker', 'generative_reranker'}: | ||||||
res = self._reranker_data_collator(batch, padding_to=padding_to) | ||||||
else: | ||||||
raise ValueError(f'task_type: {self.task_type} is not supported.') | ||||||
if not self.remove_unused_columns: | ||||||
extra_kwargs = [b['_extra_kwargs'] for b in batch if b.get('_extra_kwargs') is not None] | ||||||
extra_kwargs = RowPreprocessor.rows_to_batched(extra_kwargs) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
inputs
argument can beNone
when called from_encode_truncated
(e.g., with aNone
rejected sample). The type hint should be updated toOptional[StdTemplateInputs]
to accurately reflect this.