-
Notifications
You must be signed in to change notification settings - Fork 840
[dataset/template] refactor rejected_messages #5560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dataset/template] refactor rejected_messages #5560
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request refactors the handling of "rejected messages" within the dataset and template processing components. The primary change involves introducing a new TemplateInputs
dataclass that acts as a wrapper, explicitly containing both chosen
and rejected
instances of StdTemplateInputs
. This design centralizes the logic for processing rejected data, removes direct rejected_response
and rejected_images
fields from StdTemplateInputs
, and streamlines the data flow for tasks like Reinforcement Learning from Human Feedback (RLHF), leading to a clearer and more maintainable data structure.
Highlights
- New Data Structure: A new
TemplateInputs
dataclass has been introduced to explicitly encapsulate bothchosen
andrejected
data, each represented by an instance ofStdTemplateInputs
. - Simplified
StdTemplateInputs
: Therejected_response
andrejected_images
fields have been removed directly fromStdTemplateInputs
, simplifying its structure and promoting a clearer separation of concerns. - Centralized Rejected Data Logic: The logic for processing and converting a
rejected_response
into a fullrejected_messages
structure has been centralized within the newTemplateInputs
class, improving code organization. - Dynamic Key Generation: The
RowPreprocessor
now dynamically generatesrejected_
keys for various media types, enhancing the maintainability and extensibility of the dataset processing component.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the handling of rejected samples in datasets and templates by introducing a TemplateInputs
class with chosen
and rejected
attributes. This is a good structural improvement. However, the implementation has critical flaws. The from_dict
method in TemplateInputs
is buggy and will cause crashes. Furthermore, the data processing paths for RLHF and KTO training modes are broken due to not being updated to the new data structures. These issues need to be addressed for the refactoring to be functional.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the handling of chosen and rejected samples for RLHF-style training. It introduces a new TemplateInputs
data class that explicitly separates chosen
and rejected
inputs, which is a significant improvement in code clarity and structure over the previous flat format using rejected_response
. The changes are consistently applied across the dataset preprocessors, templates, and trainers. However, this refactoring has introduced a critical issue that breaks the functionality for embedding and reranking tasks, as they rely on a list of multiple negative samples which the new TemplateInputs
structure does not support.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-designed refactoring of how rejected messages and responses are handled. The new TemplateInputs
dataclass, which separates chosen
and rejected
samples, is a major improvement in code clarity and structure. The changes are consistently applied across the dataset preprocessor, template base class, and trainer, and the tests have been updated accordingly.
However, I've identified a critical issue in swift/llm/template/base.py
where the refactoring for embedding
and reranker
modes appears incomplete, which will lead to a runtime error. Additionally, there's a logic bug in swift/llm/template/template_inputs.py
that prevents attributes from being correctly copied from the chosen
to the rejected
sample. These issues should be addressed to complete this excellent refactoring.
swift/llm/template/base.py
Outdated
def _compat_rejected_response(inputs: TemplateInputs) -> StdTemplateInputs: | ||
chosen = inputs.chosen | ||
rejected_response = None | ||
if inputs.rejected: | ||
rejected_response = inputs.rejected.messages | ||
chosen.rejected_response = rejected_response | ||
return chosen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _compat_rejected_response
method has two critical issues that will cause it to fail at runtime:
AttributeError
: The linechosen.rejected_response = rejected_response
attempts to assign torejected_response
on aStdTemplateInputs
object. However, this attribute was removed from theStdTemplateInputs
dataclass inswift/llm/template/template_inputs.py
, which will result in anAttributeError
.- Type Mismatch: The line
rejected_response = inputs.rejected.messages
assigns aList[Dict]
torejected_response
. The legacy code in_embedding_encode
and_reranker_encode
, which this method is intended to support, expectsrejected_response
to be aList[str]
.
This method is marked with a TODO: refactor
, and it seems this part of the refactoring is incomplete. To fix this, _embedding_encode
and _reranker_encode
should be updated to work directly with the new TemplateInputs
structure, similar to how _rlhf_encode
was refactored. This compatibility method should then be removed.
if getattr(self.rejected, k) is None: | ||
setattr(self.rejected, k, getattr(self.chosen, k)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to fill empty entries in the rejected
sample from the chosen
sample is incorrect for fields with a default_factory
(e.g., images
, videos
, audios
, objects
).
The current check if getattr(self.rejected, k) is None:
will not work for these fields because getattr
will return an empty list or dict (e.g., []
), which is not None
. As a result, if chosen
has images but rejected_images
is not provided, rejected.images
will remain an empty list instead of being populated from chosen.images
.
To fix this, you should also check if the attribute is an empty container.
if getattr(self.rejected, k) is None: | |
setattr(self.rejected, k, getattr(self.chosen, k)) | |
if getattr(self.rejected, k) is None or not getattr(self.rejected, k): |
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring for handling paired chosen/rejected data, which is common in RLHF scenarios. The new TemplateInputs
data class is a major improvement in terms of code clarity, maintainability, and extensibility. The changes are consistently applied across the codebase, and the inclusion of compatibility layers for older data formats is a thoughtful addition. The test cases have also been diligently updated to reflect the new API. I've found one potential bug in the implementation, which I've detailed in a specific comment.
if len(messages[i:]) == 1: | ||
response = messages[i]['content'] | ||
assert value != response, (f'rejected_response: {value}, response: {response}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the _use_rejected_messages
method, the logic to check if the rejected_response
is different from the chosen response seems to use the wrong index. The loop variable i
is used to index messages
, but i
is the counter for multiple rejected responses, not an index for the messages
list. This can lead to an IndexError
or incorrect comparisons.
You should use the idx
variable, which is correctly calculated earlier to identify the start of the assistant's turn.
if len(messages[i:]) == 1: | |
response = messages[i]['content'] | |
assert value != response, (f'rejected_response: {value}, response: {response}') | |
if len(messages[idx:]) == 1: | |
response = messages[idx]['content'] | |
assert value != response, (f'rejected_response: {value}, response: {response}') |
…essages' into refactor_rejected_messages
…essages' into refactor_rejected_messages
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring of how rejected messages and preference data are handled. The new TemplateInputs
class provides a much cleaner and more extensible structure for various RLHF-style training tasks. The implementation, including backward compatibility for the old rejected_response
format, is commendable. My review focuses on a potential crash when handling optional rejected samples and a few opportunities to strengthen type validation for increased robustness. Overall, this is an excellent improvement to the codebase.
if inputs.is_multimodal: | ||
self._add_default_tags(inputs) | ||
|
||
self._preprocess_inputs(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -236,36 +236,38 @@ | |||
self, | |||
inputs: StdTemplateInputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1193,9 +1187,7 @@ | |||
return input_ids, labels, loss_mask | |||
|
|||
def _encode_truncated(self, inputs: StdTemplateInputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for inputs
should be Optional[StdTemplateInputs]
since it can receive None
(e.g., from inputs.rejected
in _rlhf_encode
). This will be addressed in a subsequent comment, but updating the type hint is the first step for correctness.
def _encode_truncated(self, inputs: StdTemplateInputs): | |
def _encode_truncated(self, inputs: Optional[StdTemplateInputs]): |
if self.rejected_images and not isinstance(self.rejected_images, (list, tuple)): | ||
self.rejected_images = [self.rejected_images] | ||
if self.rejected_response: | ||
assert isinstance(self.rejected_response, list) and isinstance(self.rejected_response[0], str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion only checks if the first element of rejected_response
is a string. This could allow a list with mixed types (e.g., ['foo', 123]
) to pass validation, potentially causing errors downstream. The check should be strengthened to validate all elements in the list.
assert isinstance(self.rejected_response, list) and isinstance(self.rejected_response[0], str) | |
assert isinstance(self.rejected_response, list) and all(isinstance(item, str) for item in self.rejected_response) |
…essages' into refactor_rejected_messages
9d11523
to
2e4bdfa
Compare
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring of how preference data, particularly rejected_messages
, is handled for RLHF training. The introduction of the TemplateInputs
dataclass to structure chosen
and rejected
pairs is a commendable design choice that enhances code clarity and maintainability. The changes are consistently applied across dataset preprocessing, templates, and trainers, and the documentation has been updated accordingly.
My review focuses on ensuring consistency and correctness in the updated documentation and tests. I've identified a minor issue in the English documentation where examples contain untranslated Chinese text. Additionally, I've pointed out several instances in the test files where the new TemplateInputs
class should be instantiated using the more robust from_dict()
classmethod for consistency and to prevent potential issues. These are medium-severity suggestions to improve the quality of the changes.
```jsonl | ||
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小狗。"}], "rejected_images": ["cat.png"]} | ||
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "rejected_images": ["dog.png"]} | ||
``` | ||
|
||
The above format is equivalent to: | ||
|
||
```jsonl | ||
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_response": "这是一只小狗。"} | ||
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_images": ["dog.png"]} | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JSON examples in the English documentation contain Chinese text. For consistency and clarity for an English-speaking audience, it would be better to translate this text into English or use English placeholders.
For example:
这是什么
->What is this?
这是一只小猫咪。
->This is a kitten.
这是一只小狗。
->This is a puppy.
```jsonl | |
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小狗。"}], "rejected_images": ["cat.png"]} | |
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "rejected_images": ["dog.png"]} | |
``` | |
The above format is equivalent to: | |
```jsonl | |
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_response": "这是一只小狗。"} | |
{"messages": [{"role": "user", "content": "<image>这是什么"}, {"role": "assistant", "content": "这是一只小猫咪。"}], "images": ["cat.png"], "rejected_images": ["dog.png"]} | |
``` | |
{"messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "images": ["cat.png"], "rejected_messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a puppy."}], "rejected_images": ["cat.png"]} | |
{"messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "images": ["cat.png"], "rejected_messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "rejected_images": ["dog.png"]} |
The above format is equivalent to:
{"messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "images": ["cat.png"], "rejected_response": "This is a puppy."}
{"messages": [{"role": "user", "content": "<image>What is this?"}, {"role": "assistant", "content": "This is a kitten."}], "images": ["cat.png"], "rejected_images": ["dog.png"]}
inputs = TemplateInputs({ | ||
'messages': [{ | ||
'role': 'system', | ||
'content': '000' | ||
}, { | ||
'role': 'user', | ||
'content': 'aaa' | ||
}, { | ||
'role': 'assistant', | ||
'content': 'bbb' | ||
}, { | ||
'role': 'user', | ||
'content': 'ccc' | ||
}] | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency and clarity, it's better to use the TemplateInputs.from_dict()
classmethod to construct the TemplateInputs
object from a dictionary. This is the intended way to create an instance from a flat data structure and aligns with the implementation in other test files.
inputs = TemplateInputs({ | |
'messages': [{ | |
'role': 'system', | |
'content': '000' | |
}, { | |
'role': 'user', | |
'content': 'aaa' | |
}, { | |
'role': 'assistant', | |
'content': 'bbb' | |
}, { | |
'role': 'user', | |
'content': 'ccc' | |
}] | |
}) | |
inputs = TemplateInputs.from_dict({ | |
'messages': [{ | |
'role': 'system', | |
'content': '000' | |
}, { | |
'role': 'user', | |
'content': 'aaa' | |
}, { | |
'role': 'assistant', | |
'content': 'bbb' | |
}, { | |
'role': 'user', | |
'content': 'ccc' | |
}] | |
}) |
inputs = TemplateInputs({ | ||
'messages': [{ | ||
'role': | ||
'user', | ||
'content': | ||
'Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins ' | ||
"for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per " | ||
"fresh duck egg. How much in dollars does she make every day at the farmers' market?" | ||
}, { | ||
'role': | ||
'assistant', | ||
'content': | ||
"To determine how much Janet makes from selling the duck eggs at the farmers' market, we need to " | ||
'follow these steps:\n\n1. Calculate the total number of eggs laid by the ducks each day.\n2. ' | ||
'Determine how many eggs Janet eats and bakes for herself each day.\n3. Find out how many eggs are ' | ||
"left to be sold.\n4. Calculate the revenue from selling the remaining eggs at $2 per egg.\n\nLet's " | ||
"start with the first step:\n\n1. Janet's ducks lay 16 eggs per day.\n\nNext, we calculate how many " | ||
'eggs Janet eats and bakes for herself each day:\n\n2. Janet eats 3 eggs for breakfast every morning.' | ||
'\n3. Janet bakes 4 eggs for her friends every day.\n\nSo, the total number of eggs Janet eats and ' | ||
'bakes for herself each day is:\n\\[ 3 + 4 = 7 \\text{ eggs} \\]\n\nNow, we find out how many eggs ' | ||
'are left to be sold:\n\\[ 16 - 7 = 9 \\text{ eggs} \\]\n\nFinally, we calculate the revenue from ' | ||
'selling the remaining eggs at $2 per egg:\n\\[ 9 \\times 2 = 18 \\text{ dollars} \\]\n\nTherefore, ' | ||
"Janet makes \\(\\boxed{18}\\) dollars every day at the farmers' market." | ||
}] | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency and clarity, it's better to use the TemplateInputs.from_dict()
classmethod to construct the TemplateInputs
object from a dictionary. This is the intended way to create an instance from a flat data structure and aligns with the implementation in other test files.
inputs = TemplateInputs.from_dict({
'messages': [{
'role':
'user',
'content':
'Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins '
"for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per "
"fresh duck egg. How much in dollars does she make every day at the farmers' market?"
}, {
'role':
'assistant',
'content':
"To determine how much Janet makes from selling the duck eggs at the farmers' market, we need to "
'follow these steps:\n\n1. Calculate the total number of eggs laid by the ducks each day.\n2. '
'Determine how many eggs Janet eats and bakes for herself each day.\n3. Find out how many eggs are '
"left to be sold.\n4. Calculate the revenue from selling the remaining eggs at $2 per egg.\n\nLet's "
"start with the first step:\n\n1. Janet's ducks lay 16 eggs per day.\n\nNext, we calculate how many "
'eggs Janet eats and bakes for herself each day:\n\n2. Janet eats 3 eggs for breakfast every morning.'
'\n3. Janet bakes 4 eggs for her friends every day.\n\nSo, the total number of eggs Janet eats and '
'bakes for herself each day is:\n\\[ 3 + 4 = 7 \\text{ eggs} \\_]\n\nNow, we find out how many eggs '
'are left to be sold:\n\\[ 16 - 7 = 9 \\text{ eggs} \\_]\n\nFinally, we calculate the revenue from '
'selling the remaining eggs at $2 per egg:\n\\[ 9 \\times 2 = 18 \\text{ dollars} \\_]\n\nTherefore, '
"Janet makes \\(\\boxed{18}\\) dollars every day at the farmers' market."
}]
})
inputs = TemplateInputs({ | ||
'messages': [{ | ||
'role': 'system', | ||
'content': 'You are a helpful assistant created by MiniMax based on MiniMax-Text-01 model.' | ||
}, { | ||
'role': 'user', | ||
'content': 'Hello!' | ||
}]) | ||
}] | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency and clarity, it's better to use the TemplateInputs.from_dict()
classmethod to construct the TemplateInputs
object from a dictionary. This is the intended way to create an instance from a flat data structure and aligns with the implementation in other test files.
inputs = TemplateInputs({ | |
'messages': [{ | |
'role': 'system', | |
'content': 'You are a helpful assistant created by MiniMax based on MiniMax-Text-01 model.' | |
}, { | |
'role': 'user', | |
'content': 'Hello!' | |
}]) | |
}] | |
}) | |
inputs = TemplateInputs.from_dict({ | |
'messages': [{ | |
'role': 'system', | |
'content': 'You are a helpful assistant created by MiniMax based on MiniMax-Text-01 model.' | |
}, { | |
'role': 'user', | |
'content': 'Hello!' | |
}] | |
}) |
inputs = TemplateInputs({ | ||
'messages': [{ | ||
'role': 'system', | ||
'content': 'You are a helpful assistant created by MiniMax based on MiniMax-VL-01 model.' | ||
}, { | ||
'role': 'user', | ||
'content': '<image>Describe this image.' | ||
}], | ||
images=['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png']) | ||
'images': ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png'] | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency and clarity, it's better to use the TemplateInputs.from_dict()
classmethod to construct the TemplateInputs
object from a dictionary. This is the intended way to create an instance from a flat data structure and aligns with the implementation in other test files.
inputs = TemplateInputs({ | |
'messages': [{ | |
'role': 'system', | |
'content': 'You are a helpful assistant created by MiniMax based on MiniMax-VL-01 model.' | |
}, { | |
'role': 'user', | |
'content': '<image>Describe this image.' | |
}], | |
images=['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png']) | |
'images': ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png'] | |
}) | |
inputs = TemplateInputs.from_dict({ | |
'messages': [{ | |
'role': 'system', | |
'content': 'You are a helpful assistant created by MiniMax based on MiniMax-VL-01 model.' | |
}, { | |
'role': 'user', | |
'content': '<image>Describe this image.' | |
}], | |
'images': ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png'] | |
}) |
…essages' into refactor_rejected_messages
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring of how preference data, especially rejected samples, are handled. The introduction of the new TemplateInputs
dataclass with chosen
and rejected
fields provides a much cleaner and more structured way to manage paired data for RLHF, DPO, and other preference-based training methods. The changes are consistent across the codebase, including updates to preprocessing, templates, and tests. The logic to convert legacy rejected_response
formats to the new rejected_messages
structure is well-encapsulated. My main feedback is a minor but important typo in template_inputs.py
that affects the data supplementation logic.
has_rejected_mssages = inputs.get('rejected_messages') is not None | ||
cls._compat_rejected_response(inputs) | ||
rejected_response = inputs.pop('rejected_response', None) | ||
kwargs = {} | ||
non_chosen_keys = ['rejected', 'positive', 'negative'] | ||
for prefix in ['chosen'] + non_chosen_keys: | ||
if prefix == 'chosen': | ||
std_inputs = { | ||
k: v | ||
for k, v in inputs.items() if not any(k.startswith(f'{p}_') for p in non_chosen_keys) | ||
} | ||
else: | ||
std_inputs = {k[len(f'{prefix}_'):]: v for k, v in inputs.items() if k.startswith(f'{prefix}_')} | ||
if std_inputs: | ||
kwargs[prefix] = std_inputs | ||
|
||
if not has_rejected_mssages and kwargs.get('rejected') is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in the has_rejected_mssages
variable name. It should be has_rejected_messages
. This typo causes the data supplementation logic for the rejected
sample to always execute, even when rejected_messages
is explicitly provided. This can lead to incorrect data being used for training, for example, by incorrectly copying images
from the chosen
sample to a rejected
sample that should not have any. Please correct the variable name to ensure the supplementation logic is conditional as intended.
has_rejected_mssages = inputs.get('rejected_messages') is not None | |
cls._compat_rejected_response(inputs) | |
rejected_response = inputs.pop('rejected_response', None) | |
kwargs = {} | |
non_chosen_keys = ['rejected', 'positive', 'negative'] | |
for prefix in ['chosen'] + non_chosen_keys: | |
if prefix == 'chosen': | |
std_inputs = { | |
k: v | |
for k, v in inputs.items() if not any(k.startswith(f'{p}_') for p in non_chosen_keys) | |
} | |
else: | |
std_inputs = {k[len(f'{prefix}_'):]: v for k, v in inputs.items() if k.startswith(f'{prefix}_')} | |
if std_inputs: | |
kwargs[prefix] = std_inputs | |
if not has_rejected_mssages and kwargs.get('rejected') is not None: | |
has_rejected_messages = inputs.get('rejected_messages') is not None | |
cls._compat_rejected_response(inputs) | |
rejected_response = inputs.pop('rejected_response', None) | |
kwargs = {} | |
non_chosen_keys = ['rejected', 'positive', 'negative'] | |
for prefix in ['chosen'] + non_chosen_keys: | |
if prefix == 'chosen': | |
std_inputs = { | |
k: v | |
for k, v in inputs.items() if not any(k.startswith(f'{p}_') for p in non_chosen_keys) | |
} | |
else: | |
std_inputs = {k[len(f'{prefix}_'):]: v for k, v in inputs.items() if k.startswith(f'{prefix}_')} | |
if std_inputs: | |
kwargs[prefix] = std_inputs | |
if not has_rejected_messages and kwargs.get('rejected') is not None: |
No description provided.