Skip to content

Commit 6045ee9

Browse files
authored
Use LLM finish reason as the termination reason (#56)
* let policy decide end of loop * add * fix linter * fix lint * fix test * update * rename RolloutStatus reason to termination_reason
1 parent a5e1479 commit 6045ee9

File tree

8 files changed

+89
-60
lines changed

8 files changed

+89
-60
lines changed

eval_protocol/mcp/execution/base_policy.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ async def _generate_live_tool_calls(
151151
tool_schemas: List[Dict],
152152
env_index: int,
153153
conversation_history: List[Dict[str, Any]],
154-
) -> Tuple[List[MCPToolCall], CompletionUsage]:
154+
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
155155
"""
156156
Generate tool calls using conversation history for proper OpenAI trajectories.
157157
@@ -161,7 +161,7 @@ async def _generate_live_tool_calls(
161161
user_prompt: Current user prompt with observation
162162
163163
Returns:
164-
List of MCPToolCall objects
164+
List of MCPToolCall objects, LLM usage stats, and finish reason
165165
"""
166166
# Convert MCP tools to LLM format
167167
llm_tools = self._convert_mcp_tools_to_llm_format(tool_schemas)
@@ -190,6 +190,8 @@ async def _generate_live_tool_calls(
190190
total_tokens=response["usage"]["total_tokens"],
191191
)
192192

193+
finish_reason = response["choices"][0]["finish_reason"]
194+
193195
# Extract tool call from response
194196
message = response["choices"][0]["message"]
195197
logger.debug(f"Environment {env_index} - Response message: {message}")
@@ -217,15 +219,19 @@ async def _generate_live_tool_calls(
217219
if self.max_tools_per_turn:
218220
mcp_tool_calls = mcp_tool_calls[: self.max_tools_per_turn]
219221

220-
return mcp_tool_calls, usage_stats
222+
return mcp_tool_calls, usage_stats, finish_reason
221223
else:
222224
# No tool calls in response - this is normal when episode ends or LLM provides only text
223225
logger.debug(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
224-
return [
225-
MCPToolCall(
226-
tool_name="_no_tool_call",
227-
arguments={
228-
"reason": "no_tool_call_generated",
229-
},
230-
)
231-
], usage_stats
226+
return (
227+
[
228+
MCPToolCall(
229+
tool_name="_no_tool_call",
230+
arguments={
231+
"reason": "no_tool_call_generated",
232+
},
233+
)
234+
],
235+
usage_stats,
236+
finish_reason,
237+
)

eval_protocol/mcp/execution/manager.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -169,21 +169,14 @@ async def _execute_with_semaphore(idx):
169169
max_tool_calls=getattr(policy, "max_tools_per_turn", None),
170170
)
171171
if trajectory.terminated:
172-
if trajectory.termination_reason in {
173-
TerminationReason.CONTROL_PLANE_SIGNAL,
174-
TerminationReason.USER_STOP,
175-
}:
176-
evaluation_rows[idx].rollout_status.status = "finished"
177-
elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}:
178-
evaluation_rows[idx].rollout_status.status = "stopped"
179-
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
180-
"termination_reason", trajectory.termination_reason
181-
)
182-
else:
172+
if trajectory.termination_reason == TerminationReason.ERROR:
183173
evaluation_rows[idx].rollout_status.status = "error"
184-
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
174+
evaluation_rows[idx].rollout_status.termination_reason = trajectory.control_plane_summary.get(
185175
"error_message", None
186176
)
177+
else:
178+
evaluation_rows[idx].rollout_status.status = "finished"
179+
evaluation_rows[idx].rollout_status.termination_reason = trajectory.termination_reason
187180
else:
188181
evaluation_rows[idx].rollout_status.status = "running"
189182

