|
2 | 2 |
|
3 | 3 | from unittest.mock import Mock, patch |
4 | 4 |
|
5 | | -from eval_protocol.models import EvaluationRow, EvaluateResult, EvalMetadata, ExecutionMetadata, InputMetadata, Message |
| 5 | +import pytest |
| 6 | + |
| 7 | +from eval_protocol.models import ( |
| 8 | + EvaluationRow, |
| 9 | + EvaluateResult, |
| 10 | + EvalMetadata, |
| 11 | + EvaluationThreshold, |
| 12 | + ExecutionMetadata, |
| 13 | + InputMetadata, |
| 14 | + Message, |
| 15 | +) |
6 | 16 | from eval_protocol.pytest.evaluation_test_postprocess import postprocess |
7 | 17 | from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci |
8 | 18 |
|
@@ -206,6 +216,55 @@ def test_all_invalid_scores(self): |
206 | 216 | # Should still call logger.log for all rows |
207 | 217 | assert mock_logger.log.call_count == 2 |
208 | 218 |
|
| 219 | + @patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads |
| 220 | + def test_threshold_all_zero_scores_fail(self): |
| 221 | + """When all scores are 0.0 and threshold.success is 0.01, postprocess should fail.""" |
| 222 | + all_results = [ |
| 223 | + [self.create_test_row(0.0), self.create_test_row(0.0)], |
| 224 | + ] |
| 225 | + |
| 226 | + mock_logger = Mock() |
| 227 | + threshold = EvaluationThreshold(success=0.01, standard_error=None) |
| 228 | + |
| 229 | + with pytest.raises(AssertionError) as excinfo: |
| 230 | + postprocess( |
| 231 | + all_results=all_results, |
| 232 | + aggregation_method="mean", |
| 233 | + threshold=threshold, |
| 234 | + active_logger=mock_logger, |
| 235 | + mode="pointwise", |
| 236 | + completion_params={"model": "test-model"}, |
| 237 | + test_func_name="test_threshold_all_zero", |
| 238 | + num_runs=1, |
| 239 | + experiment_duration_seconds=10.0, |
| 240 | + ) |
| 241 | + |
| 242 | + # Sanity check on the assertion message |
| 243 | + assert "below threshold" in str(excinfo.value) |
| 244 | + |
| 245 | + @patch.dict("os.environ", {"EP_NO_UPLOAD": "1"}) # Disable uploads |
| 246 | + def test_threshold_equal_score_passes(self): |
| 247 | + """When agg_score equals threshold.success (0.01), postprocess should pass.""" |
| 248 | + all_results = [ |
| 249 | + [self.create_test_row(0.01)], |
| 250 | + ] |
| 251 | + |
| 252 | + mock_logger = Mock() |
| 253 | + threshold = EvaluationThreshold(success=0.01, standard_error=None) |
| 254 | + |
| 255 | + # Should not raise |
| 256 | + postprocess( |
| 257 | + all_results=all_results, |
| 258 | + aggregation_method="mean", |
| 259 | + threshold=threshold, |
| 260 | + active_logger=mock_logger, |
| 261 | + mode="pointwise", |
| 262 | + completion_params={"model": "test-model"}, |
| 263 | + test_func_name="test_threshold_equal_score", |
| 264 | + num_runs=1, |
| 265 | + experiment_duration_seconds=10.0, |
| 266 | + ) |
| 267 | + |
209 | 268 |
|
210 | 269 | class TestBootstrapEquivalence: |
211 | 270 | def test_bootstrap_equivalence_pandas_vs_pure_python(self): |
|
0 commit comments