Skip to content

Commit 7b3defb

Browse files
authored
accept evaluation_test kwargs in decorator and delete "evaluate" (#8)
* accept evaluation_test kwargs in decorator and delete "evaluate" * update doc * update docs
1 parent 0b4a0a3 commit 7b3defb

File tree

9 files changed

+199
-70
lines changed

9 files changed

+199
-70
lines changed

eval_protocol/pytest/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@
33
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
44
from .evaluation_test import evaluation_test
55
from .types import RolloutProcessor, RolloutProcessorConfig
6-
from .utils import evaluate
76

87
__all__ = [
98
"default_agent_rollout_processor",
109
"default_no_op_rollout_processor",
1110
"default_single_turn_rollout_processor",
1211
"RolloutProcessor",
1312
"RolloutProcessorConfig",
14-
"evaluate",
1513
"evaluation_test",
1614
]

eval_protocol/pytest/evaluation_test.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from eval_protocol.pytest.types import (
99
Dataset,
1010
DatasetPathParam,
11+
EvaluationInputParam,
1112
EvaluationTestMode,
1213
InputMessagesParam,
13-
InputParam,
1414
ModelParam,
15+
RolloutInputParam,
1516
RolloutProcessor,
1617
RolloutProcessorConfig,
1718
TestFunction,
@@ -32,8 +33,9 @@ def evaluation_test(
3233
input_messages: Optional[List[InputMessagesParam]] = None,
3334
input_dataset: Optional[List[DatasetPathParam]] = None,
3435
dataset_adapter: Optional[Callable[[List[Dict[str, Any]]], Dataset]] = lambda x: x,
35-
input_params: Optional[List[InputParam]] = None,
36+
rollout_input_params: Optional[List[RolloutInputParam]] = None,
3637
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
38+
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
3739
aggregation_method: AggregationMethod = "mean",
3840
threshold_of_success: Optional[float] = None,
3941
num_runs: int = 1,
@@ -56,8 +58,9 @@ def evaluation_test(
5658
to a list of EvaluationRows if you have a custom dataset format.
5759
dataset_adapter: Function to convert the input dataset to a list of
5860
EvaluationRows. This is useful if you have a custom dataset format.
59-
input_params: Generation parameters for the model.
61+
rollout_input_params: Generation parameters for the rollout.
6062
rollout_processor: Function used to perform the rollout.
63+
evaluation_test_kwargs: Kwargs for the evaluation function.
6164
aggregation_method: How to aggregate scores across rows.
6265
threshold_of_success: If set, fail the test if the aggregated score is
6366
below this threshold.
@@ -104,12 +107,19 @@ def execute_with_params(
104107
test_func: TestFunction,
105108
row: EvaluationRow | None = None,
106109
input_dataset: List[EvaluationRow] | None = None,
110+
evaluation_test_kwargs: Optional[EvaluationInputParam] = None,
107111
):
108112
kwargs = {}
109113
if input_dataset is not None:
110114
kwargs["rows"] = input_dataset
111115
if row is not None:
112116
kwargs["row"] = row
117+
if evaluation_test_kwargs is not None:
118+
if "row" in evaluation_test_kwargs:
119+
raise ValueError("'row' is a reserved parameter for the evaluation function")
120+
if "rows" in evaluation_test_kwargs:
121+
raise ValueError("'rows' is a reserved parameter for the evaluation function")
122+
kwargs.update(evaluation_test_kwargs)
113123
return execute_function(test_func, **kwargs)
114124

115125
# Calculate all possible combinations of parameters
@@ -118,21 +128,23 @@ def generate_combinations():
118128

119129
# Handle optional parameters with defaults
120130
datasets: List[Optional[DatasetPathParam]] = input_dataset if input_dataset is not None else [None] # type: ignore
121-
params: List[Optional[InputParam]] = input_params if input_params is not None else [None] # type: ignore
131+
params: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
122132
messages: List[Optional[InputMessagesParam]] = input_messages if input_messages is not None else [None] # type: ignore
133+
kwargs: List[Optional[EvaluationInputParam]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] # type: ignore
123134

124135
# Generate all combinations
125136
for m in model:
126137
for ds in datasets:
127138
for ip in params:
128139
for im in messages:
129-
# Skip combinations that don't make sense
130-
# If we have a dataset, we should have params for rollout
131-
if ds is not None and ip is None:
132-
continue
133-
# If we have messages but no dataset, that's fine
134-
# If we have no dataset and no messages, that's also fine
135-
combinations.append((m, ds, ip, im))
140+
for etk in kwargs:
141+
# Skip combinations that don't make sense
142+
# If we have a dataset, we should have params for rollout
143+
if ds is not None and ip is None:
144+
continue
145+
# If we have messages but no dataset, that's fine
146+
# If we have no dataset and no messages, that's also fine
147+
combinations.append((m, ds, ip, im, etk))
136148

137149
return combinations
138150

@@ -141,27 +153,31 @@ def generate_combinations():
141153
# Create parameter tuples for pytest.mark.parametrize
142154
param_tuples = []
143155
for combo in combinations:
144-
model_name, dataset, params, messages = combo
156+
model_name, dataset, params, messages, etk = combo
145157
param_tuple = [model_name]
146158
if input_dataset is not None:
147159
param_tuple.append(dataset)
148-
if input_params is not None:
160+
if rollout_input_params is not None:
149161
param_tuple.append(params)
150162
if input_messages is not None:
151163
param_tuple.append(messages)
164+
if evaluation_test_kwargs is not None:
165+
param_tuple.append(etk)
152166
param_tuples.append(tuple(param_tuple))
153167

154168
# For batch mode, use the original parameter names
155169
test_param_names = ["model"]
156170
if input_dataset is not None:
157171
test_param_names.append("dataset_path")
158-
if input_params is not None:
172+
if rollout_input_params is not None:
159173
test_param_names.append("input_params")
160174
if input_messages is not None:
161175
test_param_names.append("input_messages")
176+
if evaluation_test_kwargs is not None:
177+
test_param_names.append("evaluation_test_kwargs")
162178

163179
# Create wrapper function with exact signature that pytest expects
164-
def create_wrapper_with_signature():
180+
def create_wrapper_with_signature() -> Callable:
165181
# Create the function body that will be used
166182
def wrapper_body(**kwargs):
167183
model_name = kwargs["model"]
@@ -193,6 +209,7 @@ def wrapper_body(**kwargs):
193209
result = execute_with_params(
194210
test_func,
195211
row=row,
212+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
196213
)
197214
if result is None or not isinstance(result, EvaluationRow):
198215
raise ValueError(
@@ -204,6 +221,7 @@ def wrapper_body(**kwargs):
204221
results = execute_with_params(
205222
test_func,
206223
input_dataset=input_dataset,
224+
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
207225
)
208226
if results is None:
209227
raise ValueError(
@@ -234,6 +252,7 @@ def wrapper_body(**kwargs):
234252

235253
wrapper = create_wrapper_with_signature()
236254
wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(wrapper)
255+
wrapper.original_evaluation_test_func = test_func
237256

238257
return wrapper
239258

eval_protocol/pytest/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
1111
DatasetPathParam = str
12-
InputParam = Dict[str, Any]
12+
RolloutInputParam = Dict[str, Any]
1313
InputMessagesParam = List[Message]
14+
EvaluationInputParam = Dict[str, Any]
1415

1516
Dataset = List[EvaluationRow]
1617

@@ -37,7 +38,7 @@
3738
@dataclass
3839
class RolloutProcessorConfig:
3940
model: ModelParam
40-
input_params: InputParam # optional input parameters for inference
41+
input_params: RolloutInputParam # optional input parameters for inference
4142
mcp_config_path: str # for agent rollout processor
4243

4344

eval_protocol/pytest/utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,6 @@ def execute_function(func: Callable, **kwargs) -> Any:
3939
return results
4040

4141

42-
def evaluate(
43-
rows: List[EvaluationRow], reward_fn: Callable[..., EvaluateResult], **kwargs: Any
44-
) -> List[EvaluationRow]:
45-
"""Apply a reward function to each row and attach the result."""
46-
evaluated: List[EvaluationRow] = []
47-
for row in rows:
48-
result = reward_fn(messages=row.messages, ground_truth=row.ground_truth, **kwargs)
49-
row.evaluation_result = result
50-
evaluated.append(row)
51-
return evaluated
52-
53-
5442
AggregationMethod = Literal["mean", "max", "min"]
5543

5644

eval_protocol/rewards/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,11 @@ def math_reward(
565565
require_units: bool = False,
566566
**kwargs: Any,
567567
) -> EvaluateResult:
568+
"""
569+
NOTE: This is the deprecated/old way of creating an eval in Eval Protocol.
570+
What use to be the @reward_function decorator is now the @evaluation_test
571+
decorator with the mode="pointwise" parameter.
572+
"""
568573
if (
569574
not messages
570575
or not isinstance(messages[-1], Message)

tests/pytest/test_markdown_highlighting.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Dict, List, Optional
99

1010
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
11-
from eval_protocol.pytest import evaluation_test, default_single_turn_rollout_processor, evaluate
11+
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
1212

1313

1414
def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
@@ -21,17 +21,27 @@ def markdown_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
2121
]
2222

2323

24-
def markdown_format_evaluate(messages: List[Message], ground_truth: Optional[str] = None, **kwargs) -> EvaluateResult:
24+
@evaluation_test(
25+
input_dataset=["tests/pytest/data/markdown_dataset.jsonl"],
26+
dataset_adapter=markdown_dataset_to_evaluation_row,
27+
model=["accounts/fireworks/models/llama-v3p1-8b-instruct"],
28+
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
29+
threshold_of_success=1.0,
30+
rollout_processor=default_single_turn_rollout_processor,
31+
num_runs=1,
32+
mode="pointwise",
33+
)
34+
def test_markdown_highlighting_evaluation(row: EvaluationRow) -> EvaluationRow:
2535
"""
2636
Evaluation function that checks if the model's response contains the required number of formatted sections.
2737
"""
2838

29-
assistant_response = messages[-1].content
39+
assistant_response = row.messages[-1].content
3040

3141
if not assistant_response:
3242
return EvaluateResult(score=0.0, reason="❌ No assistant response found")
3343

34-
required_highlights = int(ground_truth)
44+
required_highlights = int(row.ground_truth)
3545

3646
# Check if the response contains the required number of formatted sections
3747
# e.g. **bold** or *italic*
@@ -50,26 +60,11 @@ def markdown_format_evaluate(messages: List[Message], ground_truth: Optional[str
5060
meets_requirement = actual_count >= required_highlights
5161

5262
if meets_requirement:
53-
return EvaluateResult(
63+
row.evaluation_result = EvaluateResult(
5464
score=1.0, reason=f"✅ Found {actual_count} highlighted sections (required: {required_highlights})"
5565
)
5666
else:
57-
return EvaluateResult(
67+
row.evaluation_result = EvaluateResult(
5868
score=0.0, reason=f"❌ Only found {actual_count} highlighted sections (required: {required_highlights})"
5969
)
60-
61-
62-
@evaluation_test(
63-
input_dataset=["tests/pytest/data/markdown_dataset.jsonl"],
64-
dataset_adapter=markdown_dataset_to_evaluation_row,
65-
model=["accounts/fireworks/models/llama-v3p1-8b-instruct"],
66-
input_params=[{"temperature": 0.0, "max_tokens": 4096}],
67-
threshold_of_success=1.0,
68-
rollout_processor=default_single_turn_rollout_processor,
69-
num_runs=1,
70-
)
71-
def test_markdown_highlighting_evaluation(rows: List[EvaluationRow]) -> List[EvaluationRow]:
72-
"""
73-
Test markdown highlighting validation using batch mode with evaluate().
74-
"""
75-
return evaluate(rows, markdown_format_evaluate)
70+
return row
Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,76 @@
1-
from typing import List
2-
from eval_protocol.models import EvaluationRow
3-
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluate, evaluation_test
4-
from examples.math_example.main import evaluate as math_evaluate
1+
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult
2+
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
3+
from eval_protocol.rewards.math import math_reward
4+
from examples.math_example.main import check_think_answer_format
55
from tests.pytest.helper.gsm8k_to_evaluation_row import gsm8k_to_evaluation_row
66

77

88
@evaluation_test(
99
input_dataset=["development/gsm8k_sample.jsonl"],
1010
dataset_adapter=gsm8k_to_evaluation_row,
1111
model=["accounts/fireworks/models/kimi-k2-instruct"],
12-
input_params=[{"temperature": 0.0}],
12+
rollout_input_params=[{"temperature": 0.0}],
1313
max_dataset_rows=5,
1414
threshold_of_success=0.0,
1515
rollout_processor=default_single_turn_rollout_processor,
16+
mode="pointwise",
17+
evaluation_test_kwargs=[
18+
{"math_reward_kwargs": {"tolerance": 0.001, "absolute_tolerance": 1e-8, "require_units": False}}
19+
],
1620
)
17-
def test_math_dataset(rows: List[EvaluationRow]) -> List[EvaluationRow]:
18-
"""Run math evaluation on sample dataset using pytest interface."""
19-
return evaluate(rows, math_evaluate)
21+
def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow:
22+
"""
23+
Evaluate math problem solving considering both accuracy and format.
24+
25+
This function demonstrates how to combine multiple evaluation criteria:
26+
- Numerical accuracy using built-in math evaluation
27+
- Format compliance checking for <think>...</think><answer>...</answer> structure
28+
29+
Args:
30+
row: EvaluationRow containing the conversation messages and ground truth
31+
**kwargs: Additional parameters (like math_reward_kwargs)
32+
33+
Returns:
34+
EvaluationRow with the evaluation result
35+
"""
36+
# Get the assistant's response
37+
assistant_message = row.messages[-1]
38+
if isinstance(assistant_message, dict):
39+
assistant_response = assistant_message.get("content", "")
40+
else:
41+
assistant_response = assistant_message.content or ""
42+
43+
# Evaluate numerical accuracy using built-in function
44+
accuracy_result = math_reward(messages=row.messages, ground_truth=row.ground_truth, **kwargs["math_reward_kwargs"])
45+
46+
# Evaluate format compliance (looking for <think>...</think><answer>...</answer> format)
47+
format_correct = check_think_answer_format(assistant_response)
48+
format_score = 1.0 if format_correct else 0.0
49+
50+
# For math_example, accuracy takes priority - if accuracy is 0, overall score is 0
51+
# If accuracy is 1, then format can contribute to the score
52+
if accuracy_result.score == 0.0:
53+
combined_score = 0.0
54+
else:
55+
combined_score = accuracy_result.score # Only accuracy matters for math_example
56+
57+
# Create metrics structure expected by tests
58+
metrics = {
59+
"accuracy_reward": MetricResult(
60+
score=accuracy_result.score,
61+
reason=f"Numerical accuracy: {accuracy_result.reason}",
62+
is_score_valid=True,
63+
),
64+
"format_reward": MetricResult(
65+
score=format_score,
66+
reason=f"Format compliance: {'correct' if format_correct else 'incorrect'} <think>...</think><answer>...</answer> structure",
67+
is_score_valid=True,
68+
),
69+
}
70+
71+
row.evaluation_result = EvaluateResult(
72+
score=combined_score,
73+
reason=f"Combined score: {combined_score:.2f} (accuracy: {accuracy_result.score:.2f}, format: {format_score:.2f})",
74+
metrics=metrics,
75+
)
76+
return row

0 commit comments

Comments
 (0)