Skip to content

[FEAT] Add agent state awareness to final answer checks #1556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/en/guided_tour.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ The `final_answer_checks` parameter gives you more control over when and how an
from smolagents import CodeAgent, InferenceClientModel

# Define a custom final answer check function
def is_integer(final_answer: str, agent_memory=None) -> bool:
def is_integer(final_answer: str, agent_memory=None, agent_state=None) -> bool:
"""Return True if final_answer is an integer."""
try:
int(final_answer)
Expand All @@ -351,7 +351,7 @@ agent.run("Calculate the least common multiple of 3 and 7")
```

The `final_answer_checks` parameter accepts a list of functions that each:
- Take the agent's final_answer string the agent's memory as parameters
- Take the agent's final_answer string, the agent's memory and the agent's state as parameters
- Return a boolean indicating whether the final_answer is valid (True) or not (False)

If any function returns `False`, the agent will log the error message and continue the run.
Expand Down
4 changes: 2 additions & 2 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class MultiStepAgent(ABC):
provide_run_summary (`bool`, *optional*): Whether to provide a run summary when called as a managed agent.
final_answer_checks (`list[Callable]`, *optional*): List of validation functions to run before accepting a final answer.
Each function should:
- Take the final answer and the agent's memory as arguments.
- Take the final answer, the agent's memory and the agent's state as arguments.
- Return a boolean indicating whether the final answer is valid.
"""

Expand Down Expand Up @@ -576,7 +576,7 @@ def _run_stream(
def _validate_final_answer(self, final_answer: Any):
for check_function in self.final_answer_checks:
try:
assert check_function(final_answer, self.memory)
assert check_function(final_answer, self.memory, self.state)
except Exception as e:
raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def weather_api(location: str, celsius: str = "") -> str:
def test_final_answer_checks(self):
error_string = "failed with error"

def check_always_fails(final_answer, agent_memory):
def check_always_fails(final_answer, agent_memory, agent_state):
assert False, "Error raised in check"

agent = CodeAgent(model=FakeCodeModel(), tools=[], final_answer_checks=[check_always_fails])
Expand All @@ -681,7 +681,7 @@ def check_always_fails(final_answer, agent_memory):
agent = CodeAgent(
model=FakeCodeModel(),
tools=[],
final_answer_checks=[lambda x, y: x == 7.2904],
final_answer_checks=[lambda x, y, z: x == 7.2904],
verbosity_level=1000,
)
output = agent.run("Dummy task.")
Expand Down