@@ -266,7 +259,7 @@ async def _execute_rollout(
266259

267260
# Run rollout loop for this specific environment
268261
step = 0
269-
rollout_end = False
262+
env_end = False # if the env indicates the rollout reaches the goal
270263

271264
while step < steps and not trajectory.terminated:
272265
turn_completed = False
@@ -297,7 +290,9 @@ async def _execute_rollout(
297290

298291
# In each turn: keep looping until assistant is ready to provide final response
299292
while not turn_completed and not trajectory.terminated:
300-
tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
293+
tool_calls, usage_stats, finish_reason = await policy(
294+
tool_schema, rollout_idx, conversation_history
295+
)
301296

302297
# calc llm usage stats happened in this turn if there is aany
303298
if usage_stats:
@@ -311,17 +306,17 @@ async def _execute_rollout(
311306
if tool_calls[0].tool_name == "_no_tool_call" and user_simulator:
312307
turn_completed = True
313308
break
314-
# If there's no user simulator, no tool call means policy failed and we should terminate the rollout
309+
# If there's no user simulator, then it marks the end of the episode as LLM think there is no tool call needed.
315310
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
316311
trajectory.terminated = True
317-
trajectory.termination_reason = TerminationReason.INTERRUPTED
312+
trajectory.termination_reason = TerminationReason.from_str(finish_reason)
318313
break
319314

320315
# Execute each tool call sequentially
321316
for tool_call in tool_calls:
322317

323318
# Execute tool call for this environment
324-
observation, reward, rollout_end, info = await envs.step(rollout_idx, tool_call)
319+
observation, reward, env_end, info = await envs.step(rollout_idx, tool_call)
325320

326321
tool_response = envs.format_tool_response(observation)
327322

@@ -331,7 +326,7 @@ async def _execute_rollout(
331326
tool_response,
332327
conversation_history,
333328
reward,
334-
rollout_end,
329+
env_end,
335330
info,
336331
)
337332

@@ -354,7 +349,7 @@ async def _execute_rollout(
354349
control_plane_step = {
355350
"step": step - 1,
356351
"reward": reward,
357-
"terminated": rollout_end,
352+
"terminated": env_end,
358353
"info": info.get("control_plane", {}),
359354
"tool_calls": [f"{tool_call.tool_name}({tool_call.arguments})"],
360355
"num_tool_calls": 1,
@@ -367,11 +362,13 @@ async def _execute_rollout(
367362
if recording_mode:
368363
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
369364

370-
if rollout_end:
365+
if env_end:
366+
# if the env marks the end of the rollout, break the tool call loop
367+
# but set the termination reason later after the final policy call
371368
trajectory.terminated = True
372-
trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL
373369
break
374-
elif step >= steps:
370+
371+
if step >= steps:
375372
trajectory.terminated = True
376373
trajectory.termination_reason = TerminationReason.MAX_STEPS
377374
break
@@ -392,7 +389,7 @@ async def _execute_rollout(
392389
control_plane_step = {
393390
"step": step - 1,
394391
"reward": reward,
395-
"terminated": rollout_end,
392+
"terminated": env_end,
396393
"info": info.get("control_plane", {}),
397394
"tool_calls": tool_calls_summary,
398395
"num_tool_calls": len(tool_calls),
@@ -404,19 +401,16 @@ async def _execute_rollout(
404401
if recording_mode:
405402
policy.log_conversation_state_for_playback(rollout_idx, step - 1, conversation_history)
406403

407-
# Use control plane information for termination decision
408-
if rollout_end:
409-
trajectory.terminated = True
410-
trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL
411-
412-
# tool indicates rollout should be terminated, call policy one last time to get the final response
413-
_, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
404+
# if the env marks end, update control plane summary and do one last policy call, then break the agent loop
405+
# this is to ensure each turn ends with an assistant message, which will align with the actual agentic llm behavior
406+
if env_end:
407+
_, usage_stats, finish_reason = await policy(tool_schema, rollout_idx, conversation_history)
414408
if usage_stats:
415409
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
416410
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
417411
trajectory.usage["total_tokens"] += usage_stats.total_tokens
418-
419-
# Add final control plane summary
412+
trajectory.terminated = True
413+
trajectory.termination_reason = TerminationReason.from_str(finish_reason)
420414
trajectory.control_plane_summary.update(
421415
{
422416
"total_reward": trajectory.total_reward,
@@ -445,7 +439,7 @@ async def _execute_rollout(
445439
)
446440

447441
logger.info(
448-
f"🏁 Rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}"
442+
f"🏁 Environmnet indicates rollout {rollout_idx} terminated at step {step} (reward: {trajectory.total_reward}) in thread {threading.current_thread().name}"
449443
)
450444
break
451445

eval_protocol/mcp/execution/policy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
213213
if response.choices[0].message.tool_calls
214214
else []
215215
),
216-
}
216+
},
217+
"finish_reason": response.choices[0].finish_reason,
217218
}
218219
],
219220
"usage": {

eval_protocol/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,10 @@ class RolloutStatus(BaseModel):
270270
error: Rollout failed.
271271
stopped: Rollout terminated unexpectedly (e.g. max step, control plane signal, user stop).
272272
"""
273-
status: Literal["running", "finished", "error", "stopped"] = Field(
274-
"finished", description="Status of the rollout."
273+
status: Literal["running", "finished", "error"] = Field("running", description="Status of the rollout.")
274+
termination_reason: Optional[str] = Field(
275+
"", description="reason of the rollout status, mapped to values in TerminationReason"
275276
)
276-
error_message: Optional[str] = Field(None, description="Error message if the rollout failed.")
277277

278278

279279
class EvaluationRow(BaseModel):

eval_protocol/playback_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ async def _generate_live_tool_calls(
207207
tool_schemas: List[Dict],
208208
env_index: int,
209209
conversation_history: List[Dict[str, Any]],
210-
) -> Tuple[List["MCPToolCall"], CompletionUsage]:
210+
) -> Tuple[List["MCPToolCall"], CompletionUsage, str]:
211211
"""
212212
Generate tool calls in live mode. Concrete classes must implement this.
213213
@@ -253,7 +253,7 @@ async def __call__(
253253
]
254254

255255
# Return the recorded tool call
256-
return self._extract_tool_call_from_messages(messages, env_index), None
256+
return self._extract_tool_call_from_messages(messages, env_index), None, None
257257
else:
258258
# Live mode - generate tool call using provided conversation history
259259
return await self._generate_live_tool_calls(tool_schemas, env_index, conversation_history)

eval_protocol/types/types.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from contextlib import AsyncExitStack
12
from dataclasses import dataclass, field
23
from enum import Enum
34
from typing import Any, Dict, List, Optional
5+
46
from mcp.client.session import ClientSession
5-
from contextlib import AsyncExitStack
67

78

89
class TerminationReason(str, Enum):
@@ -11,15 +12,38 @@ class TerminationReason(str, Enum):
1112
MAX_STEPS: Trajectory ends because we hit the step limit
1213
CONTROL_PLANE_SIGNAL: Trajectory ends because the control plane signals termination (e.g. env goal reached or failure condition)
1314
USER_STOP: Trajectory ends because the simulated user signals to stop
14-
INTERRUPTED: Trajectory ends unexpectedly, for example, expecting tool call but there is no tool call
1515
ERROR: Trajectory ends because of an error
16+
STOP: Trajectory ends by the policy (mapped to llm response stop reason "stop")
17+
LENGTH: Trajectory ends by the policy (mapped to llm response stop reason "length")
18+
TOOL_CALLS: Trajectory ends by the policy with a hanging tool call response (mapped to llm response stop reason "tool_calls")
1619
"""
1720

1821
MAX_STEPS = "max_steps"
1922
CONTROL_PLANE_SIGNAL = "control_plane_signal"
2023
USER_STOP = "user_stop"
21-
INTERRUPTED = "interrupted"
2224
ERROR = "error"
25+
STOP = "stop"
26+
LENGTH = "length"
27+
TOOL_CALLS = "tool_calls"
28+
29+
@classmethod
30+
def from_str(cls, value: str) -> "TerminationReason":
31+
if value == "stop":
32+
return cls.STOP
33+
elif value == "length":
34+
return cls.LENGTH
35+
elif value == "max_steps":
36+
return cls.MAX_STEPS
37+
elif value == "control_plane_signal":
38+
return cls.CONTROL_PLANE_SIGNAL
39+
elif value == "user_stop":
40+
return cls.USER_STOP
41+
elif value == "error":
42+
return cls.ERROR
43+
elif value == "tool_calls":
44+
return cls.TOOL_CALLS
45+
else:
46+
raise ValueError(f"Invalid termination reason: {value}")
2347

2448

2549
@dataclass

eval_protocol/utils/static_policy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def _generate_live_tool_calls(
7373
tool_schemas: List[Dict],
7474
env_index: int,
7575
conversation_history: List[Dict[str, Any]],
76-
) -> Tuple[List[MCPToolCall], CompletionUsage]:
76+
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
7777
"""
7878
Generate tool calls in live mode using the static action sequence.
7979
@@ -106,7 +106,7 @@ async def _generate_live_tool_calls(
106106
logger.debug(f"🎮 Env {env_index} step {step_count}: {action}")
107107

108108
usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
109-
return [tool_call], usage_stats
109+
return [tool_call], usage_stats, None
110110

111111
def add_tool_response(
112112
self,
@@ -220,7 +220,7 @@ async def _generate_live_tool_calls(
220220
tool_schemas: List[Dict],
221221
env_index: int,
222222
conversation_history: List[Dict[str, Any]],
223-
) -> Tuple[List[MCPToolCall], CompletionUsage]:
223+
) -> Tuple[List[MCPToolCall], CompletionUsage, str]:
224224
"""
225225
Generate random tool calls in live mode.
226226
@@ -241,7 +241,7 @@ async def _generate_live_tool_calls(
241241
logger.debug(f"🎲 Env {env_index}: {action}")
242242

243243
usage_stats = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
244-
return [tool_call], usage_stats
244+
return [tool_call], usage_stats, None
245245

246246
def add_tool_response(
247247
self,

tests/test_rollout_control_plane_integration.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,13 @@ async def __call__(self, tool_schema, env_index, conversation_history):
5050
tool_calls = []
5151
tool_call = MCPToolCall(tool_name="lake_move", arguments={"action": action})
5252
tool_calls.append(tool_call)
53+
if self.step_count == 3:
54+
self.step_count += 1
55+
no_tool_call = MCPToolCall(tool_name="_no_tool_call", arguments={})
56+
return [no_tool_call], None, "stop"
5357

5458
self.step_count += 1
55-
return tool_calls, None
59+
return tool_calls, None, None
5660

5761
def add_tool_response(
5862
self,
@@ -285,11 +289,11 @@ def mock_step_side_effect(env_index, tool_call):
285289
final_cp_step = final_msg.control_plane_step
286290
assert final_cp_step["terminated"] == True, "Final step should be terminated"
287291
assert final_cp_step["reward"] == 1.0, "Final step should have correct reward"
288-
assert final_cp_step["termination_reason"] == "control_plane_signal", "Should terminate via control plane"
292+
assert final_cp_step["termination_reason"] == "stop", "Should terminate via control plane"
289293
assert final_cp_step["step"] == 2, "Should record final step"
290294

291295
# Validate policy interaction
292-
assert policy.step_count == 4, "Policy should have been called 3 times"
296+
assert policy.step_count == 4, "Policy should have been called 4 times"
293297

294298
@pytest.mark.asyncio
295299
async def test_rollout_trajectory_recording_with_control_plane(self):

0 commit comments

Comments
 (0)