Skip to content

Commit 13ac18a

Browse files
authored
refactor rollout status to be AIP-193 compatible (#109)
* save * update * fix * fix test_status_model * Remove backwards compatibility methods for rollout status from EvaluationRow and associated tests. * fix test_status_migration_integration * fix test_migration_Changes * delete * fix test_retry_mechanism * fix tests * remove unused
1 parent 7dac466 commit 13ac18a

File tree

9 files changed

+1393
-59
lines changed

9 files changed

+1393
-59
lines changed

β€Ževal_protocol/mcp/execution/manager.pyβ€Ž

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
2121
from vendor.tau2.user.user_simulator import UserSimulator
2222

23-
from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus
23+
from ...models import EvaluationRow, InputMetadata, Message, Status
2424
from ...types import TerminationReason, Trajectory, NonSkippableException
2525

2626
if TYPE_CHECKING:
@@ -136,15 +136,14 @@ async def _execute_with_semaphore(idx):
136136
}
137137

138138
if trajectory.terminated:
139-
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
140-
evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED
141-
# preserve the true error mesage if there are any
139+
extra_info = None
142140
if trajectory.control_plane_summary.get("error_message"):
143-
evaluation_row.rollout_status.extra_info = {
144-
"error_message": trajectory.control_plane_summary.get("error_message")
145-
}
141+
extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")}
142+
evaluation_row.rollout_status = Status.rollout_finished(
143+
termination_reason=trajectory.termination_reason, extra_info=extra_info
144+
)
146145
else:
147-
evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING
146+
evaluation_row.rollout_status = Status.rollout_running()
148147

149148
return evaluation_row
150149

β€Ževal_protocol/models.pyβ€Ž

Lines changed: 186 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from datetime import datetime
33
from enum import Enum
4-
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
4+
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union
55

