Skip to content

Commit b0f92ac

Browse files
authored
accept in memory rows as input (#112)
1 parent 8986902 commit b0f92ac

File tree

4 files changed

+62
-11
lines changed

4 files changed

+62
-11
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
EvaluationInputParam,
3737
EvaluationTestMode,
3838
InputMessagesParam,
39+
InputRowsParam,
3940
ModelParam,
4041
RolloutProcessorConfig,
4142
RolloutProcessorInputParam,
@@ -238,6 +239,7 @@ def evaluation_test( # noqa: C901
238239
completion_params: List[CompletionParams],
239240
input_messages: Optional[List[InputMessagesParam]] = None,
240241
input_dataset: Optional[List[DatasetPathParam]] = None,
242+
input_rows: Optional[List[InputRowsParam]] = None,
241243
dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter,
242244
rollout_processor: RolloutProcessor = NoOpRolloutProcessor(),
243245
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
@@ -299,6 +301,9 @@ def evaluation_test( # noqa: C901
299301
input_dataset: Paths to JSONL datasets. This is useful if you have a
300302
dataset already. Provide a dataset_adapter to convert the input dataset
301303
to a list of EvaluationRows if you have a custom dataset format.
304+
input_rows: Pre-constructed EvaluationRow objects to use directly. This is useful
305+
when you want to provide EvaluationRow objects with custom metadata, input_messages,
306+
or other fields already populated. Will be passed as "input_dataset" to the test function.
302307
dataset_adapter: Function to convert the input dataset to a list of
303308
EvaluationRows. This is useful if you have a custom dataset format.
304309
completion_params: Generation parameters for the rollout.
@@ -413,33 +418,42 @@ async def execute_with_params(
413418
# Calculate all possible combinations of parameters
414419
if mode == "groupwise":
415420
combinations = generate_parameter_combinations(
416-
input_dataset, None, input_messages, evaluation_test_kwargs, max_dataset_rows, combine_datasets
421+
input_dataset,
422+
None,
423+
input_messages,
424+
input_rows,
425+
evaluation_test_kwargs,
426+
max_dataset_rows,
427+
combine_datasets,
417428
)
418429
else:
419430
combinations = generate_parameter_combinations(
420431
input_dataset,
421432
completion_params,
422433
input_messages,
434+
input_rows,
423435
evaluation_test_kwargs,
424436
max_dataset_rows,
425437
combine_datasets,
426438
)
427439
if len(combinations) == 0:
428440
raise ValueError(
429-
"No combinations of parameters were found. Please provide at least a model and one of input_dataset or input_messages."
441+
"No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows."
430442
)
431443

432444
# Create parameter tuples for pytest.mark.parametrize
433445
param_tuples = []
434446
for combo in combinations:
435-
dataset, cp, messages, etk = combo
447+
dataset, cp, messages, rows, etk = combo
436448
param_tuple = []
437449
if input_dataset is not None:
438450
param_tuple.append(dataset)
439451
if completion_params is not None:
440452
param_tuple.append(cp)
441453
if input_messages is not None:
442454
param_tuple.append(messages)
455+
if input_rows is not None:
456+
param_tuple.append(rows)
443457
if evaluation_test_kwargs is not None:
444458
param_tuple.append(etk)
445459
param_tuples.append(tuple(param_tuple))
@@ -452,6 +466,8 @@ async def execute_with_params(
452466
test_param_names.append("completion_params")
453467
if input_messages is not None:
454468
test_param_names.append("input_messages")
469+
if input_rows is not None:
470+
test_param_names.append("input_rows")
455471
if evaluation_test_kwargs is not None:
456472
test_param_names.append("evaluation_test_kwargs")
457473

@@ -500,8 +516,11 @@ def _log_eval_error(
500516
else:
501517
# Multiple rows: list of List[Message]
502518
data = [EvaluationRow(messages=m) for m in im]
519+
elif "input_rows" in kwargs and kwargs["input_rows"] is not None:
520+
# Use pre-constructed EvaluationRow objects directly
521+
data = kwargs["input_rows"]
503522
else:
504-
raise ValueError("No input dataset or input messages provided")
523+
raise ValueError("No input dataset, input messages, or input rows provided")
505524

506525
for row in data:
507526
# generate a stable row_id for each row

eval_protocol/pytest/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
1616
DatasetPathParam = str
1717
InputMessagesParam = List[Message]
18+
InputRowsParam = List[EvaluationRow]
1819
EvaluationInputParam = Dict[str, Any]
1920
RolloutProcessorInputParam = Dict[str, Any]
2021

eval_protocol/pytest/utils.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
DatasetPathParam,
1414
EvaluationInputParam,
1515
InputMessagesParam,
16+
InputRowsParam,
1617
RolloutProcessorConfig,
1718
)
1819
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config
@@ -166,6 +167,7 @@ def generate_parameter_combinations(
166167
input_dataset: Optional[List[DatasetPathParam]],
167168
completion_params: List[CompletionParams],
168169
input_messages: Optional[List[InputMessagesParam]],
170+
input_rows: Optional[List[InputRowsParam]],
169171
evaluation_test_kwargs: Optional[List[EvaluationInputParam]],
170172
max_dataset_rows: Optional[int],
171173
combine_datasets: bool,
@@ -177,6 +179,7 @@ def generate_parameter_combinations(
177179
input_dataset: Dataset paths to use
178180
completion_params: Completion parameters to test
179181
input_messages: Input messages to use
182+
input_rows: Pre-constructed EvaluationRow objects to use
180183
evaluation_test_kwargs: Additional kwargs for evaluation tests
181184
max_dataset_rows: Maximum number of dataset rows to process
182185
combine_datasets: Whether to combine multiple datasets into one test
@@ -217,6 +220,18 @@ def generate_parameter_combinations(
217220
else:
218221
messages = [None] # type: ignore
219222

223+
# Handle input_rows - similar to input_messages, apply max_dataset_rows if specified
224+
if input_rows is not None and isinstance(input_rows, list):
225+
effective_max_rows = parse_ep_max_rows(max_dataset_rows)
226+
if effective_max_rows is not None:
227+
sliced_rows = input_rows[:effective_max_rows] # type: ignore
228+
else:
229+
sliced_rows = input_rows # type: ignore
230+
# Wrap as a single parameter payload
231+
rows = [sliced_rows] # type: ignore
232+
else:
233+
rows = [None] # type: ignore
234+
220235
kwargs: List[Optional[EvaluationInputParam]] = (
221236
evaluation_test_kwargs if evaluation_test_kwargs is not None else [None]
222237
) # type: ignore
@@ -225,13 +240,14 @@ def generate_parameter_combinations(
225240
for ds in datasets:
226241
for cp in cps:
227242
for im in messages:
228-
for etk in kwargs:
229-
# if no dataset and no messages, raise an error
230-
if ds is None and im is None:
231-
raise ValueError(
232-
"No dataset or messages provided. Please provide at least one of input_dataset or input_messages."
233-
)
234-
combinations.append((ds, cp, im, etk))
243+
for ir in rows:
244+
for etk in kwargs:
245+
# if no dataset, no messages, and no rows, raise an error
246+
if ds is None and im is None and ir is None:
247+
raise ValueError(
248+
"No dataset, messages, or rows provided. Please provide at least one of input_dataset, input_messages, or input_rows."
249+
)
250+
combinations.append((ds, cp, im, ir, etk))
235251

236252
return combinations
237253

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from eval_protocol.models import EvaluationRow, Message
2+
from eval_protocol.pytest import evaluation_test
3+
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
4+
5+
6+
@evaluation_test(
7+
input_rows=[EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])],
8+
completion_params=[{"model": "no-op"}],
9+
rollout_processor=NoOpRolloutProcessor(),
10+
mode="pointwise",
11+
)
12+
def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow:
13+
"""Run math evaluation on sample dataset using pytest interface."""
14+
assert row.messages[0].content == "What is the capital of France?"
15+
return row

0 commit comments

Comments
 (0)