11import asyncio
22from typing import List
33
4- from openai import AsyncOpenAI
4+ from litellm import acompletion
5+ from openai .types .chat .chat_completion_message import ChatCompletionMessageToolCall
56
6- from eval_protocol .auth import get_fireworks_api_base , get_fireworks_api_key
77from eval_protocol .dataset_logger import default_logger
88from eval_protocol .models import EvaluationRow , Message
99from eval_protocol .pytest .types import RolloutProcessorConfig
1212async def default_single_turn_rollout_processor (
1313 rows : List [EvaluationRow ], config : RolloutProcessorConfig
1414) -> List [EvaluationRow ]:
15- """Generate a single response from a Fireworks model concurrently."""
16-
17- api_key = get_fireworks_api_key ()
18- api_base = get_fireworks_api_base ()
19- client = AsyncOpenAI (api_key = api_key , base_url = f"{ api_base } /inference/v1" )
15+ """Generate a single response from any supported model provider using LiteLLM."""
2016
2117 async def process_row (row : EvaluationRow ) -> EvaluationRow :
2218 """Process a single row asynchronously."""
@@ -25,17 +21,35 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
2521
2622 messages_payload = [{"role" : m .role , "content" : m .content } for m in row .messages ]
2723
28- create_kwargs = dict (model = config .model , messages = messages_payload , ** config .input_params )
24+ request_params = {"model" : config .model , "messages" : messages_payload , ** config .input_params }
25+
2926 if row .tools is not None :
30- create_kwargs ["tools" ] = row .tools
31- response = await client .chat .completions .create (** create_kwargs )
27+ request_params ["tools" ] = row .tools
28+
29+ response = await acompletion (** request_params )
30+
3231 assistant_content = response .choices [0 ].message .content or ""
3332 tool_calls = response .choices [0 ].message .tool_calls if response .choices [0 ].message .tool_calls else None
33+
34+ converted_tool_calls = None
35+ if tool_calls :
36+ converted_tool_calls = [
37+ ChatCompletionMessageToolCall (
38+ id = tool_call .id ,
39+ type = tool_call .type ,
40+ function = {
41+ "name" : tool_call .function .name ,
42+ "arguments" : tool_call .function .arguments ,
43+ },
44+ )
45+ for tool_call in tool_calls
46+ ]
47+
3448 messages = list (row .messages ) + [
3549 Message (
3650 role = "assistant" ,
3751 content = assistant_content ,
38- tool_calls = tool_calls ,
52+ tool_calls = converted_tool_calls ,
3953 )
4054 ]
4155
0 commit comments