Skip to content

Commit 8986902

Browse files
authored
Fix duplicate log (#111)
* add test to ensure only 19 logs are generated * fix assertion error loggiing
1 parent e106227 commit 8986902

File tree

2 files changed

+103
-2
lines changed

2 files changed

+103
-2
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,8 @@ def _log_eval_error(
474474
try:
475475
# Handle dataset loading
476476
data: List[EvaluationRow] = []
477+
# Track all rows processed in the current run for error logging
478+
processed_rows_in_run: List[EvaluationRow] = []
477479
if "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
478480
ds_arg = kwargs["dataset_path"]
479481
# Support either a single path or a list of paths; if a list is provided,
@@ -587,6 +589,7 @@ def _log_eval_error(
587589
# log the fresh_dataset
588590
for row in fresh_dataset:
589591
active_logger.log(row)
592+
processed_rows_in_run.append(row)
590593

591594
# prepare parallel eval helper function
592595
semaphore = asyncio.Semaphore(max_concurrent_evaluations)
@@ -741,10 +744,16 @@ async def _collect_result(config, lst):
741744
)
742745

743746
except AssertionError:
744-
_log_eval_error("finished", data if "data" in locals() else None, passed=False)
747+
_log_eval_error(
748+
"finished",
749+
processed_rows_in_run if "processed_rows_in_run" in locals() else None,
750+
passed=False,
751+
)
745752
raise
746753
except Exception:
747-
_log_eval_error("error", data if "data" in locals() else None, passed=False)
754+
_log_eval_error(
755+
"error", processed_rows_in_run if "processed_rows_in_run" in locals() else None, passed=False
756+
)
748757
raise
749758

750759
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import List, Set
2+
import asyncio
3+
4+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
5+
from eval_protocol.models import EvaluationRow
6+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
7+
from eval_protocol.pytest.types import RolloutProcessorConfig
8+
from tests.pytest.test_markdown_highlighting import markdown_dataset_to_evaluation_row
9+
10+
11+
class TrackingRolloutProcessor(RolloutProcessor):
12+
"""Custom rollout processor that tracks which rollout IDs are generated during rollout phase."""
13+
14+
def __init__(self, shared_rollout_ids: Set[str]):
15+
self.shared_rollout_ids = shared_rollout_ids
16+
17+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
18+
"""Process rows and track rollout IDs generated during rollout phase."""
19+
20+
async def process_row(row: EvaluationRow) -> EvaluationRow:
21+
# Track this rollout ID as being generated during rollout phase
22+
self.shared_rollout_ids.add(row.execution_metadata.rollout_id)
23+
return row
24+
25+
# Create tasks that process the rows and track IDs
26+
tasks = [asyncio.create_task(process_row(row)) for row in rows]
27+
return tasks
28+
29+
30+
class TrackingLogger(DatasetLogger):
31+
"""Custom logger that tracks all rollout IDs that are logged."""
32+
33+
def __init__(self, shared_rollout_ids: Set[str]):
34+
self.shared_rollout_ids = shared_rollout_ids
35+
36+
def log(self, row: EvaluationRow):
37+
self.shared_rollout_ids.add(row.execution_metadata.rollout_id)
38+
39+
def read(self):
40+
return []
41+
42+
43+
async def test_assertion_error_no_new_rollouts():
44+
"""
45+
Test that when an assertion error occurs due to failing threshold,
46+
no new rollout IDs are logged beyond those generated during the rollout phase.
47+
"""
48+
from eval_protocol.pytest.evaluation_test import evaluation_test
49+
50+
# Create shared set to track rollout IDs generated during rollout phase
51+
shared_rollout_ids: Set[str] = set()
52+
53+
# Create custom processor and logger for tracking with shared set
54+
rollout_processor = TrackingRolloutProcessor(shared_rollout_ids)
55+
logger = TrackingLogger(shared_rollout_ids)
56+
57+
input_dataset: list[str] = [
58+
"tests/pytest/data/markdown_dataset.jsonl",
59+
]
60+
completion_params: list[dict] = [{"temperature": 0.0, "model": "dummy/local-model"}]
61+
62+
@evaluation_test(
63+
input_dataset=input_dataset,
64+
completion_params=completion_params,
65+
dataset_adapter=markdown_dataset_to_evaluation_row,
66+
rollout_processor=rollout_processor,
67+
mode="pointwise",
68+
combine_datasets=False,
69+
num_runs=1, # Single run to simplify tracking
70+
passed_threshold=0.5, # Threshold that will fail since we return 0.0
71+
logger=logger,
72+
)
73+
def eval_fn(row: EvaluationRow) -> EvaluationRow:
74+
# Always return score 0.0, which will fail the 0.5 threshold
75+
from eval_protocol.models import EvaluateResult
76+
77+
row.evaluation_result = EvaluateResult(score=0.0)
78+
return row
79+
80+
try:
81+
# This should fail due to threshold not being met
82+
for ds_path in input_dataset:
83+
for completion_param in completion_params:
84+
await eval_fn(dataset_path=ds_path, completion_params=completion_param)
85+
except AssertionError:
86+
# Expected - the threshold check should fail
87+
pass
88+
else:
89+
assert False, "Expected AssertionError due to failing threshold"
90+
91+
# Get the final set of rollout IDs that were generated during rollout phase
92+
assert len(shared_rollout_ids) == 19, "Only 19 rollout IDs should have been logged"

0 commit comments

Comments
 (0)