diff --git a/docs/source/en/guided_tour.md b/docs/source/en/guided_tour.md index 14fba24be..13c2b9c65 100644 --- a/docs/source/en/guided_tour.md +++ b/docs/source/en/guided_tour.md @@ -332,7 +332,7 @@ The `final_answer_checks` parameter gives you more control over when and how an from smolagents import CodeAgent, InferenceClientModel # Define a custom final answer check function -def is_integer(final_answer: str, agent_memory=None) -> bool: +def is_integer(final_answer: str, agent_memory=None, agent_state=None) -> bool: """Return True if final_answer is an integer.""" try: int(final_answer) @@ -351,7 +351,7 @@ agent.run("Calculate the least common multiple of 3 and 7") ``` The `final_answer_checks` parameter accepts a list of functions that each: -- Take the agent's final_answer string the agent's memory as parameters +- Take the agent's final_answer string, the agent's memory and the agent's state as parameters - Return a boolean indicating whether the final_answer is valid (True) or not (False) If any function returns `False`, the agent will log the error message and continue the run. diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 09dcba75b..839448c60 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -253,7 +253,7 @@ class MultiStepAgent(ABC): provide_run_summary (`bool`, *optional*): Whether to provide a run summary when called as a managed agent. final_answer_checks (`list[Callable]`, *optional*): List of validation functions to run before accepting a final answer. Each function should: - - Take the final answer and the agent's memory as arguments. + - Take the final answer, the agent's memory and the agent's state as arguments. - Return a boolean indicating whether the final answer is valid. """ @@ -576,7 +576,7 @@ def _run_stream( def _validate_final_answer(self, final_answer: Any): for check_function in self.final_answer_checks: try: - assert check_function(final_answer, self.memory) + assert check_function(final_answer, self.memory, self.state) except Exception as e: raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger) diff --git a/tests/test_agents.py b/tests/test_agents.py index 0e2452e0b..9883e3cbb 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -670,7 +670,7 @@ def weather_api(location: str, celsius: str = "") -> str: def test_final_answer_checks(self): error_string = "failed with error" - def check_always_fails(final_answer, agent_memory): + def check_always_fails(final_answer, agent_memory, agent_state): assert False, "Error raised in check" agent = CodeAgent(model=FakeCodeModel(), tools=[], final_answer_checks=[check_always_fails]) @@ -681,7 +681,7 @@ def check_always_fails(final_answer, agent_memory): agent = CodeAgent( model=FakeCodeModel(), tools=[], - final_answer_checks=[lambda x, y: x == 7.2904], + final_answer_checks=[lambda x, y, z: x == 7.2904], verbosity_level=1000, ) output = agent.run("Dummy task.")