Skip to content

Commit 87fdc2c

Browse files
authored
Add "weight" to "Message" and helper func to dump model for chat completion request (#208)
* add * format * add ut * add * add * format * fix * ad * format
1 parent 1002941 commit 87fdc2c

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

eval_protocol/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ class Message(BaseModel):
281281
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
282282
function_call: Optional[FunctionCall] = None
283283
control_plane_step: Optional[Dict[str, Any]] = None
284+
weight: Optional[int] = None
285+
286+
def dump_mdoel_for_chat_completion_request(self):
287+
"""Only keep chat completion accepted fields"""
288+
return self.model_dump(exclude_none=True, exclude={"control_plane_step", "reasoning_content", "weight"})
284289

285290
@classmethod
286291
def model_validate(cls, obj, *args, **kwargs):

tests/adapters/test_openai_responses_adapter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def test_openai_responses_adapter_with_real_response_simple(snapshot: SnapshotAs
2222
assert len(eval_rows) == 1
2323

2424
# Convert to dict for snapshot testing
25-
eval_rows_dict = [row.model_dump(exclude={"created_at", "execution_metadata"}) for row in eval_rows]
25+
eval_rows_dict = [
26+
row.model_dump(exclude={"created_at": True, "execution_metadata": True, "messages": {"__all__": {"weight"}}})
27+
for row in eval_rows
28+
]
2629

2730
# Assert against snapshot
2831
assert eval_rows_dict == snapshot
@@ -42,7 +45,10 @@ def test_openai_responses_adapter_with_real_response_parallel_tool_calls(snapsho
4245
assert len(eval_rows) == 1
4346

4447
# Convert to dict for snapshot testing
45-
eval_rows_dict = [row.model_dump(exclude={"created_at", "execution_metadata"}) for row in eval_rows]
48+
eval_rows_dict = [
49+
row.model_dump(exclude={"created_at": True, "execution_metadata": True, "messages": {"__all__": {"weight"}}})
50+
for row in eval_rows
51+
]
4652

4753
# Assert against snapshot
4854
assert eval_rows_dict == snapshot

tests/test_models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,3 +694,30 @@ def test_evaluation_row_extra_fields():
694694
assert "eval" in dictionary
695695
assert "accuracy" in dictionary["eval_details"]["metrics"]
696696
assert "test" in dictionary["extra_fields"]
697+
698+
699+
def test_message_with_weight_dump():
700+
example = {
701+
"role": "user",
702+
"content": "Hello, how are you?",
703+
"weight": 0,
704+
}
705+
706+
message = Message(**example)
707+
dictionary = message.model_dump()
708+
assert "weight" in dictionary
709+
assert dictionary["weight"] == 0
710+
711+
712+
def test_message_dump_for_chat_completion_request():
713+
example = {
714+
"role": "user",
715+
"content": "Hello, how are you?",
716+
"weight": 0,
717+
"reasoning_content": "I am thinking about the user's question",
718+
}
719+
message = Message(**example)
720+
dictionary = message.dump_mdoel_for_chat_completion_request()
721+
assert "weight" not in dictionary
722+
assert "reasoning_content" not in dictionary
723+
assert dictionary["content"] == "Hello, how are you?"

0 commit comments

Comments
 (0)