|
| 1 | +import sys |
| 2 | +import types |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Any, Dict, List |
| 5 | + |
| 6 | +import asyncio |
| 7 | +import pytest |
| 8 | +from pydantic import BaseModel |
| 9 | +from unittest import mock |
| 10 | + |
| 11 | + |
| 12 | +# ---- Stub external dependencies ---- |
| 13 | +openai = types.ModuleType("openai") |
| 14 | +types_mod = types.ModuleType("openai.types") |
| 15 | +chat_mod = types.ModuleType("openai.types.chat") |
| 16 | +chat_msg_mod = types.ModuleType("openai.types.chat.chat_completion_message") |
| 17 | + |
| 18 | + |
| 19 | +class FunctionCall(BaseModel): |
| 20 | + name: str |
| 21 | + arguments: str |
| 22 | + |
| 23 | + |
| 24 | +class ToolFunction(BaseModel): |
| 25 | + name: str |
| 26 | + arguments: str |
| 27 | + |
| 28 | + |
| 29 | +class ChatCompletionMessageToolCall(BaseModel): |
| 30 | + id: str |
| 31 | + type: str |
| 32 | + function: ToolFunction |
| 33 | + |
| 34 | + |
| 35 | +class CompletionUsage(BaseModel): |
| 36 | + prompt_tokens: int = 0 |
| 37 | + completion_tokens: int = 0 |
| 38 | + total_tokens: int = 0 |
| 39 | + |
| 40 | + |
| 41 | +chat_msg_mod.FunctionCall = FunctionCall |
| 42 | +chat_msg_mod.ChatCompletionMessageToolCall = ChatCompletionMessageToolCall |
| 43 | +chat_mod.chat_completion_message = chat_msg_mod |
| 44 | +openai.types = types_mod |
| 45 | +types_mod.chat = chat_mod |
| 46 | +types_mod.CompletionUsage = CompletionUsage |
| 47 | +sys.modules["openai"] = openai |
| 48 | +sys.modules["openai.types"] = types_mod |
| 49 | +sys.modules["openai.types.chat"] = chat_mod |
| 50 | +sys.modules["openai.types.chat.chat_completion_message"] = chat_msg_mod |
| 51 | + |
| 52 | + |
| 53 | +# Stub litellm |
| 54 | +litellm = types.ModuleType("litellm") |
| 55 | + |
| 56 | + |
| 57 | +async def acompletion(**kwargs): |
| 58 | + raise NotImplementedError |
| 59 | + |
| 60 | + |
| 61 | +litellm.acompletion = acompletion |
| 62 | +sys.modules["litellm"] = litellm |
| 63 | + |
| 64 | + |
| 65 | +# Stub eval_protocol models and types |
| 66 | +class Message(BaseModel): |
| 67 | + role: str |
| 68 | + content: Any = "" |
| 69 | + name: str | None = None |
| 70 | + tool_call_id: str | None = None |
| 71 | + tool_calls: List[ChatCompletionMessageToolCall] | None = None |
| 72 | + function_call: FunctionCall | None = None |
| 73 | + |
| 74 | + |
| 75 | +class EvaluationRow(BaseModel): |
| 76 | + messages: List[Message] |
| 77 | + tools: Any = None |
| 78 | + ground_truth: Any = None |
| 79 | + |
| 80 | + |
| 81 | +@dataclass |
| 82 | +class RolloutProcessorConfig: |
| 83 | + model: str |
| 84 | + input_params: Dict[str, Any] |
| 85 | + mcp_config_path: str |
| 86 | + server_script_path: str | None = None |
| 87 | + max_concurrent_rollouts: int = 8 |
| 88 | + steps: int = 30 |
| 89 | + |
| 90 | + |
| 91 | +# Register stub modules |
| 92 | +import_path = "/workspace/python-sdk/eval_protocol" |
| 93 | +eval_protocol_pkg = types.ModuleType("eval_protocol") |
| 94 | +eval_protocol_pkg.__path__ = [import_path] |
| 95 | +models_module = types.ModuleType("eval_protocol.models") |
| 96 | +models_module.Message = Message |
| 97 | +models_module.EvaluationRow = EvaluationRow |
| 98 | +pytest_pkg = types.ModuleType("eval_protocol.pytest") |
| 99 | +pytest_pkg.__path__ = [f"{import_path}/pytest"] |
| 100 | +types_module = types.ModuleType("eval_protocol.pytest.types") |
| 101 | +types_module.RolloutProcessorConfig = RolloutProcessorConfig |
| 102 | + |
| 103 | +sys.modules["eval_protocol"] = eval_protocol_pkg |
| 104 | +sys.modules["eval_protocol.models"] = models_module |
| 105 | +sys.modules["eval_protocol.pytest"] = pytest_pkg |
| 106 | +sys.modules["eval_protocol.pytest.types"] = types_module |
| 107 | + |
| 108 | + |
| 109 | +# Now we can import the rollout processor |
| 110 | +from eval_protocol.pytest.default_single_turn_rollout_process import ( |
| 111 | + default_single_turn_rollout_processor, |
| 112 | +) |
| 113 | + |
| 114 | + |
| 115 | +def test_handles_function_call_messages(): |
| 116 | + async def run_test(): |
| 117 | + tool_call = ChatCompletionMessageToolCall( |
| 118 | + id="call_1", |
| 119 | + type="function", |
| 120 | + function=ToolFunction(name="get_weather", arguments="{}"), |
| 121 | + ) |
| 122 | + row = EvaluationRow( |
| 123 | + messages=[ |
| 124 | + Message(role="user", content="Hi"), |
| 125 | + Message(role="assistant", tool_calls=[tool_call], content=""), |
| 126 | + Message(role="tool", tool_call_id="call_1", content="sunny"), |
| 127 | + ], |
| 128 | + tools=[{"type": "function", "function": {"name": "get_weather"}}], |
| 129 | + ) |
| 130 | + config = RolloutProcessorConfig( |
| 131 | + model="gpt-4o-mini", input_params={}, mcp_config_path="" |
| 132 | + ) |
| 133 | + |
| 134 | + captured_messages: List[Dict[str, Any]] = [] |
| 135 | + |
| 136 | + async def fake_acompletion(**kwargs): |
| 137 | + nonlocal captured_messages |
| 138 | + captured_messages = kwargs["messages"] |
| 139 | + return types.SimpleNamespace( |
| 140 | + choices=[ |
| 141 | + types.SimpleNamespace( |
| 142 | + message=types.SimpleNamespace( |
| 143 | + content="done", |
| 144 | + tool_calls=[ |
| 145 | + ChatCompletionMessageToolCall( |
| 146 | + id="call_2", |
| 147 | + type="function", |
| 148 | + function=ToolFunction(name="foo", arguments="{}"), |
| 149 | + ) |
| 150 | + ], |
| 151 | + function_call=None, |
| 152 | + ) |
| 153 | + ) |
| 154 | + ] |
| 155 | + ) |
| 156 | + |
| 157 | + with pytest.raises(NotImplementedError): |
| 158 | + await acompletion() |
| 159 | + |
| 160 | + with mock.patch( |
| 161 | + "eval_protocol.pytest.default_single_turn_rollout_process.acompletion", |
| 162 | + side_effect=fake_acompletion, |
| 163 | + ): |
| 164 | + dataset = await default_single_turn_rollout_processor([row], config) |
| 165 | + |
| 166 | + assert captured_messages[1]["tool_calls"][0]["id"] == "call_1" |
| 167 | + assert captured_messages[2]["tool_call_id"] == "call_1" |
| 168 | + result_row = dataset[0] |
| 169 | + assert result_row.messages[-1].tool_calls[0].id == "call_2" |
| 170 | + |
| 171 | + asyncio.run(run_test()) |
0 commit comments