Skip to content

Commit 361369e

Browse files
authored
Add original example idx in input_metadata (#346)
* add * respect original line number from the source dataloader * add ut * add * add * fix
1 parent 004be2f commit 361369e

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

eval_protocol/data_loader/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _process_variant(self, result: DataLoaderResult) -> DataLoaderResult:
111111

112112
def _apply_metadata(self, result: DataLoaderResult, original_count: int, processed_count: int) -> None:
113113
"""Apply metadata to all rows in the result."""
114-
for row in result.rows:
114+
for idx, row in enumerate(result.rows):
115115
if row.input_metadata.dataset_info is None:
116116
row.input_metadata.dataset_info = {}
117117

@@ -126,3 +126,4 @@ def _apply_metadata(self, result: DataLoaderResult, original_count: int, process
126126
# Apply row counts
127127
row.input_metadata.dataset_info["data_loader_num_rows"] = original_count
128128
row.input_metadata.dataset_info["data_loader_num_rows_after_preprocessing"] = processed_count
129+
row.input_metadata.dataset_info["data_loader_row_idx"] = idx

eval_protocol/pytest/tracing_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def update_row_with_remote_trace(
171171
row.messages = remote_row.messages
172172
row.tools = remote_row.tools
173173
row.input_metadata.session_data = remote_row.input_metadata.session_data
174+
row.input_metadata.dataset_info = remote_row.input_metadata.dataset_info
174175
row.execution_metadata = remote_row.execution_metadata
175176
return None
176177
else:
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
2+
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
3+
from eval_protocol.pytest import evaluation_test
4+
from typing import List
5+
6+
def generator() -> list[EvaluationRow]:
7+
return [EvaluationRow(messages=[Message(role="user", content="What is 2 + 2?")]) for _ in range(2)]
8+
9+
@evaluation_test(
10+
data_loaders=DynamicDataLoader(
11+
generators=[generator],
12+
),
13+
mode="all",
14+
)
15+
def test_data_loader_stable_row_id_with_same_content(rows: List[EvaluationRow]) -> List[EvaluationRow]:
16+
"""Test that the row id is stable even when the data loader is called multiple times."""
17+
row_ids = set()
18+
for row in rows:
19+
row_ids.add(row.input_metadata.row_id)
20+
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
21+
assert len(row_ids) == 2
22+
return rows

0 commit comments

Comments
 (0)