Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, field_validator

from eval_protocol.get_pep440_version import get_pep440_version
from eval_protocol.human_id import generate_id
Expand Down Expand Up @@ -595,7 +595,7 @@ class EvaluationRow(BaseModel):
supporting both row-wise batch evaluation and trajectory-based RL evaluation.
"""

model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", validate_assignment=True)

# Core OpenAI ChatCompletion compatible conversation data
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
Expand Down Expand Up @@ -626,6 +626,17 @@ class EvaluationRow(BaseModel):
default=None, description="The evaluation result for this row/trajectory."
)

@field_validator("evaluation_result", mode="before")
@classmethod
def _coerce_evaluation_result(
cls, value: EvaluateResult | dict[str, Any] | None
) -> EvaluateResult | None:
if value is None or isinstance(value, EvaluateResult):
return value
if isinstance(value, dict):
return EvaluateResult(**value)
return value

execution_metadata: ExecutionMetadata = Field(
default_factory=lambda: ExecutionMetadata(run_id=None),
description="Metadata about the execution of the evaluation.",
Expand Down
10 changes: 8 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,14 @@ def test_metric_result_dict_access():
}
assert set(metric.items()) == expected_items

# __iter__
assert set(list(metric)) == {"score", "reason", "is_score_valid"}

def test_evaluation_row_accepts_dict_assignment_for_evaluation_result():
row = dummy_row()
row.evaluation_result = {"score": 0.6}

assert isinstance(row.evaluation_result, EvaluateResult)
assert row.evaluation_result.score == 0.6
assert row.evaluation_result.is_score_valid is True


def test_evaluate_result_dict_access():
Expand Down
Loading