Skip to content

Commit 495ff63

Browse files
authored
support parameter override (#227)
* support overwrite * format * remove useless override * add test * format * update comment
1 parent 5d7e5cb commit 495ff63

File tree

5 files changed

+121
-1
lines changed

5 files changed

+121
-1
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@
5252
add_cost_metrics,
5353
log_eval_status_and_rows,
5454
parse_ep_completion_params,
55+
parse_ep_completion_params_overwrite,
5556
parse_ep_max_concurrent_rollouts,
5657
parse_ep_max_rows,
5758
parse_ep_num_runs,
5859
parse_ep_passed_threshold,
60+
parse_ep_dataloaders,
5961
rollout_processor_with_retry,
6062
run_tasks_with_eval_progress,
6163
run_tasks_with_run_progress,
@@ -189,10 +191,18 @@ def evaluation_test(
189191
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
190192
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
191193
completion_params = parse_ep_completion_params(completion_params)
194+
completion_params = parse_ep_completion_params_overwrite(completion_params)
192195
original_completion_params = completion_params
193196
passed_threshold = parse_ep_passed_threshold(passed_threshold)
197+
data_loaders = parse_ep_dataloaders(data_loaders)
194198
custom_invocation_id = os.environ.get("EP_INVOCATION_ID", None)
195199

200+
# ignore other data input params when dataloader is provided
201+
if data_loaders:
202+
input_dataset = None
203+
input_messages = None
204+
input_rows = None
205+
196206
def decorator(
197207
test_func: TestFunction,
198208
) -> TestFunction:

eval_protocol/pytest/plugin.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pathlib
2020
import sys
2121
from pytest import StashKey
22+
import pytest
2223

2324

2425
def pytest_addoption(parser) -> None:
@@ -56,6 +57,7 @@ def pytest_addoption(parser) -> None:
5657
default=None,
5758
help=("Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)."),
5859
)
60+
# deprecate this later
5961
group.addoption(
6062
"--ep-input-param",
6163
action="append",
@@ -115,6 +117,22 @@ def pytest_addoption(parser) -> None:
115117
"Default: false (experiment JSONs are saved and uploaded by default)."
116118
),
117119
)
120+
group.addoption(
121+
"--ep-jsonl-path",
122+
default=None,
123+
help=("Load input from a jsonl file that is already in EvaluationRow or openai CHAT format"),
124+
)
125+
group.addoption(
126+
"--ep-completion-params",
127+
default=[],
128+
action="append",
129+
help=("Overwrite completion params with json. Can be used multiple times. "),
130+
)
131+
group.addoption(
132+
"--ep-remote-rollout-processor-base-url",
133+
default=None,
134+
help=("If set, use this base URL for remote rollout processing. Example: http://localhost:8000"),
135+
)
118136

119137

120138
def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
@@ -243,6 +261,15 @@ def pytest_configure(config) -> None:
243261
if config.getoption("--ep-no-upload"):
244262
os.environ["EP_NO_UPLOAD"] = "1"
245263

264+
if config.getoption("--ep-jsonl-path"):
265+
os.environ["EP_JSONL_PATH"] = config.getoption("--ep-jsonl-path")
266+
267+
if config.getoption("--ep-completion-params"):
268+
# redump to json to make sure they are legit
269+
os.environ["EP_COMPLETION_PARAMS"] = json.dumps(
270+
[json.loads(s) for s in config.getoption("--ep-completion-params") or []]
271+
)
272+
246273
# Allow ad-hoc overrides of input params via CLI flags
247274
try:
248275
merged: dict = {}

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata
1010
from .rollout_processor import RolloutProcessor
1111
from .types import RolloutProcessorConfig
12+
import os
1213

1314

