Skip to content

Commit c37ca45

Browse files
committed
Add test for rollout processor tool calls
1 parent 7b252d3 commit c37ca45

File tree

2 files changed

+198
-4
lines changed

2 files changed

+198
-4
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
from typing import List
33

44
from litellm import acompletion
5-
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
5+
from openai.types.chat.chat_completion_message import (
6+
ChatCompletionMessageToolCall,
7+
FunctionCall,
8+
)
69

710
from eval_protocol.models import EvaluationRow, Message
811
from eval_protocol.pytest.types import RolloutProcessorConfig
@@ -18,7 +21,24 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
1821
if len(row.messages) == 0:
1922
raise ValueError("Messages is empty. Please provide a non-empty dataset")
2023

21-
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
24+
messages_payload = []
25+
for m in row.messages:
26+
payload = {"role": m.role}
27+
if m.content is not None:
28+
payload["content"] = m.content
29+
if m.name is not None:
30+
payload["name"] = m.name
31+
if m.tool_call_id is not None:
32+
payload["tool_call_id"] = m.tool_call_id
33+
if m.tool_calls is not None:
34+
payload["tool_calls"] = [
35+
tc.model_dump(exclude_none=True) for tc in m.tool_calls
36+
]
37+
if m.function_call is not None:
38+
payload["function_call"] = m.function_call.model_dump(
39+
exclude_none=True
40+
)
41+
messages_payload.append(payload)
2242

2343
request_params = {"model": config.model, "messages": messages_payload, **config.input_params}
2444

@@ -27,8 +47,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
2747

2848
response = await acompletion(**request_params)
2949

30-
assistant_content = response.choices[0].message.content or ""
31-
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
50+
assistant_message = response.choices[0].message
51+
assistant_content = assistant_message.content or ""
52+
tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None
53+
function_call = assistant_message.function_call
3254

