Skip to content

Commit 2b765e0

Browse files
xzrderekshreymodi1
andauthored
gepa integration (#359)
* gepa integration part 1 * update * skeleton of gepa trainer * abc trainer * assign * fix lock * attempt at primitive conversion * gepa wokring * gepa work * updates * cleaning up 1 * undo * fixes * fix * updated --------- Co-authored-by: Shrey Modi <[email protected]>
1 parent fdccc92 commit 2b765e0

File tree

12 files changed

+1452
-12
lines changed

12 files changed

+1452
-12
lines changed

eval_protocol/models.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import importlib
44
from datetime import datetime, timezone
55
from enum import Enum
6-
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union
6+
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union, Callable, Sequence
77

88
JSONType = Union[Dict[str, Any], List[Any], str, int, float, bool, None]
99

@@ -1190,3 +1190,35 @@ class MCPMultiClientConfiguration(BaseModel):
11901190
"""Represents a MCP configuration."""
11911191

11921192
mcpServers: Dict[str, Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]]
1193+
1194+
1195+
class EPParameters(BaseModel):
1196+
"""The parameters of an `@evaluation_test`. Used for trainable integrations."""
1197+
1198+
model_config = ConfigDict(arbitrary_types_allowed=True)
1199+
1200+
completion_params: Any = None
1201+
input_messages: Any = None
1202+
input_dataset: Any = None
1203+
input_rows: Any = None
1204+
data_loaders: Any = None
1205+
dataset_adapter: Optional[Callable[..., Any]] = None
1206+
rollout_processor: Any = None
1207+
rollout_processor_kwargs: Dict[str, Any] | None = None
1208+
evaluation_test_kwargs: Any = None
1209+
aggregation_method: Any = Field(default="mean")
1210+
passed_threshold: Any = None
1211+
disable_browser_open: bool = False
1212+
num_runs: int = 1
1213+
filtered_row_ids: Optional[Sequence[str]] = None
1214+
max_dataset_rows: Optional[int] = None
1215+
mcp_config_path: Optional[str] = None
1216+
max_concurrent_rollouts: int = 8
1217+
max_concurrent_evaluations: int = 64
1218+
server_script_path: Optional[str] = None
1219+
steps: int = 30
1220+
mode: Any = Field(default="pointwise")
1221+
combine_datasets: bool = True
1222+
preprocess_fn: Optional[Callable[[list[EvaluationRow]], list[EvaluationRow]]] = None
1223+
logger: Any = None
1224+
exception_handler_config: Any = None

eval_protocol/pytest/evaluation_test.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
EvaluationThresholdDict,
2323
EvaluateResult,
2424
Status,
25+
EPParameters,
2526
)
2627
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
2728
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
@@ -753,13 +754,34 @@ async def _collect_result(config, lst):
753754
)
754755
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
755756

756-
ep_params: dict[str, Any] = {
757-
"rollout_processor": rollout_processor,
758-
"server_script_path": server_script_path,
759-
"mcp_config_path": mcp_config_path,
760-
"rollout_processor_kwargs": rollout_processor_kwargs,
761-
"mode": mode,
762-
}
757+
# Attach full evaluation parameter metadata for training integrations
758+
ep_params: EPParameters = EPParameters(
759+
completion_params=completion_params,
760+
input_messages=input_messages,
761+
input_dataset=input_dataset,
762+
input_rows=input_rows,
763+
data_loaders=data_loaders,
764+
dataset_adapter=dataset_adapter,
765+
rollout_processor=rollout_processor,
766+
rollout_processor_kwargs=rollout_processor_kwargs,
767+
evaluation_test_kwargs=evaluation_test_kwargs,
768+
aggregation_method=aggregation_method,
769+
passed_threshold=passed_threshold,
770+
disable_browser_open=disable_browser_open,
771+
num_runs=num_runs,
772+
filtered_row_ids=filtered_row_ids,
773+
max_dataset_rows=max_dataset_rows,
774+
mcp_config_path=mcp_config_path,
775+
max_concurrent_rollouts=max_concurrent_rollouts,
776+
max_concurrent_evaluations=max_concurrent_evaluations,
777+
server_script_path=server_script_path,
778+
steps=steps,
779+
mode=mode,
780+
combine_datasets=combine_datasets,
781+
preprocess_fn=preprocess_fn,
782+
logger=logger,
783+
exception_handler_config=exception_handler_config,
784+
)
763785

764786
# Create the dual mode wrapper
765787
dual_mode_wrapper = create_dual_mode_wrapper(

eval_protocol/pytest/integrations/openenv_trl_vllm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,13 @@ def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:
121121

122122
eval_func = candidate_tests[0]
123123
ep_eval_func = eval_func # used later after rollouts complete
124-
ep_params: Dict[str, Any] = getattr(eval_func, "__ep_params__", {})
125-
ep_rollout_processor = ep_params.get("rollout_processor")
126-
ep_rollout_processor_kwargs = ep_params.get("rollout_processor_kwargs") or {}
127-
ep_mcp_config_path = ep_params.get("mcp_config_path") or ""
124+
ep_params = getattr(eval_func, "__ep_params__", None)
125+
# ep_params is an EPParameters model (Pydantic), use attribute access
126+
ep_rollout_processor = getattr(ep_params, "rollout_processor", None) if ep_params else None
127+
ep_rollout_processor_kwargs = (
128+
(getattr(ep_params, "rollout_processor_kwargs", None) or {}) if ep_params else {}
129+
)
130+
ep_mcp_config_path = (getattr(ep_params, "mcp_config_path", None) or "") if ep_params else ""
128131
logger.info(
129132
"[OpenEnvVLLM] Loaded eval test '%s' with rollout_processor=%s",
130133
getattr(eval_func, "__name__", str(eval_func)),

eval_protocol/training/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from .gepa_trainer import GEPATrainer
2+
from .gepa_utils import (
3+
DSPyModuleType,
4+
DSPyModuleFactory,
5+
create_single_turn_program,
6+
create_signature,
7+
build_reflection_lm,
8+
)
9+
10+
__all__ = [
11+
"GEPATrainer",
12+
# DSPy module creation utilities
13+
"DSPyModuleType",
14+
"DSPyModuleFactory",
15+
"create_single_turn_program",
16+
"create_signature",
17+
# Reflection LM helpers
18+
"build_reflection_lm",
19+
]

0 commit comments

Comments
 (0)