11import asyncio
22from typing import List
33
4- from openai import AsyncOpenAI
4+ from litellm import acompletion
5+ from openai .types .chat .chat_completion_message import ChatCompletionMessageToolCall
56
6- from eval_protocol .auth import get_fireworks_api_base , get_fireworks_api_key
77from eval_protocol .models import EvaluationRow , Message
88from eval_protocol .pytest .types import RolloutProcessorConfig
99
1010
1111async def default_single_turn_rollout_processor (
1212 rows : List [EvaluationRow ], config : RolloutProcessorConfig
1313) -> List [EvaluationRow ]:
14- """Generate a single response from a Fireworks model concurrently."""
15-
16- api_key = get_fireworks_api_key ()
17- api_base = get_fireworks_api_base ()
18- client = AsyncOpenAI (api_key = api_key , base_url = f"{ api_base } /inference/v1" )
14+ """Generate a single response from any supported model provider using LiteLLM."""
1915
2016 async def process_row (row : EvaluationRow ) -> EvaluationRow :
2117 """Process a single row asynchronously."""
@@ -24,17 +20,35 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
2420
2521 messages_payload = [{"role" : m .role , "content" : m .content } for m in row .messages ]
2622
27- create_kwargs = dict (model = config .model , messages = messages_payload , ** config .input_params )
23+ request_params = {"model" : config .model , "messages" : messages_payload , ** config .input_params }
24+
2825 if row .tools is not None :
29- create_kwargs ["tools" ] = row .tools
30- response = await client .chat .completions .create (** create_kwargs )
26+ request_params ["tools" ] = row .tools
27+
28+ response = await acompletion (** request_params )
29+
3130 assistant_content = response .choices [0 ].message .content or ""
3231 tool_calls = response .choices [0 ].message .tool_calls if response .choices [0 ].message .tool_calls else None
32+
33+ converted_tool_calls = None
34+ if tool_calls :
35+ converted_tool_calls = [
36+ ChatCompletionMessageToolCall (
37+ id = tool_call .id ,
38+ type = tool_call .type ,
39+ function = {
40+ "name" : tool_call .function .name ,
41+ "arguments" : tool_call .function .arguments ,
42+ },
43+ )
44+ for tool_call in tool_calls
45+ ]
46+
3347 messages = list (row .messages ) + [
3448 Message (
3549 role = "assistant" ,
3650 content = assistant_content ,
37- tool_calls = tool_calls ,
51+ tool_calls = converted_tool_calls ,
3852 )
3953 ]
4054
0 commit comments