1415
class RemoteRolloutProcessor(RolloutProcessor):
@@ -30,7 +31,8 @@ def __init__(
3031
# Prefer constructor-provided configuration. These can be overridden via
3132
# config.kwargs at call time for backward compatibility.
3233
self._remote_base_url = remote_base_url
33-
self._model_base_url = model_base_url
34+
if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"):
35+
self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL")
3436
self._poll_interval = poll_interval
3537
self._timeout_seconds = timeout_seconds
3638
self._output_data_loader = output_data_loader

eval_protocol/pytest/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
EvaluationThresholdDict,
2020
Status,
2121
)
22+
from eval_protocol.data_loader import DynamicDataLoader
23+
from eval_protocol.data_loader.models import EvaluationDataLoader
2224
from eval_protocol.pytest.rollout_processor import RolloutProcessor
2325
from eval_protocol.pytest.types import (
2426
RolloutProcessorConfig,
@@ -239,6 +241,45 @@ def parse_ep_completion_params(
239241
return completion_params
240242

241243

244+
def parse_ep_completion_params_overwrite(
245+
completion_params: Sequence[CompletionParams | None] | None,
246+
) -> Sequence[CompletionParams | None]:
247+
new_completion_params = os.getenv("EP_COMPLETION_PARAMS")
248+
if new_completion_params:
249+
try:
250+
new_completion_params_list = json.loads(new_completion_params)
251+
if isinstance(new_completion_params_list, list):
252+
return new_completion_params_list
253+
except Exception:
254+
pass
255+
return completion_params or []
256+
257+
258+
def _rows_from_jsonl(path: str) -> list[EvaluationRow]:
259+
rows = []
260+
try:
261+
with open(path, "r", encoding="utf-8") as f:
262+
for line in f:
263+
rows.append(EvaluationRow(**json.loads(line)))
264+
except Exception as e:
265+
print(f"❌ Failed to load rows from JSONL at {path}: {e}")
266+
return []
267+
268+
return rows
269+
270+
271+
def parse_ep_dataloaders(
272+
dataloaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None,
273+
) -> Sequence[EvaluationDataLoader] | EvaluationDataLoader | None:
274+
try:
275+
load_from_jsonl_path = os.getenv("EP_JSONL_PATH")
276+
if load_from_jsonl_path:
277+
return DynamicDataLoader(generators=[lambda path=load_from_jsonl_path: _rows_from_jsonl(path)])
278+
except Exception:
279+
pass
280+
return dataloaders or None
281+
282+
242283
def parse_ep_passed_threshold(
243284
default_value: float | EvaluationThresholdDict | EvaluationThreshold | None,
244285
) -> EvaluationThreshold | None:

tests/pytest/test_pytest_env_overwrite.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import atexit
2+
import shutil
3+
import tempfile
14
from eval_protocol.models import EvaluationRow, Message
25
from eval_protocol.pytest import evaluation_test
36
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
@@ -18,3 +21,40 @@ def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow:
1821
assert row.messages[0].content == "What is the capital of France?"
1922
assert row.execution_metadata.invocation_id == "test-invocation-123"
2023
return row
24+
25+
26+
with mock.patch.dict(os.environ, {"EP_COMPLETION_PARAMS": '[{"model": "gpt-40"}]'}):
27+
28+
@evaluation_test(
29+
input_rows=[[EvaluationRow(messages=[Message(role="user", content="What is 5 * 6?")])]],
30+
completion_params=[{"model": "no-op"}], # This should be overridden by the env var
31+
rollout_processor=NoOpRolloutProcessor(),
32+
mode="pointwise",
33+
)
34+
def test_input_messages_in_env(row: EvaluationRow) -> EvaluationRow:
35+
"""Run math evaluation on sample dataset using pytest interface."""
36+
assert row.messages[0].content == "What is 5 * 6?"
37+
assert row.input_metadata.completion_params["model"] == "gpt-40"
38+
return row
39+
40+
41+
_jsonl_tmpdir = tempfile.mkdtemp()
42+
atexit.register(shutil.rmtree, _jsonl_tmpdir, ignore_errors=True)
43+
44+
input_path = os.path.join(_jsonl_tmpdir, "input.jsonl")
45+
with open(input_path, "w") as f:
46+
f.write(
47+
'{"messages": [{"role": "user", "content": "What is 10 / 2?"}], "input_metadata": {"some_key": "some_value"}}\n'
48+
)
49+
print(f"finish prepare input file {input_path}")
50+
with mock.patch.dict(os.environ, {"EP_JSONL_PATH": input_path}):
51+
52+
@evaluation_test(
53+
input_rows=[[EvaluationRow(messages=[Message(role="user", content="This will be ignored")])]],
54+
completion_params=[{"model": "no-op"}],
55+
rollout_processor=NoOpRolloutProcessor(),
56+
mode="pointwise",
57+
)
58+
def test_input_override(row: EvaluationRow) -> EvaluationRow:
59+
assert row.messages[0].content == "What is 10 / 2?"
60+
return row

0 commit comments

Comments
 (0)