Skip to content

Commit 4ee6b27

Browse files
authored
dont allow all 0 (#363)
* dont allow all 0 * .001 * add test * fix
1 parent 8219c44 commit 4ee6b27

File tree

3 files changed

+93
-5
lines changed

3 files changed

+93
-5
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,13 @@ def _validate_evaluator_locally(
279279
docker_build_extra: str,
280280
docker_run_extra: str,
281281
) -> bool:
282-
"""Run pytest locally for the selected evaluation test to validate the evaluator."""
282+
"""Run pytest locally for the selected evaluation test to validate the evaluator.
283+
284+
The pytest helpers always enforce a small success threshold (0.01) for
285+
evaluation_test-based suites so that an evaluation run where all scores are
286+
0.0 will naturally fail with a non-zero pytest exit code, which we then treat
287+
as a failed validator.
288+
"""
283289
if not selected_test_file or not selected_test_func:
284290
# No local test associated; skip validation but warn the user.
285291
print("Warning: Could not resolve a local evaluation test for this evaluator; skipping local validation.")
@@ -702,7 +708,7 @@ def _create_rft_job(
702708
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
703709
if getattr(args, "evaluation_dataset", None):
704710
body["evaluationDataset"] = args.evaluation_dataset
705-
711+
706712
output_model_arg = getattr(args, "output_model", None)
707713
if output_model_arg:
708714
if len(output_model_arg) > 63:

eval_protocol/cli_commands/local_test.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def _build_docker_image(dockerfile_path: str, image_tag: str, build_extras: List
3838
def _run_pytest_host(pytest_target: str) -> int:
3939
"""Run pytest against a target on the host and return its exit code."""
4040
print(f"Running locally: pytest {pytest_target} -vs")
41-
proc = subprocess.run([sys.executable, "-m", "pytest", pytest_target, "-vs"])
41+
# Always enforce a small success threshold for evaluation_test-based suites so that runs with all-zero scores fail.
42+
cmd = [sys.executable, "-m", "pytest", "--ep-success-threshold", "0.001", pytest_target, "-vs"]
43+
proc = subprocess.run(cmd)
4244
return proc.returncode
4345

4446

@@ -69,6 +71,22 @@ def _run_pytest_in_docker(
6971
"-w",
7072
workdir,
7173
]
74+
75+
# If EP_SUMMARY_JSON is set on the host, mirror it into the container so that
76+
# pytest evaluation tests can write summary artifacts that are visible to the
77+
# host. We map paths under the host logs directory (~/.eval_protocol) into the
78+
# mounted container home directory.
79+
host_summary_path = os.environ.get("EP_SUMMARY_JSON")
80+
if host_summary_path:
81+
try:
82+
rel_path = os.path.relpath(host_summary_path, host_logs_dir)
83+
# Only forward the variable when the summary path is inside the logs dir.
84+
if not rel_path.startswith(os.pardir):
85+
container_summary_path = os.path.join("/container_home/.eval_protocol", rel_path)
86+
cmd += ["-e", f"EP_SUMMARY_JSON={container_summary_path}"]
87+
except Exception:
88+
# Best-effort only; do not fail docker execution if we can't map the path.
89+
pass
7290
# Try to match host user to avoid permission problems on mounted volume
7391
try:
7492
uid = os.getuid() # type: ignore[attr-defined]
@@ -78,7 +96,12 @@ def _run_pytest_in_docker(
7896
pass
7997
if run_extras:
8098
cmd += run_extras
81-
cmd += [image_tag, "pytest", pytest_target, "-vs"]
99+
100+
# Build pytest command, always enforcing the same small success threshold as
101+
# the host runner so that all-zero score runs fail consistently.
102+
pytest_cmd: list[str] = ["pytest", "--ep-success-threshold", "0.001", pytest_target, "-vs"]
103+
104+
cmd += [image_tag] + pytest_cmd
82105
print("Running in Docker:", " ".join(cmd))
83106
try:
84107
proc = subprocess.run(cmd)

tests/test_evaluation_postprocess.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22

33
from unittest.mock import Mock, patch
44

5-
from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata, Message
5+
import pytest
6+
7+
from eval_protocol.models import (
8+
EvaluationRow,
9+
EvaluateResult,
10+
EvalMetadata,
11+
EvaluationThreshold,
12+
ExecutionMetadata,
13+
InputMetadata,
14+
Message,
15+
)
616
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
717
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
818

@@ -206,6 +216,55 @@ def test_all_invalid_scores(self):
206216
# Should still call logger.log for all rows
207217
assert mock_logger.log.call_count == 2
208218

219+
@patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads
220+
def test_threshold_all_zero_scores_fail(self):
221+
"""When all scores are 0.0 and threshold.success is 0.01, postprocess should fail."""
222+
all_results = [
223+
[self.create_test_row(0.0), self.create_test_row(0.0)],
224+
]
225+
226+
mock_logger = Mock()
227+
threshold = EvaluationThreshold(success=0.01, standard_error=None)
228+
229+
with pytest.raises(AssertionError) as excinfo:
230+
postprocess(
231+
all_results=all_results,
232+
aggregation_method="mean",
233+
threshold=threshold,
234+
active_logger=mock_logger,
235+
mode="pointwise",
236+
completion_params={"model": "test-model"},
237+
test_func_name="test_threshold_all_zero",
238+
num_runs=1,
239+
experiment_duration_seconds=10.0,
240+
)
241+
242+
# Sanity check on the assertion message
243+
assert "below threshold" in str(excinfo.value)
244+
245+
@patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads
246+
def test_threshold_equal_score_passes(self):
247+
"""When agg_score equals threshold.success (0.01), postprocess should pass."""
248+
all_results = [
249+
[self.create_test_row(0.01)],
250+
]
251+
252+
mock_logger = Mock()
253+
threshold = EvaluationThreshold(success=0.01, standard_error=None)
254+
255+
# Should not raise
256+
postprocess(
257+
all_results=all_results,
258+
aggregation_method="mean",
259+
threshold=threshold,
260+
active_logger=mock_logger,
261+
mode="pointwise",
262+
completion_params={"model": "test-model"},
263+
test_func_name="test_threshold_equal_score",
264+
num_runs=1,
265+
experiment_duration_seconds=10.0,
266+
)
267+
209268

210269
class TestBootstrapEquivalence:
211270
def test_bootstrap_equivalence_pandas_vs_pure_python(self):

0 commit comments

Comments
 (0)