|
| 1 | +import asyncio |
1 | 2 | from typing import List |
2 | 3 |
|
3 | | -from openai import OpenAI |
| 4 | +from openai import AsyncOpenAI |
4 | 5 |
|
5 | 6 | from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key |
6 | | -from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message |
7 | | -from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig |
| 7 | +from eval_protocol.models import EvaluationRow, Message |
| 8 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
8 | 9 |
|
9 | 10 |
|
10 | | -def default_single_turn_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]: |
11 | | - """Generate a single response from a Fireworks model.""" |
| 11 | +async def default_single_turn_rollout_processor( |
| 12 | + rows: List[EvaluationRow], config: RolloutProcessorConfig |
| 13 | +) -> List[EvaluationRow]: |
| 14 | + """Generate a single response from a Fireworks model concurrently.""" |
12 | 15 |
|
13 | 16 | api_key = get_fireworks_api_key() |
14 | 17 | api_base = get_fireworks_api_base() |
15 | | - client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") |
16 | | - |
17 | | - if len(row.messages) == 0: |
18 | | - raise ValueError("Messages is empty. Please provide a non-empty dataset") |
19 | | - |
20 | | - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] |
21 | | - |
22 | | - response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params) |
23 | | - assistant_content = response.choices[0].message.content or "" |
24 | | - messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] |
25 | | - processed = EvaluationRow( |
26 | | - messages=messages, |
27 | | - ground_truth=row.ground_truth, |
28 | | - input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)), |
29 | | - ) |
30 | | - return [processed] |
| 18 | + client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") |
| 19 | + |
| 20 | + async def process_row(row: EvaluationRow) -> EvaluationRow: |
| 21 | + """Process a single row asynchronously.""" |
| 22 | + if len(row.messages) == 0: |
| 23 | + raise ValueError("Messages is empty. Please provide a non-empty dataset") |
| 24 | + |
| 25 | + messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] |
| 26 | + |
| 27 | + response = await client.chat.completions.create( |
| 28 | + model=config.model, messages=messages_payload, **config.input_params |
| 29 | + ) |
| 30 | + assistant_content = response.choices[0].message.content or "" |
| 31 | + messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] |
| 32 | + |
| 33 | + return EvaluationRow( |
| 34 | + messages=messages, |
| 35 | + **row.model_dump(exclude={"messages"}), |
| 36 | + ) |
| 37 | + |
| 38 | + # Process all rows concurrently |
| 39 | + tasks = [process_row(row) for row in rows] |
| 40 | + dataset = await asyncio.gather(*tasks) |
| 41 | + |
| 42 | + return dataset |
0 commit comments