Skip to content

Commit c185d84

Browse files
authored
preserve metadata and evaluator id etc to the wrapped eval func (#108)
* preserve metadata and evaluator id etc to the wrapped eval func * format * fix ut * fix format * remove id gen logic * fix ut
1 parent 8c87240 commit c185d84

File tree

2 files changed

+74
-18
lines changed

2 files changed

+74
-18
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from dataclasses import replace
1212
from typing import Any, Callable, Dict, List, Literal, Optional, Union
1313
from collections import defaultdict
14-
14+
import hashlib
15+
import ast
1516
from mcp.types import Completion
1617
import pytest
1718

@@ -244,6 +245,7 @@ def evaluation_test( # noqa: C901
244245
max_dataset_rows: Optional[int] = None,
245246
mcp_config_path: Optional[str] = None,
246247
max_concurrent_rollouts: int = 8,
248+
max_concurrent_evaluations: int = 64,
247249
server_script_path: Optional[str] = None,
248250
steps: int = 30,
249251
mode: EvaluationTestMode = "pointwise",
@@ -308,6 +310,7 @@ def evaluation_test( # noqa: C901
308310
max_dataset_rows: Limit dataset to the first N rows.
309311
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
310312
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
313+
max_concurrent_evaluations: Maximum number of concurrent evaluations to run in parallel.
311314
server_script_path: Path to the MCP server script to run (default: "examples/tau2_mcp/server.py").
312315
steps: Number of rollout steps to execute (default: 30).
313316
mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result).
@@ -582,29 +585,42 @@ def _log_eval_error(
582585
for row in fresh_dataset:
583586
active_logger.log(row)
584587

585-
if mode == "pointwise":
586-
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
587-
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
588-
tasks = []
588+
# prepare parallel eval helper function
589+
semaphore = asyncio.Semaphore(max_concurrent_evaluations)
589590

590-
async def _execute_with_semaphore(row):
591-
async with semaphore:
592-
# NOTE: we will still evaluate errored rows (give users control over this)
593-
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
591+
async def _execute_eval_with_semaphore(**inner_kwargs):
592+
async with semaphore:
593+
# NOTE: we will still evaluate errored rows (give users control over this)
594+
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
595+
if "row" in inner_kwargs:
594596
result = await execute_with_params(
595597
test_func,
596-
processed_row=row,
598+
processed_row=inner_kwargs["row"],
597599
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
598600
)
599601
if result is None or not isinstance(result, EvaluationRow):
600602
raise ValueError(
601603
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
602604
)
603605
return result
606+
if "rows" in inner_kwargs:
607+
results = await execute_with_params(
608+
test_func,
609+
processed_dataset=inner_kwargs["rows"],
610+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
611+
)
612+
if results is None or not isinstance(results, list):
613+
raise ValueError(
614+
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
615+
)
616+
return results
604617

618+
if mode == "pointwise":
619+
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
620+
tasks = []
605621
# Use wrapper that handles retry logic internally
606622
async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config):
607-
tasks.append(asyncio.create_task(_execute_with_semaphore(row)))
623+
tasks.append(asyncio.create_task(_execute_eval_with_semaphore(row=row)))
608624

609625
results = await asyncio.gather(*tasks)
610626

@@ -645,14 +661,13 @@ async def _collect_result(config, lst):
645661
for result in rollout_results:
646662
for row in result:
647663
row_groups[row.input_metadata.row_id].append(row)
648-
results = []
664+
tasks = []
649665
for row_id, rows in row_groups.items():
650-
result = await execute_with_params(
651-
test_func,
652-
processed_dataset=rows,
653-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
654-
)
655-
results.extend(result)
666+
tasks.append(asyncio.create_task(_execute_eval_with_semaphore(rows=rows)))
667+
results = []
668+
for task in tasks:
669+
res = await task
670+
results.extend(res)
656671
all_results[i] = results
657672
else:
658673
# Batch mode: collect all results first, then evaluate (no pipelining)
@@ -789,6 +804,13 @@ async def dual_mode_wrapper(*args, **kwargs):
789804
# If not a direct call, use the pytest wrapper
790805
return await pytest_wrapper(*args, **kwargs)
791806

807+
dual_mode_wrapper._origin_func = test_func
808+
dual_mode_wrapper._metainfo = {
809+
"mode": mode,
810+
"max_rollout_concurrency": max_concurrent_rollouts,
811+
"max_evaluation_concurrency": max_concurrent_evaluations,
812+
}
813+
792814
# Copy all attributes from the pytest wrapper to our dual mode wrapper
793815
import functools
794816

tests/pytest/test_get_metadata.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import asyncio
2+
from typing import Dict, List
3+
4+
from eval_protocol.pytest import evaluation_test
5+
from eval_protocol.models import EvaluationRow, Message
6+
7+
8+
@evaluation_test(
9+
input_messages=[
10+
[
11+
Message(role="user", content="What is the capital of France?"),
12+
],
13+
[
14+
Message(role="user", content="What is the capital of the moon?"),
15+
],
16+
],
17+
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}] * 2,
18+
mode="groupwise",
19+
max_concurrent_rollouts=5,
20+
max_concurrent_evaluations=10,
21+
)
22+
def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
23+
"""Run math evaluation on sample dataset using pytest interface."""
24+
return rows
25+
26+
27+
def test_pytest_func_metainfo():
28+
assert hasattr(test_pytest_async, "_origin_func")
29+
origin_func = test_pytest_async._origin_func
30+
assert not asyncio.iscoroutinefunction(origin_func)
31+
assert asyncio.iscoroutinefunction(test_pytest_async)
32+
assert test_pytest_async._metainfo["mode"] == "groupwise"
33+
assert test_pytest_async._metainfo["max_rollout_concurrency"] == 5
34+
assert test_pytest_async._metainfo["max_evaluation_concurrency"] == 10

0 commit comments

Comments
 (0)