Skip to content

Commit 0b4a0a3

Browse files
authored
Refactor rollouts to accept list (#7)
* refactor rollout processor to accept entire input dataset * run single turn rollouts in parallel
1 parent 6029271 commit 0b4a0a3

File tree

5 files changed

+49
-36
lines changed

5 files changed

+49
-36
lines changed

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
1010
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
1111
from eval_protocol.models import EvaluationRow, Message
12-
from eval_protocol.pytest.types import RolloutProcessorConfig
12+
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
1313

1414

1515
class Agent:
@@ -73,8 +73,13 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str:
7373
return first_content.text
7474

7575

76-
async def default_agent_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
77-
agent = Agent(model=config.model, initial_messages=config.initial_messages, config_path=config.mcp_config_path)
78-
await agent.setup()
79-
await agent.call_agent()
80-
return [EvaluationRow(messages=agent.messages)]
76+
async def default_agent_rollout_processor(
77+
rows: List[EvaluationRow], config: RolloutProcessorConfig
78+
) -> List[EvaluationRow]:
79+
dataset: Dataset = []
80+
for row in rows:
81+
agent = Agent(model=config.model, initial_messages=row.messages, config_path=config.mcp_config_path)
82+
await agent.setup()
83+
await agent.call_agent()
84+
dataset.append(EvaluationRow(messages=agent.messages, ground_truth=row.ground_truth))
85+
return dataset
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import List
22

33
from eval_protocol.models import EvaluationRow
4-
from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig
4+
from eval_protocol.pytest.types import RolloutProcessorConfig
55

66

7-
def default_no_op_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
7+
def default_no_op_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]:
88
"""
99
Simply passes input dataset through to the test function. This can be useful
1010
if you want to run the rollout yourself.
1111
"""
12-
return [row]
12+
return rows
Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,42 @@
1+
import asyncio
12
from typing import List
23

3-
from openai import OpenAI
4+
from openai import AsyncOpenAI
45

56
from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key
6-
from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message
7-
from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig
7+
from eval_protocol.models import EvaluationRow, Message
8+
from eval_protocol.pytest.types import RolloutProcessorConfig
89

910

10-
def default_single_turn_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
11-
"""Generate a single response from a Fireworks model."""
11+
async def default_single_turn_rollout_processor(
12+
rows: List[EvaluationRow], config: RolloutProcessorConfig
13+
) -> List[EvaluationRow]:
14+
"""Generate a single response from a Fireworks model concurrently."""
1215

1316
api_key = get_fireworks_api_key()
1417
api_base = get_fireworks_api_base()
15-
client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")
16-
17-
if len(row.messages) == 0:
18-
raise ValueError("Messages is empty. Please provide a non-empty dataset")
19-
20-
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
21-
22-
response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params)
23-
assistant_content = response.choices[0].message.content or ""
24-
messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]
25-
processed = EvaluationRow(
26-
messages=messages,
27-
ground_truth=row.ground_truth,
28-
input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)),
29-
)
30-
return [processed]
18+
client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")
19+
20+
async def process_row(row: EvaluationRow) -> EvaluationRow:
21+
"""Process a single row asynchronously."""
22+
if len(row.messages) == 0:
23+
raise ValueError("Messages is empty. Please provide a non-empty dataset")
24+
25+
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
26+
27+
response = await client.chat.completions.create(
28+
model=config.model, messages=messages_payload, **config.input_params
29+
)
30+
assistant_content = response.choices[0].message.content or ""
31+
messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]
32+
33+
return EvaluationRow(
34+
messages=messages,
35+
**row.model_dump(exclude={"messages"}),
36+
)
37+
38+
# Process all rows concurrently
39+
tasks = [process_row(row) for row in rows]
40+
dataset = await asyncio.gather(*tasks)
41+
42+
return dataset

eval_protocol/pytest/evaluation_test.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,8 @@ def wrapper_body(**kwargs):
182182
model=model_name,
183183
input_params=kwargs.get("input_params") or {},
184184
mcp_config_path=mcp_config_path or "",
185-
initial_messages=kwargs.get("input_messages") if "input_messages" in kwargs else [],
186185
)
187-
for row in data:
188-
processed: List[EvaluationRow] = execute_function(rollout_processor, row=row, config=config)
189-
input_dataset.extend(processed)
186+
input_dataset = execute_function(rollout_processor, rows=data, config=config)
190187

191188
all_results: List[EvaluationRow] = []
192189
for _ in range(num_runs):

eval_protocol/pytest/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class RolloutProcessorConfig:
3939
model: ModelParam
4040
input_params: InputParam # optional input parameters for inference
4141
mcp_config_path: str # for agent rollout processor
42-
initial_messages: list[Message] # for agent rollout processor
4342

4443

45-
RolloutProcessor = Callable[[EvaluationRow, RolloutProcessorConfig], List[EvaluationRow]]
44+
RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]]

0 commit comments

Comments
 (0)