|
11 | 11 | from dataclasses import replace |
12 | 12 | from typing import Any, Callable, Dict, List, Literal, Optional, Union |
13 | 13 | from collections import defaultdict |
14 | | - |
| 14 | +import hashlib |
| 15 | +import ast |
15 | 16 | from mcp.types import Completion |
16 | 17 | import pytest |
17 | 18 |
|
@@ -244,6 +245,7 @@ def evaluation_test( # noqa: C901 |
244 | 245 | max_dataset_rows: Optional[int] = None, |
245 | 246 | mcp_config_path: Optional[str] = None, |
246 | 247 | max_concurrent_rollouts: int = 8, |
| 248 | + max_concurrent_evaluations: int = 64, |
247 | 249 | server_script_path: Optional[str] = None, |
248 | 250 | steps: int = 30, |
249 | 251 | mode: EvaluationTestMode = "pointwise", |
@@ -308,6 +310,7 @@ def evaluation_test( # noqa: C901 |
308 | 310 | max_dataset_rows: Limit dataset to the first N rows. |
309 | 311 | mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema |
310 | 312 | max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel. |
| 313 | + max_concurrent_evaluations: Maximum number of concurrent evaluations to run in parallel. |
311 | 314 | server_script_path: Path to the MCP server script to run (default: "examples/tau2_mcp/server.py"). |
312 | 315 | steps: Number of rollout steps to execute (default: 30). |
313 | 316 | mode: Evaluation mode. "pointwise" (default) applies test function to each row (rollout result). |
@@ -582,29 +585,42 @@ def _log_eval_error( |
582 | 585 | for row in fresh_dataset: |
583 | 586 | active_logger.log(row) |
584 | 587 |
|
585 | | - if mode == "pointwise": |
586 | | - # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution |
587 | | - semaphore = asyncio.Semaphore(max_concurrent_rollouts) |
588 | | - tasks = [] |
| 588 | + # prepare parallel eval helper function |
| 589 | + semaphore = asyncio.Semaphore(max_concurrent_evaluations) |
589 | 590 |
|
590 | | - async def _execute_with_semaphore(row): |
591 | | - async with semaphore: |
592 | | - # NOTE: we will still evaluate errored rows (give users control over this) |
593 | | - # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func |
| 591 | + async def _execute_eval_with_semaphore(**inner_kwargs): |
| 592 | + async with semaphore: |
| 593 | + # NOTE: we will still evaluate errored rows (give users control over this) |
| 594 | + # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func |
| 595 | + if "row" in inner_kwargs: |
594 | 596 | result = await execute_with_params( |
595 | 597 | test_func, |
596 | | - processed_row=row, |
| 598 | + processed_row=inner_kwargs["row"], |
597 | 599 | evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, |
598 | 600 | ) |
599 | 601 | if result is None or not isinstance(result, EvaluationRow): |
600 | 602 | raise ValueError( |
601 | 603 | f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." |
602 | 604 | ) |
603 | 605 | return result |
| 606 | + if "rows" in inner_kwargs: |
| 607 | + results = await execute_with_params( |
| 608 | + test_func, |
| 609 | + processed_dataset=inner_kwargs["rows"], |
| 610 | + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, |
| 611 | + ) |
| 612 | + if results is None or not isinstance(results, list): |
| 613 | + raise ValueError( |
| 614 | + f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." |
| 615 | + ) |
| 616 | + return results |
604 | 617 |
|
| 618 | + if mode == "pointwise": |
| 619 | + # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution |
| 620 | + tasks = [] |
605 | 621 | # Use wrapper that handles retry logic internally |
606 | 622 | async for row in rollout_processor_with_retry(rollout_processor, fresh_dataset, config): |
607 | | - tasks.append(asyncio.create_task(_execute_with_semaphore(row))) |
| 623 | + tasks.append(asyncio.create_task(_execute_eval_with_semaphore(row=row))) |
608 | 624 |
|
609 | 625 | results = await asyncio.gather(*tasks) |
610 | 626 |
|
@@ -645,14 +661,13 @@ async def _collect_result(config, lst): |
645 | 661 | for result in rollout_results: |
646 | 662 | for row in result: |
647 | 663 | row_groups[row.input_metadata.row_id].append(row) |
648 | | - results = [] |
| 664 | + tasks = [] |
649 | 665 | for row_id, rows in row_groups.items(): |
650 | | - result = await execute_with_params( |
651 | | - test_func, |
652 | | - processed_dataset=rows, |
653 | | - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, |
654 | | - ) |
655 | | - results.extend(result) |
| 666 | + tasks.append(asyncio.create_task(_execute_eval_with_semaphore(rows=rows))) |
| 667 | + results = [] |
| 668 | + for task in tasks: |
| 669 | + res = await task |
| 670 | + results.extend(res) |
656 | 671 | all_results[i] = results |
657 | 672 | else: |
658 | 673 | # Batch mode: collect all results first, then evaluate (no pipelining) |
@@ -789,6 +804,13 @@ async def dual_mode_wrapper(*args, **kwargs): |
789 | 804 | # If not a direct call, use the pytest wrapper |
790 | 805 | return await pytest_wrapper(*args, **kwargs) |
791 | 806 |
|
| 807 | + dual_mode_wrapper._origin_func = test_func |
| 808 | + dual_mode_wrapper._metainfo = { |
| 809 | + "mode": mode, |
| 810 | + "max_rollout_concurrency": max_concurrent_rollouts, |
| 811 | + "max_evaluation_concurrency": max_concurrent_evaluations, |
| 812 | + } |
| 813 | + |
792 | 814 | # Copy all attributes from the pytest wrapper to our dual mode wrapper |
793 | 815 | import functools |
794 | 816 |
|
|
0 commit comments