Skip to content
Merged
34 changes: 33 additions & 1 deletion eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
from datetime import datetime, timezone
from enum import Enum
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union, Callable, Sequence

JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None]

Expand Down Expand Up @@ -1190,3 +1190,35 @@ class MCPMultiClientConfiguration(BaseModel):
"""Represents a MCP configuration."""

mcpServers: Dict[str, Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]]


class EPParameters(BaseModel):
"""The parameters of an `@evaluation_test`. Used for trainable integrations."""

model_config = ConfigDict(arbitrary_types_allowed=True)

completion_params: Any = None
input_messages: Any = None
input_dataset: Any = None
input_rows: Any = None
data_loaders: Any = None
dataset_adapter: Optional[Callable[..., Any]] = None
rollout_processor: Any = None
rollout_processor_kwargs: Dict[str, Any] | None = None
evaluation_test_kwargs: Any = None
aggregation_method: Any = Field(default="mean")
passed_threshold: Any = None
disable_browser_open: bool = False
num_runs: int = 1
filtered_row_ids: Optional[Sequence[str]] = None
max_dataset_rows: Optional[int] = None
mcp_config_path: Optional[str] = None
max_concurrent_rollouts: int = 8
max_concurrent_evaluations: int = 64
server_script_path: Optional[str] = None
steps: int = 30
mode: Any = Field(default="pointwise")
combine_datasets: bool = True
preprocess_fn: Optional[Callable[[list[EvaluationRow]], list[EvaluationRow]]] = None
logger: Any = None
exception_handler_config: Any = None
36 changes: 29 additions & 7 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EvaluationThresholdDict,
EvaluateResult,
Status,
EPParameters,
)
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
Expand Down Expand Up @@ -695,13 +696,34 @@ async def _collect_result(config, lst):
)
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)

ep_params: dict[str, Any] = {
"rollout_processor": rollout_processor,
"server_script_path": server_script_path,
"mcp_config_path": mcp_config_path,
"rollout_processor_kwargs": rollout_processor_kwargs,
"mode": mode,
}
# Attach full evaluation parameter metadata for training integrations
ep_params: EPParameters = EPParameters(
completion_params=completion_params,
input_messages=input_messages,
input_dataset=input_dataset,
input_rows=input_rows,
data_loaders=data_loaders,
dataset_adapter=dataset_adapter,
rollout_processor=rollout_processor,
rollout_processor_kwargs=rollout_processor_kwargs,
evaluation_test_kwargs=evaluation_test_kwargs,
aggregation_method=aggregation_method,
passed_threshold=passed_threshold,
disable_browser_open=disable_browser_open,
num_runs=num_runs,
filtered_row_ids=filtered_row_ids,
max_dataset_rows=max_dataset_rows,
mcp_config_path=mcp_config_path,
max_concurrent_rollouts=max_concurrent_rollouts,
max_concurrent_evaluations=max_concurrent_evaluations,
server_script_path=server_script_path,
steps=steps,
mode=mode,
combine_datasets=combine_datasets,
preprocess_fn=preprocess_fn,
logger=logger,
exception_handler_config=exception_handler_config,
)

# Create the dual mode wrapper
dual_mode_wrapper = create_dual_mode_wrapper(
Expand Down
11 changes: 7 additions & 4 deletions eval_protocol/pytest/integrations/openenv_trl_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,13 @@ def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:

eval_func = candidate_tests[0]
ep_eval_func = eval_func # used later after rollouts complete
ep_params: Dict[str, Any] = getattr(eval_func, "__ep_params__", {})
ep_rollout_processor = ep_params.get("rollout_processor")
ep_rollout_processor_kwargs = ep_params.get("rollout_processor_kwargs") or {}
ep_mcp_config_path = ep_params.get("mcp_config_path") or ""
ep_params = getattr(eval_func, "__ep_params__", None)
# ep_params is an EPParameters model (Pydantic), use attribute access
ep_rollout_processor = getattr(ep_params, "rollout_processor", None) if ep_params else None
ep_rollout_processor_kwargs = (
(getattr(ep_params, "rollout_processor_kwargs", None) or {}) if ep_params else {}
)
ep_mcp_config_path = (getattr(ep_params, "mcp_config_path", None) or "") if ep_params else ""
logger.info(
"[OpenEnvVLLM] Loaded eval test '%s' with rollout_processor=%s",
getattr(eval_func, "__name__", str(eval_func)),
Expand Down
19 changes: 19 additions & 0 deletions eval_protocol/training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .gepa_trainer import GEPATrainer
from .gepa_utils import (
DSPyModuleType,
DSPyModuleFactory,
create_single_turn_program,
create_signature,
build_reflection_lm,
)

__all__ = [
"GEPATrainer",
# DSPy module creation utilities
"DSPyModuleType",
"DSPyModuleFactory",
"create_single_turn_program",
"create_signature",
# Reflection LM helpers
"build_reflection_lm",
]
Loading
Loading