Skip to content

Commit cb42aed

Browse files
authored
add better logging to pydantic ai rollouts (#127)
* save * move created at to beginning of table * fix types * fix types * vite build
1 parent 5e75383 commit cb42aed

File tree

7 files changed

+51
-48
lines changed

7 files changed

+51
-48
lines changed

eval_protocol/pytest/default_pydantic_ai_rollout_processor.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1+
# pyright: reportPrivateUsage=false
2+
13
import asyncio
24
import logging
35
import types
4-
from typing import List
5-
6-
from attr import dataclass
7-
from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
8-
6+
from pydantic_ai.models import Model
7+
from typing_extensions import override
98
from eval_protocol.models import EvaluationRow, Message
109
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1110
from eval_protocol.pytest.types import RolloutProcessorConfig
12-
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
11+
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
1312
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
1413
from pydantic_ai.models.anthropic import AnthropicModel
1514
from pydantic_ai.models.openai import OpenAIModel
@@ -25,7 +24,6 @@
2524
UserPromptPart,
2625
)
2726
from pydantic_ai.providers.openai import OpenAIProvider
28-
from typing_extensions import TypedDict
2927

3028
logger = logging.getLogger(__name__)
3129

@@ -36,9 +34,10 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3634

3735
def __init__(self):
3836
# dummy model used for its helper functions for processing messages
39-
self.util = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
37+
self.util: OpenAIModel = OpenAIModel("dummy-model", provider=OpenAIProvider(api_key="dummy"))
4038

41-
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
39+
@override
40+
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
4241
"""Create agent rollout tasks and return them for external handling."""
4342