3355
converted_tool_calls = None
3456
if tool_calls:
@@ -49,6 +71,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
4971
role="assistant",
5072
content=assistant_content,
5173
tool_calls=converted_tool_calls,
74+
function_call=function_call,
5275
)
5376
]
5477

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import sys
2+
import types
3+
from dataclasses import dataclass
4+
from typing import Any, Dict, List
5+
6+
import asyncio
7+
import pytest
8+
from pydantic import BaseModel
9+
from unittest import mock
10+
11+
12+
# ---- Stub external dependencies ----
13+
openai = types.ModuleType("openai")
14+
types_mod = types.ModuleType("openai.types")
15+
chat_mod = types.ModuleType("openai.types.chat")
16+
chat_msg_mod = types.ModuleType("openai.types.chat.chat_completion_message")
17+
18+
19+
class FunctionCall(BaseModel):
20+
name: str
21+
arguments: str
22+
23+
24+
class ToolFunction(BaseModel):
25+
name: str
26+
arguments: str
27+
28+
29+
class ChatCompletionMessageToolCall(BaseModel):
30+
id: str
31+
type: str
32+
function: ToolFunction
33+
34+
35+
class CompletionUsage(BaseModel):
36+
prompt_tokens: int = 0
37+
completion_tokens: int = 0
38+
total_tokens: int = 0
39+
40+
41+
chat_msg_mod.FunctionCall = FunctionCall
42+
chat_msg_mod.ChatCompletionMessageToolCall = ChatCompletionMessageToolCall
43+
chat_mod.chat_completion_message = chat_msg_mod
44+
openai.types = types_mod
45+
types_mod.chat = chat_mod
46+
types_mod.CompletionUsage = CompletionUsage
47+
sys.modules["openai"] = openai
48+
sys.modules["openai.types"] = types_mod
49+
sys.modules["openai.types.chat"] = chat_mod
50+
sys.modules["openai.types.chat.chat_completion_message"] = chat_msg_mod
51+
52+
53+
# Stub litellm
54+
litellm = types.ModuleType("litellm")
55+
56+
57+
async def acompletion(**kwargs):
58+
raise NotImplementedError
59+
60+
61+
litellm.acompletion = acompletion
62+
sys.modules["litellm"] = litellm
63+
64+
65+
# Stub eval_protocol models and types
66+
class Message(BaseModel):
67+
role: str
68+
content: Any = ""
69+
name: str | None = None
70+
tool_call_id: str | None = None
71+
tool_calls: List[ChatCompletionMessageToolCall] | None = None
72+
function_call: FunctionCall | None = None
73+
74+
75+
class EvaluationRow(BaseModel):
76+
messages: List[Message]
77+
tools: Any = None
78+
ground_truth: Any = None
79+
80+
81+
@dataclass
82+
class RolloutProcessorConfig:
83+
model: str
84+
input_params: Dict[str, Any]
85+
mcp_config_path: str
86+
server_script_path: str | None = None
87+
max_concurrent_rollouts: int = 8
88+
steps: int = 30
89+
90+
91+
# Register stub modules
92+
import_path = "/workspace/python-sdk/eval_protocol"
93+
eval_protocol_pkg = types.ModuleType("eval_protocol")
94+
eval_protocol_pkg.__path__ = [import_path]
95+
models_module = types.ModuleType("eval_protocol.models")
96+
models_module.Message = Message
97+
models_module.EvaluationRow = EvaluationRow
98+
pytest_pkg = types.ModuleType("eval_protocol.pytest")
99+
pytest_pkg.__path__ = [f"{import_path}/pytest"]
100+
types_module = types.ModuleType("eval_protocol.pytest.types")
101+
types_module.RolloutProcessorConfig = RolloutProcessorConfig
102+
103+
sys.modules["eval_protocol"] = eval_protocol_pkg
104+
sys.modules["eval_protocol.models"] = models_module
105+
sys.modules["eval_protocol.pytest"] = pytest_pkg
106+
sys.modules["eval_protocol.pytest.types"] = types_module
107+
108+
109+
# Now we can import the rollout processor
110+
from eval_protocol.pytest.default_single_turn_rollout_process import (
111+
default_single_turn_rollout_processor,
112+
)
113+
114+
115+
def test_handles_function_call_messages():
116+
async def run_test():
117+
tool_call = ChatCompletionMessageToolCall(
118+
id="call_1",
119+
type="function",
120+
function=ToolFunction(name="get_weather", arguments="{}"),
121+
)
122+
row = EvaluationRow(
123+
messages=[
124+
Message(role="user", content="Hi"),
125+
Message(role="assistant", tool_calls=[tool_call], content=""),
126+
Message(role="tool", tool_call_id="call_1", content="sunny"),
127+
],
128+
tools=[{"type": "function", "function": {"name": "get_weather"}}],
129+
)
130+
config = RolloutProcessorConfig(
131+
model="gpt-4o-mini", input_params={}, mcp_config_path=""
132+
)
133+
134+
captured_messages: List[Dict[str, Any]] = []
135+
136+
async def fake_acompletion(**kwargs):
137+
nonlocal captured_messages
138+
captured_messages = kwargs["messages"]
139+
return types.SimpleNamespace(
140+
choices=[
141+
types.SimpleNamespace(
142+
message=types.SimpleNamespace(
143+
content="done",
144+
tool_calls=[
145+
ChatCompletionMessageToolCall(
146+
id="call_2",
147+
type="function",
148+
function=ToolFunction(name="foo", arguments="{}"),
149+
)
150+
],
151+
function_call=None,
152+
)
153+
)
154+
]
155+
)
156+
157+
with pytest.raises(NotImplementedError):
158+
await acompletion()
159+
160+
with mock.patch(
161+
"eval_protocol.pytest.default_single_turn_rollout_process.acompletion",
162+
side_effect=fake_acompletion,
163+
):
164+
dataset = await default_single_turn_rollout_processor([row], config)
165+
166+
assert captured_messages[1]["tool_calls"][0]["id"] == "call_1"
167+
assert captured_messages[2]["tool_call_id"] == "call_1"
168+
result_row = dataset[0]
169+
assert result_row.messages[-1].tool_calls[0].id == "call_2"
170+
171+
asyncio.run(run_test())

0 commit comments

Comments
 (0)