66
from openai.types import CompletionUsage
77
from openai.types.chat.chat_completion_message import (
@@ -15,6 +15,188 @@
1515
from eval_protocol.types import TerminationReason
1616

1717

18+
class ErrorInfo(BaseModel):
19+
"""
20+
AIP-193 ErrorInfo model for structured error details.
21+
22+
This model follows Google's AIP-193 standard for ErrorInfo:
23+
https://google.aip.dev/193#errorinfo
24+
25+
Attributes:
26+
reason (str): A short snake_case description of the cause of the error.
27+
domain (str): The logical grouping to which the reason belongs.
28+
metadata (Dict[str, Any]): Additional dynamic information as context.
29+
"""
30+
31+
# Constants for reason values
32+
REASON_TERMINATION_REASON: ClassVar[str] = "TERMINATION_REASON"
33+
REASON_EXTRA_INFO: ClassVar[str] = "EXTRA_INFO"
34+
35+
# Domain constant
36+
DOMAIN: ClassVar[str] = "evalprotocol.io"
37+
38+
reason: str = Field(..., description="Short snake_case description of the error cause")
39+
domain: str = Field(..., description="Logical grouping for the error reason")
40+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional dynamic information as context")
41+
42+
def to_aip193_format(self) -> Dict[str, Any]:
43+
"""Convert to AIP-193 format with @type field."""
44+
return {
45+
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
46+
"reason": self.reason,
47+
"domain": self.domain,
48+
"metadata": self.metadata,
49+
}
50+
51+
@classmethod
52+
def termination_reason(cls, reason: TerminationReason) -> "ErrorInfo":
53+
"""Create an ErrorInfo for termination reason."""
54+
# Convert TerminationReason enum to string if needed
55+
reason_str = reason.value if isinstance(reason, TerminationReason) else reason
56+
return cls(
57+
reason=cls.REASON_TERMINATION_REASON, domain=cls.DOMAIN, metadata={"termination_reason": reason_str}
58+
)
59+
60+
@classmethod
61+
def extra_info(cls, metadata: Dict[str, Any]) -> "ErrorInfo":
62+
"""Create an ErrorInfo for extra information."""
63+
return cls(reason=cls.REASON_EXTRA_INFO, domain=cls.DOMAIN, metadata=metadata)
64+
65+
66+
class Status(BaseModel):
67+
"""
68+
AIP-193 compatible Status model for standardized error responses.
69+
70+
This model follows Google's AIP-193 standard for error handling:
71+
https://google.aip.dev/193
72+
73+
Attributes:
74+
code (int): The status code, must be the numeric value of one of the elements
75+
of google.rpc.Code enum (e.g., 5 for NOT_FOUND).
76+
message (str): Developer-facing, human-readable debug message in English.
77+
details (List[Dict[str, Any]]): Additional error information, each packed in
78+
a google.protobuf.Any message format.
79+
"""
80+
81+
code: "Status.Code" = Field(..., description="The status code from google.rpc.Code enum")
82+
message: str = Field(..., description="Developer-facing, human-readable debug message in English")
83+
details: List[Dict[str, Any]] = Field(
84+
default_factory=list,
85+
description="Additional error information, each packed in a google.protobuf.Any message format",
86+
)
87+
88+
# Convenience constants for common status codes
89+
class Code(int, Enum):
90+
"""Common gRPC status codes as defined in google.rpc.Code"""
91+
92+
OK = 0
93+
CANCELLED = 1
94+
UNKNOWN = 2
95+
INVALID_ARGUMENT = 3
96+
DEADLINE_EXCEEDED = 4
97+
NOT_FOUND = 5
98+
ALREADY_EXISTS = 6
99+
PERMISSION_DENIED = 7
100+
RESOURCE_EXHAUSTED = 8
101+
FAILED_PRECONDITION = 9
102+
ABORTED = 10
103+
OUT_OF_RANGE = 11
104+
UNIMPLEMENTED = 12
105+
INTERNAL = 13
106+
UNAVAILABLE = 14
107+
DATA_LOSS = 15
108+
UNAUTHENTICATED = 16
109+
110+
# Custom codes for rollout states (using higher numbers to avoid conflicts)
111+
FINISHED = 100
112+
113+
@classmethod
114+
def rollout_running(cls) -> "Status":
115+
"""Create a status indicating the rollout is running."""
116+
return cls(code=cls.Code.OK, message="Rollout is running", details=[])
117+
118+
@classmethod
119+
def rollout_finished(
120+
cls,
121+
termination_reason: Optional[TerminationReason] = None,
122+
extra_info: Optional[Dict[str, Any]] = None,
123+
) -> "Status":
124+
"""Create a status indicating the rollout finished."""
125+
details = []
126+
if termination_reason:
127+
details.append(ErrorInfo.termination_reason(termination_reason).to_aip193_format())
128+
if extra_info:
129+
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
130+
return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details)
131+
132+
@classmethod
133+
def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status":
134+
"""Create a status indicating the rollout failed with an error."""
135+
details = []
136+
if extra_info:
137+
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
138+
return cls.error(error_message, details)
139+
140+
@classmethod
141+
def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
142+
"""Create a status indicating the rollout failed with an error."""
143+
return cls(code=cls.Code.INTERNAL, message=error_message, details=details)
144+
145+
def is_running(self) -> bool:
146+
"""Check if the status indicates the rollout is running."""
147+
return self.code == self.Code.OK and self.message == "Rollout is running"
148+
149+
def is_finished(self) -> bool:
150+
"""Check if the status indicates the rollout finished successfully."""
151+
return self.code == self.Code.FINISHED
152+
153+
def is_error(self) -> bool:
154+
"""Check if the status indicates the rollout failed with an error."""
155+
return self.code == self.Code.INTERNAL
156+
157+
def is_stopped(self) -> bool:
158+
"""Check if the status indicates the rollout was stopped."""
159+
return self.code == self.Code.CANCELLED
160+
161+
def get_termination_reason(self) -> Optional[TerminationReason]:
162+
"""Extract termination reason from details if present."""
163+
for detail in self.details:
164+
metadata = detail.get("metadata", {})
165+
if detail.get("reason") == ErrorInfo.REASON_TERMINATION_REASON and "termination_reason" in metadata:
166+
try:
167+
return TerminationReason.from_str(metadata["termination_reason"])
168+
except ValueError:
169+
# If the reason is not a valid enum value, return None
170+
return None
171+
return None
172+
173+
def get_extra_info(self) -> Optional[Dict[str, Any]]:
174+
"""Extract extra info from details if present."""
175+
for detail in self.details:
176+
metadata = detail.get("metadata", {})
177+
reason = detail.get("reason")
178+
# Skip termination_reason and stopped details, return other error info
179+
if reason in [ErrorInfo.REASON_EXTRA_INFO]:
180+
return metadata
181+
return None
182+
183+
def __hash__(self) -> int:
184+
"""Generate a hash for the Status object."""
185+
# Use a stable hash based on code, message, and details
186+
import hashlib
187+
188+
# Create a stable string representation
189+
hash_data = f"{self.code}:{self.message}:{len(self.details)}"
190+
191+
# Add details content for more uniqueness
192+
for detail in sorted(self.details, key=lambda x: str(x)):
193+
hash_data += f":{str(detail)}"
194+
195+
# Generate hash
196+
hash_obj = hashlib.sha256(hash_data.encode("utf-8"))
197+
return int.from_bytes(hash_obj.digest()[:8], byteorder="big")
198+
199+
18200
class ChatCompletionContentPartTextParam(BaseModel):
19201
text: str = Field(..., description="The text content.")
20202
type: Literal["text"] = Field("text", description="The type of the content part.")
@@ -289,27 +471,6 @@ class ExecutionMetadata(BaseModel):
289471
)
290472

