1+ # pyright: reportPrivateUsage=false
2+
13import asyncio
24import logging
35import types
4- from typing import List
5-
6- from attr import dataclass
7- from openai .types .chat .chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam
8-
6+ from pydantic_ai .models import Model
7+ from typing_extensions import override
98from eval_protocol .models import EvaluationRow , Message
109from eval_protocol .pytest .rollout_processor import RolloutProcessor
1110from eval_protocol .pytest .types import RolloutProcessorConfig
12- from openai .types .chat import ChatCompletion , ChatCompletionMessageParam
11+ from openai .types .chat import ChatCompletion , ChatCompletionMessage , ChatCompletionMessageParam
1312from openai .types .chat .chat_completion import Choice as ChatCompletionChoice
1413from pydantic_ai .models .anthropic import AnthropicModel
1514from pydantic_ai .models .openai import OpenAIModel
2524 UserPromptPart ,
2625)
2726from pydantic_ai .providers .openai import OpenAIProvider
28- from typing_extensions import TypedDict
2927
3028logger = logging .getLogger (__name__ )
3129
@@ -36,9 +34,10 @@ class PydanticAgentRolloutProcessor(RolloutProcessor):
3634
3735 def __init__ (self ):
3836 # dummy model used for its helper functions for processing messages
39- self .util = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
37+ self .util : OpenAIModel = OpenAIModel ("dummy-model" , provider = OpenAIProvider (api_key = "dummy" ))
4038
41- def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
39+ @override
40+ def __call__ (self , rows : list [EvaluationRow ], config : RolloutProcessorConfig ) -> list [asyncio .Task [EvaluationRow ]]:
4241 """Create agent rollout tasks and return them for external handling."""
4342
4443 max_concurrent = getattr (config , "max_concurrent_rollouts" , 8 ) or 8
@@ -60,34 +59,34 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
6059 raise ValueError (
6160 "completion_params['model'] must be a dict mapping agent argument names to model config dicts (with 'model' and 'provider' keys)"
6261 )
63- kwargs = {}
64- for k , v in config .completion_params ["model" ].items ():
65- if v ["model" ] and v ["model" ].startswith ("anthropic:" ):
62+ kwargs : dict [ str , Model ] = {}
63+ for k , v in config .completion_params ["model" ].items (): # pyright: ignore[reportUnknownVariableType]
64+ if v ["model" ] and v ["model" ].startswith ("anthropic:" ): # pyright: ignore[reportUnknownMemberType]
6665 kwargs [k ] = AnthropicModel (
67- v ["model" ].removeprefix ("anthropic:" ),
66+ v ["model" ].removeprefix ("anthropic:" ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
6867 )
69- elif v ["model" ] and v ["model" ].startswith ("google:" ):
68+ elif v ["model" ] and v ["model" ].startswith ("google:" ): # pyright: ignore[reportUnknownMemberType]
7069 kwargs [k ] = GoogleModel (
71- v ["model" ].removeprefix ("google:" ),
70+ v ["model" ].removeprefix ("google:" ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
7271 )
7372 else :
7473 kwargs [k ] = OpenAIModel (
75- v ["model" ],
76- provider = v ["provider" ],
74+ v ["model" ], # pyright: ignore[reportUnknownArgumentType]
75+ provider = v ["provider" ], # pyright: ignore[reportUnknownArgumentType]
7776 )
78- agent = setup_agent (** kwargs )
77+ agent_instance : Agent = setup_agent (** kwargs ) # pyright: ignore[reportAny]
7978 model = None
8079 else :
81- agent = config .kwargs ["agent" ]
80+ agent_instance = config .kwargs ["agent" ] # pyright: ignore[reportAssignmentType ]
8281 model = OpenAIModel (
83- config .completion_params ["model" ],
84- provider = config .completion_params ["provider" ],
82+ config .completion_params ["model" ], # pyright: ignore[reportAny]
83+ provider = config .completion_params ["provider" ], # pyright: ignore[reportAny]
8584 )
8685
8786 async def process_row (row : EvaluationRow ) -> EvaluationRow :
8887 """Process a single row with agent rollout."""
8988 model_messages = [self .convert_ep_message_to_pyd_message (m , row ) for m in row .messages ]
90- response = await agent .run (
89+ response = await agent_instance .run (
9190 message_history = model_messages , model = model , usage_limits = config .kwargs .get ("usage_limits" )
9291 )
9392 row .messages = await self .convert_pyd_message_to_ep_message (response .all_messages ())
@@ -104,11 +103,11 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
104103
105104 async def convert_pyd_message_to_ep_message (self , messages : list [ModelMessage ]) -> list [Message ]:
106105 oai_messages : list [ChatCompletionMessageParam ] = await self .util ._map_messages (messages )
107- return [Message (** m ) for m in oai_messages ]
106+ return [Message (** m ) for m in oai_messages ] # pyright: ignore[reportArgumentType]
108107
109108 def convert_ep_message_to_pyd_message (self , message : Message , row : EvaluationRow ) -> ModelMessage :
110109 if message .role == "assistant" :
111- type_adapter = TypeAdapter (ChatCompletionAssistantMessageParam )
110+ type_adapter = TypeAdapter (ChatCompletionMessage )
112111 oai_message = type_adapter .validate_python (message )
113112 # Fix: Provide required finish_reason and index, and ensure created is int (timestamp)
114113 return self .util ._process_response (
@@ -117,23 +116,23 @@ def convert_ep_message_to_pyd_message(self, message: Message, row: EvaluationRow
117116 object = "chat.completion" ,
118117 model = "" ,
119118 id = "" ,
120- created = (
121- int (row .created_at .timestamp ())
122- if hasattr (row .created_at , "timestamp" )
123- else int (row .created_at )
124- ),
119+ created = int (row .created_at .timestamp ()),
125120 )
126121 )
127122 elif message .role == "user" :
128123 if isinstance (message .content , str ):
129124 return ModelRequest (parts = [UserPromptPart (content = message .content )])
130125 elif isinstance (message .content , list ):
131126 return ModelRequest (parts = [UserPromptPart (content = message .content [0 ].text )])
127+ else :
128+ raise ValueError (f"Unsupported content type for user message: { type (message .content )} " )
132129 elif message .role == "system" :
133130 if isinstance (message .content , str ):
134131 return ModelRequest (parts = [SystemPromptPart (content = message .content )])
135132 elif isinstance (message .content , list ):
136133 return ModelRequest (parts = [SystemPromptPart (content = message .content [0 ].text )])
134+ else :
135+ raise ValueError (f"Unsupported content type for system message: { type (message .content )} " )
137136 elif message .role == "tool" :
138137 return ModelRequest (
139138 parts = [
0 commit comments