diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 40bb1e44..2ac254a4 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -12,7 +12,7 @@ from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, ) -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from eval_protocol.get_pep440_version import get_pep440_version from eval_protocol.human_id import generate_id @@ -595,7 +595,7 @@ class EvaluationRow(BaseModel): supporting both row-wise batch evaluation and trajectory-based RL evaluation. """ - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra="allow", validate_assignment=True) # Core OpenAI ChatCompletion compatible conversation data messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.") @@ -626,6 +626,17 @@ class EvaluationRow(BaseModel): default=None, description="The evaluation result for this row/trajectory." ) + @field_validator("evaluation_result", mode="before") + @classmethod + def _coerce_evaluation_result( + cls, value: EvaluateResult | dict[str, Any] | None + ) -> EvaluateResult | None: + if value is None or isinstance(value, EvaluateResult): + return value + if isinstance(value, dict): + return EvaluateResult(**value) + return value + execution_metadata: ExecutionMetadata = Field( default_factory=lambda: ExecutionMetadata(run_id=None), description="Metadata about the execution of the evaluation.", diff --git a/tests/test_models.py b/tests/test_models.py index 723685b8..bdcd5d9b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -198,8 +198,14 @@ def test_metric_result_dict_access(): } assert set(metric.items()) == expected_items - # __iter__ - assert set(list(metric)) == {"score", "reason", "is_score_valid"} + +def test_evaluation_row_accepts_dict_assignment_for_evaluation_result(): + row = dummy_row() + row.evaluation_result = {"score": 0.6} + + assert isinstance(row.evaluation_result, EvaluateResult) + assert row.evaluation_result.score == 0.6 + assert row.evaluation_result.is_score_valid is True def test_evaluate_result_dict_access():