291473

292-
class RolloutStatus(BaseModel):
293-
"""Status of the rollout."""
294-
295-
"""
296-
running: Unfinished rollout which is still in progress.
297-
finished: Rollout finished.
298-
error: Rollout failed due to unexpected error. The rollout record should be discard.
299-
"""
300-
301-
class Status(str, Enum):
302-
RUNNING = "running"
303-
FINISHED = "finished"
304-
ERROR = "error"
305-
306-
status: Status = Field(Status.RUNNING, description="Status of the rollout.")
307-
termination_reason: Optional[TerminationReason] = Field(
308-
None, description="reason of the rollout status, mapped to values in TerminationReason"
309-
)
310-
extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.")
311-
312-
313474
class EvaluationRow(BaseModel):
314475
"""
315476
Unified data structure for a single evaluation unit that contains messages,
@@ -334,9 +495,9 @@ class EvaluationRow(BaseModel):
334495
description="Metadata related to the input (dataset info, model config, session data, etc.).",
335496
)
336497

337-
rollout_status: RolloutStatus = Field(
338-
default_factory=RolloutStatus,
339-
description="The status of the rollout.",
498+
rollout_status: Status = Field(
499+
default_factory=Status.rollout_running,
500+
description="The status of the rollout following AIP-193 standards.",
340501
)
341502

342503
# Ground truth reference (moved from EvaluateResult to top level)

β€Ževal_protocol/pytest/plugin.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def pytest_addoption(parser) -> None:
6464
action="store",
6565
type=int,
6666
default=0,
67-
help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."),
67+
help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."),
6868
)
6969
group.addoption(
7070
"--ep-fail-on-max-retry",

β€Ževal_protocol/pytest/utils.pyβ€Ž

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Callable, Dict, List, Literal, Optional, Union
77

88
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
9-
from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus
9+
from eval_protocol.models import EvalMetadata, EvaluationRow, Status
1010
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1111
from eval_protocol.pytest.types import (
1212
CompletionParams,
@@ -282,7 +282,7 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
282282
try:
283283
# Try original task first
284284
result = await task
285-
result.rollout_status.status = RolloutStatus.Status.FINISHED
285+
result.rollout_status = Status.rollout_finished()
286286
return result
287287
except Exception as e:
288288
# NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails.
@@ -295,17 +295,15 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
295295
# Use shared backoff function for retryable exceptions
296296
try:
297297
result = await execute_row_with_backoff_retry(row)
298-
result.rollout_status.status = RolloutStatus.Status.FINISHED
298+
result.rollout_status = Status.rollout_finished()
299299
return result
300300
except Exception as retry_error:
301301
# Backoff gave up
302-
row.rollout_status.status = RolloutStatus.Status.ERROR
303-
# row.rollout_status.termination_reason = str(retry_error)
302+
row.rollout_status = Status.rollout_error(str(retry_error))
304303
return row
305304
else:
306305
# Non-retryable exception - fail immediately
307-
row.rollout_status.status = RolloutStatus.Status.ERROR
308-
# row.rollout_status.termination_reason = str(e)
306+
row.rollout_status = Status.rollout_error(str(e))
309307
return row
310308

311309
# Process all tasks concurrently with backoff retry

β€Žtests/test_retry_mechanism.pyβ€Ž

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pytest
1313

14-
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus
14+
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, Status
1515
from eval_protocol.pytest.evaluation_test import evaluation_test
1616
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1717
from eval_protocol.pytest.types import RolloutProcessorConfig
@@ -95,11 +95,11 @@ async def process_single_row(
9595
def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow:
9696
"""MOCK TEST: Tests that retry mechanism works - one task fails on first attempt, succeeds on retry."""
9797
print(
98-
f"πŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})"
98+
f"πŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})"
9999
)
100100

101101
# Assign a score based on success/failure
102-
score = 1.0 if row.rollout_status.status == "finished" else 0.0
102+
score = 1.0 if row.rollout_status.is_finished() else 0.0
103103
row.evaluation_result = EvaluateResult(score=score)
104104

105105
return row
@@ -191,9 +191,9 @@ async def process_single_row(row: EvaluationRow) -> EvaluationRow:
191191
def test_fail_fast_exceptions(row: EvaluationRow) -> EvaluationRow:
192192
"""Test that fail-fast exceptions like ValueError are not retried."""
193193
print(
194-
f"πŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})"
194+
f"πŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})"
195195
)
196-
score = 1.0 if row.rollout_status.status == "finished" else 0.0
196+
score = 1.0 if row.rollout_status.is_finished() else 0.0
197197
row.evaluation_result = EvaluateResult(score=score)
198198
return row
199199

@@ -283,8 +283,8 @@ def custom_http_giveup(e):
283283
def test_custom_giveup_function(row: EvaluationRow) -> EvaluationRow:
284284
"""Test custom giveup function behavior."""
285285
task_content = row.messages[0].content if row.messages else ""
286-
print(f"πŸ“Š EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})")
287-
score = 1.0 if row.rollout_status.status == "finished" else 0.0
286+
print(f"πŸ“Š EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})")
287+
score = 1.0 if row.rollout_status.is_finished() else 0.0
288288
row.evaluation_result = EvaluateResult(score=score)
289289
return row
290290

@@ -368,9 +368,9 @@ def simple_4xx_giveup(e):
368368
def test_simple_giveup_function(row: EvaluationRow) -> EvaluationRow:
369369
"""Test that giveup function prevents retries immediately."""
370370
print(
371-
f"πŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})"
371+
f"πŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})"
372372
)
373-
score = 1.0 if row.rollout_status.status == "finished" else 0.0
373+
score = 1.0 if row.rollout_status.is_finished() else 0.0
374374
row.evaluation_result = EvaluateResult(score=score)
375375
return row
376376

0 commit comments

Comments
Β (0)