4443
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
@@ -60,34 +59,34 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
6059
raise ValueError(
6160
"completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
6261
)
63-
kwargs = {}
64-
for k, v in config.completion_params["model"].items():
65-
if v["model"] and v["model"].startswith("anthropic:"):
62+
kwargs: dict[str, Model] = {}
63+
for k, v in config.completion_params["model"].items(): # pyright: ignore[reportUnknownVariableType]
64+
if v["model"] and v["model"].startswith("anthropic:"): # pyright: ignore[reportUnknownMemberType]
6665
kwargs[k] = AnthropicModel(
67-
v["model"].removeprefix("anthropic:"),
66+
v["model"].removeprefix("anthropic:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
6867
)
69-
elif v["model"] and v["model"].startswith("google:"):
68+
elif v["model"] and v["model"].startswith("google:"): # pyright: ignore[reportUnknownMemberType]
7069
kwargs[k] = GoogleModel(
71-
v["model"].removeprefix("google:"),
70+
v["model"].removeprefix("google:"), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
7271
)
7372
else:
7473
kwargs[k] = OpenAIModel(
75-
v["model"],
76-
provider=v["provider"],
74+
v["model"], # pyright: ignore[reportUnknownArgumentType]
75+
provider=v["provider"], # pyright: ignore[reportUnknownArgumentType]
7776
)
78-
agent = setup_agent(**kwargs)
77+
agent_instance: Agent = setup_agent(**kwargs) # pyright: ignore[reportAny]
7978
model = None
8079
else:
81-
agent = config.kwargs["agent"]
80+
agent_instance = config.kwargs["agent"] # pyright: ignore[reportAssignmentType]
8281
model = OpenAIModel(
83-
config.completion_params["model"],
84-
provider=config.completion_params["provider"],
82+
config.completion_params["model"], # pyright: ignore[reportAny]
83+
provider=config.completion_params["provider"], # pyright: ignore[reportAny]
8584
)
8685

8786
async def process_row(row: EvaluationRow) -> EvaluationRow:
8887
"""Process a single row with agent rollout."""
8988
model_messages = [self.convert_ep_message_to_pyd_message(m, row) for m in row.messages]
90-
response = await agent.run(
89+
response = await agent_instance.run(
9190
message_history=model_messages, model=model, usage_limits=config.kwargs.get("usage_limits")
9291
)
9392
row.messages = await self.convert_pyd_message_to_ep_message(response.all_messages())
@@ -104,11 +103,11 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
104103

105104
async def convert_pyd_message_to_ep_message(self, messages: list[ModelMessage]) -> list[Message]:
106105
oai_messages: list[ChatCompletionMessageParam] = await self.util._map_messages(messages)
107-
return [Message(**m) for m in oai_messages]
106+
return [Message(**m) for m in oai_messages] # pyright: ignore[reportArgumentType]
108107

109108
def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow) -> ModelMessage:
110109
if message.role == "assistant":
111-
type_adapter = TypeAdapter(ChatCompletionAssistantMessageParam)
110+
type_adapter = TypeAdapter(ChatCompletionMessage)
112111
oai_message = type_adapter.validate_python(message)
113112
# Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
114113
return self.util._process_response(
@@ -117,23 +116,23 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow
117116
object="chat.completion",
118117
model="",
119118
id="",
120-
created=(
121-
int(row.created_at.timestamp())
122-
if hasattr(row.created_at, "timestamp")
123-
else int(row.created_at)
124-
),
119+
created=int(row.created_at.timestamp()),
125120
)
126121
)
127122
elif message.role == "user":
128123
if isinstance(message.content, str):
129124
return ModelRequest(parts=[UserPromptPart(content=message.content)])
130125
elif isinstance(message.content, list):
131126
return ModelRequest(parts=[UserPromptPart(content=message.content[0].text)])
127+
else:
128+
raise ValueError(f"Unsupported content type for user message: {type(message.content)}")
132129
elif message.role == "system":
133130
if isinstance(message.content, str):
134131
return ModelRequest(parts=[SystemPromptPart(content=message.content)])
135132
elif isinstance(message.content, list):
136133
return ModelRequest(parts=[SystemPromptPart(content=message.content[0].text)])
134+
else:
135+
raise ValueError(f"Unsupported content type for system message: {type(message.content)}")
137136
elif message.role == "tool":
138137
return ModelRequest(
139138
parts=[

tests/chinook/test_pydantic_chinook.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from eval_protocol.pytest.default_pydantic_ai_rollout_processor import PydanticAgentRolloutProcessor
99
from tests.chinook.agent import setup_agent
10+
import os
1011
from pydantic_ai.models.openai import OpenAIModel
1112

1213
from tests.chinook.dataset import collect_dataset
@@ -82,7 +83,10 @@ class Response(BaseModel):
8283
return row
8384

8485

85-
@pytest.mark.skip(reason="takes too long to run")
86+
@pytest.mark.skipif(
87+
os.environ.get("CI") == "true",
88+
reason="Only run this test locally (skipped in CI)",
89+
)
8690
@pytest.mark.asyncio
8791
@evaluation_test(
8892
input_rows=[collect_dataset()],
Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vite-app/dist/assets/index-CMzLJz8S.js.map renamed to vite-app/dist/assets/index-D04dO2VH.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vite-app/dist/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
66
<title>EP | Log Viewer</title>
77
<link rel="icon" href="/assets/favicon-BkAAWQga.png" />
8-
<script type="module" crossorigin src="/assets/index-CMzLJz8S.js"></script>
8+
<script type="module" crossorigin src="/assets/index-D04dO2VH.js"></script>
99
<link rel="stylesheet" crossorigin href="/assets/index-DLyzGYL0.css">
1010
</head>
1111
<body>

vite-app/src/components/EvaluationRow.tsx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,11 @@ export const EvaluationRow = observer(
402402
<ExpandIcon rolloutId={rolloutId} />
403403
</TableCell>
404404

405+
{/* Created */}
406+
<TableCell className="py-3 text-xs">
407+
<RowCreated created_at={row.created_at} />
408+
</TableCell>
409+
405410
{/* Name */}
406411
<TableCell className="py-3 text-xs">
407412
<RowName name={row.eval_metadata?.name} />
@@ -461,11 +466,6 @@ export const EvaluationRow = observer(
461466
<TableCell className="py-3 text-xs">
462467
<RowScore score={row.evaluation_result?.score} />
463468
</TableCell>
464-
465-
{/* Created */}
466-
<TableCell className="py-3 text-xs">
467-
<RowCreated created_at={row.created_at} />
468-
</TableCell>
469469
</TableRowInteractive>
470470

471471
{/* Expanded Content Row */}

vite-app/src/components/EvaluationTable.tsx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,14 @@ export const EvaluationTable = observer(() => {
165165
<TableHead>
166166
<tr>
167167
<TableHeader className="w-8">&nbsp;</TableHeader>
168+
<SortableTableHeader
169+
sortField="created_at"
170+
currentSortField={state.sortField}
171+
currentSortDirection={state.sortDirection}
172+
onSort={handleSort}
173+
>
174+
Created
175+
</SortableTableHeader>
168176
<SortableTableHeader
169177
sortField="$.eval_metadata.name"
170178
currentSortField={state.sortField}
@@ -245,14 +253,6 @@ export const EvaluationTable = observer(() => {
245253
>
246254
Score
247255
</SortableTableHeader>
248-
<SortableTableHeader
249-
sortField="created_at"
250-
currentSortField={state.sortField}
251-
currentSortDirection={state.sortDirection}
252-
onSort={handleSort}
253-
>
254-
Created
255-
</SortableTableHeader>
256256
</tr>
257257
</TableHead>
258258

0 commit comments

Comments
 (0)