Skip to content

Commit 9f352ed

Browse files
authored
gsm8k math example (#294)
* gsm8k math example * fix tests
1 parent 9902b0f commit 9f352ed

File tree

10 files changed

+1204
-85
lines changed

10 files changed

+1204
-85
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ jobs:
4141
- name: Install tau2 for testing
4242
run: uv pip install git+https://github.com/sierra-research/tau2-bench.git@main
4343

44-
- name: Ruff format (check)
45-
run: uv run ruff format --check .
46-
4744
- name: Ruff lint
4845
run: uv run ruff check .
4946

development/gsm8k_sample.jsonl

Lines changed: 1000 additions & 5 deletions
Large diffs are not rendered by default.

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@
2121
class SingleTurnRolloutProcessor(RolloutProcessor):
2222
"""Single turn rollout processor for direct LLM calls."""
2323

24+
def __init__(self, *, drop_trailing_assistant_messages: bool = True) -> None:
25+
"""
26+
Args:
27+
drop_trailing_assistant_messages: When True (default), strip any trailing
28+
assistant messages from the input conversation before calling the model.
29+
This helps when datasets include previous assistant turns and you want
30+
the model to answer the latest user query.
31+
"""
32+
self.drop_trailing_assistant_messages = drop_trailing_assistant_messages
33+
2434
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
2535
"""Generate single turn rollout tasks and return them for external handling."""
2636
# Do not modify global LiteLLM cache. Disable caching per-request instead.
@@ -32,7 +42,13 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3242
if len(row.messages) == 0:
3343
raise ValueError("Messages is empty. Please provide a non-empty dataset")
3444

35-
messages_payload = [message.model_dump() for message in row.messages]
45+
# Optionally drop trailing assistant messages for single-turn prompts
46+
messages_for_request: List[Message] = list(row.messages)
47+
if self.drop_trailing_assistant_messages:
48+
while messages_for_request and messages_for_request[-1].role == "assistant":
49+
messages_for_request.pop()
50+
51+
messages_payload = [message.model_dump() for message in messages_for_request]
3652

3753
request_params = {"messages": messages_payload, **config.completion_params}
3854
# Ensure caching is disabled only for this request (review feedback)
@@ -114,7 +130,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
114130
except Exception:
115131
pass
116132

117-
messages = list(row.messages) + [
133+
messages = list(messages_for_request) + [
118134
Message(
119135
role="assistant",
120136
content=assistant_content,

eval_protocol/pytest/exception_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import requests
1313
import httpx
1414

15+
1516
# Default exceptions that should be retried with backoff
1617
DEFAULT_RETRYABLE_EXCEPTIONS: Set[Type[Exception]] = {
1718
# Standard library exceptions
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
eval-protocol
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import re
2+
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult, Message
3+
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
4+
from typing import List, Dict, Any, Optional
5+
6+
7+
def extract_answer_digits(ground_truth: str) -> Optional[str]:
8+
"""
9+
Extract the digits from the answer string.
10+
"""
11+
answer_string = ground_truth.split("<answer>")[1].split("</answer>")[0]
12+
return re.search(r"(\d+)", answer_string).group(1) if answer_string else None
13+
14+
15+
@evaluation_test(
16+
input_dataset=["development/gsm8k_sample.jsonl"],
17+
completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
18+
max_dataset_rows=5,
19+
passed_threshold=0.0,
20+
rollout_processor=SingleTurnRolloutProcessor(),
21+
mode="pointwise",
22+
evaluation_test_kwargs=[
23+
{"math_reward_kwargs": {"tolerance": 0.001, "absolute_tolerance": 1e-8, "require_units": False}}
24+
],
25+
)
26+
def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow:
27+
"""
28+
Evaluate math problem solving considering both accuracy and format.
29+
30+
This function demonstrates how to combine multiple evaluation criteria:
31+
- Numerical accuracy using built-in math evaluation (80% weight)
32+
- Format compliance checking for <think>...</think><answer>...</answer> structure (20% weight)
33+
34+
Args:
35+
row: EvaluationRow containing the conversation messages and ground truth
36+
**kwargs: Additional parameters (like math_reward_kwargs)
37+
38+
Returns:
39+
EvaluationRow with the evaluation result
40+
"""
41+
#### Get predicted answer value
42+
prediction = extract_answer_digits(str(row.messages[2].content))
43+
gt = extract_answer_digits(str(row.ground_truth))
44+
45+
#### Get score
46+
if prediction is None or gt is None:
47+
score = 0
48+
reason = "Missing answer tags in prediction or ground truth."
49+
50+
elif gt == prediction:
51+
score = 1
52+
reason = "Model answer is correct."
53+
54+
else:
55+
score = 0
56+
reason = "Model answer is not correct."
57+
58+
reason += f" Prediction: {prediction}, Ground Truth: {gt}"
59+
60+
evaluation_result = EvaluateResult(
61+
score=score, # Required: The final evaluation score
62+
is_score_valid=True, # Optional: Whether the score is valid, true by default
63+
reason=reason, # Optional: The reason for the score
64+
)
65+
row.evaluation_result = evaluation_result
66+
return row

tests/pytest/test_pytest_math_example.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

tests/pytest/test_pytest_math_format_length.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55
from eval_protocol.rewards.length import count_tokens
66
from eval_protocol.rewards.math import math_reward
77
from examples.math_with_format_and_length.main import check_think_answer_format
8-
from tests.pytest.helper.gsm8k_to_evaluation_row import gsm8k_to_evaluation_row
98

109

1110
@evaluation_test(
1211
input_dataset=["development/gsm8k_sample.jsonl"],
13-
dataset_adapter=gsm8k_to_evaluation_row,
1412
completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
1513
max_dataset_rows=5,
1614
passed_threshold=0.0,

tests/pytest/test_pytest_word_count_example.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult
44
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
5-
from tests.pytest.helper.word_count_to_evaluation_row import word_count_to_evaluation_row
65

76

87
@evaluation_test(
98
input_dataset=["development/gsm8k_sample.jsonl"],
10-
dataset_adapter=word_count_to_evaluation_row,
119
completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
1210
max_dataset_rows=5,
1311
passed_threshold=0.3, # Reasonable threshold for word count evaluation
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import asyncio
2+
from types import SimpleNamespace
3+
4+
import pytest
5+
6+
from eval_protocol.models import EvaluationRow, Message
7+
from eval_protocol.pytest import SingleTurnRolloutProcessor
8+
9+
10+
class _DummyConfig:
11+
def __init__(self):
12+
self.completion_params = {"model": "fake-model", "temperature": 0}
13+
self.semaphore = asyncio.Semaphore(10)
14+
15+
16+
@pytest.mark.asyncio
17+
async def test_single_turn_drops_trailing_assistant_by_default(monkeypatch):
18+
# Arrange dataset row with trailing assistant message
19+
row = EvaluationRow(
20+
messages=[
21+
Message(role="user", content="What is 2+2?"),
22+
Message(role="assistant", content="Old response"),
23+
]
24+
)
25+
26+
# Capture the messages payload passed to the LLM call
27+
captured = {}
28+
29+
# Patch module-level imports in the processor module
30+
import eval_protocol.pytest.default_single_turn_rollout_process as mod
31+
32+
class StubChoices:
33+
pass
34+
35+
class StubModelResponse:
36+
def __init__(self, text: str):
37+
self.choices = [StubChoices()]
38+
# Emulate OpenAI-like response.message fields
39+
self.choices[0].message = SimpleNamespace(content=text, tool_calls=None)
40+
# Minimal usage payload
41+
self.usage = SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2)
42+
43+
async def fake_acompletion(**kwargs):
44+
# Verify that trailing assistant was dropped before sending
45+
msgs = kwargs.get("messages", [])
46+
assert msgs, "Expected non-empty messages payload"
47+
captured["messages"] = msgs
48+
assert msgs[-1]["role"] != "assistant", "Trailing assistant should be dropped by default"
49+
return StubModelResponse(text="4")
50+
51+
# Monkeypatch the processor module's symbols to avoid dependency on litellm types
52+
monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True)
53+
monkeypatch.setattr(mod, "Choices", StubChoices, raising=True)
54+
monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True)
55+
56+
processor = SingleTurnRolloutProcessor()
57+
config = _DummyConfig()
58+
59+
# Act
60+
tasks = processor([row], config)
61+
out = await tasks[0]
62+
63+
# Assert: request trimmed the trailing assistant
64+
sent_msgs = captured["messages"]
65+
assert len(sent_msgs) == 1
66+
assert sent_msgs[0]["role"] == "user"
67+
assert out.messages[-1].role == "assistant"
68+
assert out.messages[-1].content == "4"
69+
# Ensure previous trailing assistant was not duplicated
70+
assert [m.role for m in out.messages] == ["user", "assistant"]
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_single_turn_keeps_trailing_assistant_when_disabled(monkeypatch):
75+
# Arrange dataset row with trailing assistant message
76+
row = EvaluationRow(
77+
messages=[
78+
Message(role="user", content="Say hi"),
79+
Message(role="assistant", content="Hi!"),
80+
]
81+
)
82+
83+
captured = {}
84+
85+
import eval_protocol.pytest.default_single_turn_rollout_process as mod
86+
87+
class StubChoices:
88+
pass
89+
90+
class StubModelResponse:
91+
def __init__(self, text: str):
92+
self.choices = [StubChoices()]
93+
self.choices[0].message = SimpleNamespace(content=text, tool_calls=None)
94+
self.usage = SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2)
95+
96+
async def fake_acompletion(**kwargs):
97+
msgs = kwargs.get("messages", [])
98+
captured["messages"] = msgs
99+
# With opt-out, trailing assistant is preserved
100+
assert msgs[-1]["role"] == "assistant"
101+
return StubModelResponse(text="Hello again")
102+
103+
monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True)
104+
monkeypatch.setattr(mod, "Choices", StubChoices, raising=True)
105+
monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True)
106+
107+
processor = SingleTurnRolloutProcessor(drop_trailing_assistant_messages=False)
108+
config = _DummyConfig()
109+
110+
# Act
111+
tasks = processor([row], config)
112+
out = await tasks[0]
113+
114+
# Assert: both original messages plus new assistant
115+
sent_msgs = captured["messages"]
116+
assert [m["role"] for m in sent_msgs] == ["user", "assistant"]
117+
assert [m.role for m in out.messages] == ["user", "assistant", "assistant"]
118+
assert out.messages[-1].content == "Hello again"

0 commit comments

Comments
 (0)