diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index 609ba93e..900b55d3 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -27,7 +27,7 @@ from effectful.handlers.llm.encoding import DecodedToolCall, Encodable from effectful.handlers.llm.template import Template, Tool from effectful.internals.unification import nested_type -from effectful.ops.semantics import fwd, handler +from effectful.ops.semantics import _simple_type, fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import Operation @@ -157,7 +157,7 @@ def to_feedback_message(self, include_traceback: bool) -> Message: ) -type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None] +type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None, bool] @Operation.define @@ -217,6 +217,24 @@ def call_assistant[T, U]( for raw_tool_call in raw_tool_calls: try: tool_calls += [encoding.decode(raw_tool_call)] # type: ignore + if tool_calls[-1].is_final: + if len(raw_tool_calls) > 1: + raise ValueError( + f"IsFinal tool '{raw_tool_call.function.name}' must be the " + f"only tool call in a round, but {len(raw_tool_calls)} tool calls " + f"were generated." + ) + # Validate that the tool's return type matches the template's. + tool_sig = inspect.signature(tool_calls[-1].tool) + return_annotation = typing.get_args(tool_sig.return_annotation)[0] + if not issubclass( + _simple_type(return_annotation), response_format.base + ): + raise TypeError( + f"IsFinal tool '{raw_tool_call.function.name}' has signature " + f"{tool_sig!r}, but the enclosing template expects " + f"{response_format.base!r}." + ) except Exception as e: raise ToolCallDecodingError( raw_tool_call=raw_tool_call, @@ -238,33 +256,35 @@ def call_assistant[T, U]( except (pydantic.ValidationError, TypeError, ValueError, SyntaxError) as e: raise ResultDecodingError(e, raw_message=raw_message) from e - return (raw_message, tool_calls, result) + is_final = any(tc.is_final for tc in tool_calls) or not tool_calls + return (raw_message, tool_calls, result, is_final) @Operation.define -def call_tool(tool_call: DecodedToolCall) -> Message: - """Implements a roundtrip call to a python function. Input is a json - string representing an LLM tool call request parameters. The output is - the serialised response to the model. - +def call_tool[T](tool_call: DecodedToolCall[T]) -> tuple[Message, T | None, bool]: + """Execute a tool and return the serialised message, the raw result, and + whether this result is a final answer. + + Returns: + A 3-tuple ``(message, result, is_final)``. ``message`` is appended + to the conversation history. When ``is_final`` is ``True`` the + completion loop uses ``result`` directly as the template return value. """ # call tool with python types try: - result = tool_call.tool( + result: T = tool_call.tool( *tool_call.bound_args.args, **tool_call.bound_args.kwargs ) except Exception as e: raise ToolCallExecutionError(raw_tool_call=tool_call, original_error=e) from e - return_type = Encodable.define( - typing.cast(type[typing.Any], nested_type(result).value) - ) - encoded_result = return_type.serialize(return_type.encode(result)) + return_type = Encodable.define(nested_type(result).value) # type: ignore + encoded_result = return_type.serialize(return_type.encode(result)) # type: ignore message = _make_message( dict(role="tool", content=encoded_result, tool_call_id=tool_call.id), ) append_message(message) - return message + return message, result, tool_call.is_final @Operation.define @@ -417,18 +437,24 @@ def _attempt() -> MessageResult[T]: return fwd(tools, response_format, model, **kwargs) with handler({_get_history: lambda: _message_sequence}): - message, tool_calls, result = self.call_assistant_retryer(_attempt) + message, tool_calls, result, is_final = self.call_assistant_retryer( + _attempt + ) append_message(message) - return (message, tool_calls, result) + return (message, tool_calls, result, is_final) @implements(call_tool) - def _call_tool(self, tool_call: DecodedToolCall) -> Message: + def _call_tool[T]( + self, tool_call: DecodedToolCall[T] + ) -> tuple[Message, T | None, bool]: """Handle tool execution with runtime error capture. Runtime errors from tool execution are captured and returned as error messages to the LLM. Only exceptions matching `catch_tool_errors` - are caught; others propagate up. + are caught; others propagate up. When an error is caught, + ``is_final`` is always ``False`` so the error feedback goes back + to the LLM rather than being mistaken for a final answer. """ try: return fwd(tool_call) @@ -436,7 +462,7 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: if isinstance(e.original_error, self.catch_tool_errors): message = e.to_feedback_message(self.include_traceback) append_message(message) - return message + return message, None, False else: raise @@ -477,12 +503,13 @@ def _call[**P, T]( # loop based on: https://cookbook.openai.com/examples/reasoning_function_calls tool_calls: list[DecodedToolCall] = [] result: T | None = None - while message["role"] != "assistant" or tool_calls: - message, tool_calls, result = call_assistant( + is_final: bool = False + while not is_final: + message, tool_calls, result, is_final = call_assistant( template.tools, response_model, **self.config ) for tool_call in tool_calls: - message = call_tool(tool_call) + message, result, is_final = call_tool(tool_call) try: _get_history() diff --git a/effectful/handlers/llm/encoding.py b/effectful/handlers/llm/encoding.py index 40cd4b26..f6af514f 100644 --- a/effectful/handlers/llm/encoding.py +++ b/effectful/handlers/llm/encoding.py @@ -31,7 +31,7 @@ from PIL import Image import effectful.handlers.llm.evaluation as evaluation -from effectful.handlers.llm.template import Tool +from effectful.handlers.llm.template import Tool, _IsFinalAnnotation from effectful.internals.unification import nested_type from effectful.ops.semantics import _simple_type from effectful.ops.syntax import _CustomSingleDispatchCallable @@ -61,6 +61,13 @@ class DecodedToolCall[T]: id: ToolCallID name: str + @property + def is_final(self) -> bool: + ret = inspect.signature(self.tool).return_annotation + return typing.get_origin(ret) is typing.Annotated and any( + isinstance(arg, _IsFinalAnnotation) for arg in ret.__metadata__ + ) + class Encodable[T, U](ABC): base: type[T] diff --git a/effectful/handlers/llm/template.py b/effectful/handlers/llm/template.py index e8c24443..6ebf049b 100644 --- a/effectful/handlers/llm/template.py +++ b/effectful/handlers/llm/template.py @@ -42,8 +42,9 @@ def factorial(n: int) -> Annotated[int, IsRecursive]: @classmethod def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature: - for name, ty in sig.parameters.items(): - if not ty or not typing.get_origin(ty) is Annotated: + for name, param in sig.parameters.items(): + ty = param.annotation + if ty is inspect.Parameter.empty or typing.get_origin(ty) is not Annotated: continue if any(isinstance(arg, cls) for arg in typing.get_args(ty)): raise TypeError( @@ -62,6 +63,48 @@ def _is_recursive_signature(sig: inspect.Signature): return any(annotation is IsRecursive for annotation in annotations) +class _IsFinalAnnotation(Annotation): + """ + A special type annotation for return types in the signature of a + :class:`Tool` that indicates its result should be returned directly + as the final answer of the enclosing :class:`Template`, skipping + the final LLM API call. + + .. warning:: + + :class:`IsFinal` annotations are only defined to ascribe + return annotations, and if used in a parameter will raise a + :class:`TypeError` at tool construction time. + + **Example usage**:: + + >>> from typing import Annotated + >>> from effectful.handlers.llm import Tool + >>> from effectful.handlers.llm.template import IsFinal + + >>> @Tool.define + ... def generate(prompt: str) -> Annotated[str, IsFinal]: + ... \"""Generate content for the given prompt.\""" + ... return "generated content" + """ + + @classmethod + def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature: + for name, param in sig.parameters.items(): + ty = param.annotation + if ty is inspect.Parameter.empty or typing.get_origin(ty) is not Annotated: + continue + if any(isinstance(arg, cls) for arg in typing.get_args(ty)): + raise TypeError( + f"Illegal annotation {ty} for parameter {name}, " + "IsFinal must only be used to annotate return types." + ) + return sig + + +IsFinal = _IsFinalAnnotation() + + class Tool[**P, T](Operation[P, T]): """A :class:`Tool` is a function that may be called by a :class:`Template`. @@ -96,6 +139,7 @@ def __init__( if not default.__doc__: raise ValueError("Tools must have docstrings.") signature = IsRecursive.infer_annotations(signature) + signature = IsFinal.infer_annotations(signature) super().__init__(signature, name, default) @classmethod diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index 80bc879c..82a9981c 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -507,7 +507,7 @@ def test_retry_handler_succeeds_on_first_attempt(self): handler(mock_handler), handler(message_sequence_provider), ): - message, tool_calls, result = call_assistant( + message, tool_calls, result, _ = call_assistant( tools={}, response_format=Encodable.define(str), model="test-model", @@ -537,7 +537,7 @@ def test_retry_handler_retries_on_invalid_tool_call(self): handler(mock_handler), handler(message_sequence_provider), ): - message, tool_calls, result = call_assistant( + message, tool_calls, result, _ = call_assistant( tools={"add_numbers": add_numbers}, response_format=Encodable.define(str), model="test-model", @@ -569,7 +569,7 @@ def test_retry_handler_retries_on_unknown_tool(self): handler(mock_handler), handler(message_sequence_provider), ): - message, tool_calls, result = call_assistant( + message, tool_calls, result, _ = call_assistant( tools={"add_numbers": add_numbers}, response_format=Encodable.define(str), model="test-model", @@ -648,7 +648,7 @@ def test_retry_handler_valid_tool_call_passes_through(self): handler(mock_handler), handler(message_sequence_provider), ): - message, tool_calls, result = call_assistant( + message, tool_calls, result, _ = call_assistant( tools={"add_numbers": add_numbers}, response_format=Encodable.define(str), model="test-model", @@ -723,7 +723,7 @@ def test_retry_handler_retries_on_invalid_result(self): handler(mock_handler), handler(message_sequence_provider), ): - message, tool_calls, result = call_assistant( + message, tool_calls, result, _ = call_assistant( tools={}, response_format=Encodable.define(int), model="test-model", @@ -973,7 +973,7 @@ def test_retry_handler_catches_tool_runtime_error(self): tool_call = DecodedToolCall(failing_tool, bound_args, "call_1", "failing_tool") with handler(RetryLLMHandler()): - result = call_tool(tool_call) + result, _, _ = call_tool(tool_call) # The result should be an error message, not an exception assert result["role"] == "tool" @@ -990,7 +990,7 @@ def test_retry_handler_catches_division_by_zero(self): tool_call = DecodedToolCall(divide_tool, bound_args, "call_div", "divide_tool") with handler(RetryLLMHandler()): - result = call_tool(tool_call) + result, _, _ = call_tool(tool_call) assert result["role"] == "tool" assert result["tool_call_id"] == "call_div" @@ -1005,7 +1005,7 @@ def test_successful_tool_execution_returns_result(self): tool_call = DecodedToolCall(add_numbers, bound_args, "call_add", "add_numbers") with handler(RetryLLMHandler()): - result = call_tool(tool_call) + result, _, _ = call_tool(tool_call) assert result["role"] == "tool" assert result["tool_call_id"] == "call_add" @@ -1042,7 +1042,7 @@ def _call_assistant(self, tools, response_format, model, **kwargs): handler(mock_handler), handler(message_sequence_provider), ): - message, tool_calls, result = call_assistant( + message, tool_calls, result, _ = call_assistant( tools={"failing_tool": failing_tool}, response_format=Encodable.define(str), model="test-model", @@ -1417,13 +1417,13 @@ def _completion(self_, model, messages, *args, **kwargs): handler({_get_history: lambda: message_sequence}), ): # First call: input is the latest message (msg_user) - resp1, _, _ = call_assistant( + resp1, _, _, _ = call_assistant( tools={}, response_format=Encodable.define(str), model="test-model", ) # Second call: input is the first response - resp2, _, _ = call_assistant( + resp2, _, _, _ = call_assistant( tools={}, response_format=Encodable.define(str), model="test-model", @@ -1588,7 +1588,7 @@ def test_call_tool_success_does_not_raise(self): bound_args = sig.bind(a=3, b=4) tc = DecodedToolCall(add_numbers, bound_args, "call_ok", "add_numbers") - result = call_tool(tc) + result, _, _ = call_tool(tc) assert result["role"] == "tool" assert result["tool_call_id"] == "call_ok" @@ -1603,7 +1603,7 @@ def test_matching_error_returns_feedback_message(self): tc = DecodedToolCall(flaky_tool, bound_args, "call_match", "flaky_tool") with handler(RetryLLMHandler(catch_tool_errors=ConnectionError)): - result = call_tool(tc) + result, _, _ = call_tool(tc) assert result["role"] == "tool" assert result["tool_call_id"] == "call_match" @@ -1632,7 +1632,7 @@ def test_default_catch_all_catches_everything(self): ) with handler(RetryLLMHandler()): - result = call_tool(tc) + result, _, _ = call_tool(tc) assert result["role"] == "tool" assert "Tool execution failed" in result["content"] @@ -1648,7 +1648,7 @@ def test_tuple_of_error_types(self): catch_tool_errors=(ConnectionError, ValueError), ) ): - result = call_tool(tc) + result, _, _ = call_tool(tc) assert result["role"] == "tool" assert "Tool execution failed" in result["content"] diff --git a/tests/test_handlers_llm_template.py b/tests/test_handlers_llm_template.py index f8e954c0..1198637d 100644 --- a/tests/test_handlers_llm_template.py +++ b/tests/test_handlers_llm_template.py @@ -4,6 +4,7 @@ import dataclasses import inspect from dataclasses import dataclass +from typing import Annotated import pytest from litellm import ModelResponse @@ -13,9 +14,15 @@ DEFAULT_SYSTEM_PROMPT, LiteLLMProvider, RetryLLMHandler, + ToolCallDecodingError, + _get_history, + call_assistant, + call_tool, call_user, completion, ) +from effectful.handlers.llm.encoding import DecodedToolCall, Encodable +from effectful.handlers.llm.template import IsFinal from effectful.ops.semantics import handler from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import NotHandled @@ -1518,3 +1525,262 @@ def test_validate_format_spec_on_undefined_var(): def bad(x: int) -> str: """Value: {x} and {missing:.2f}.""" raise NotHandled + + +# --------------------------------------------------------------------------- +# IsFinal annotation tests +# --------------------------------------------------------------------------- + + +class TestIsFinalCompletionLoop: + """Tests for IsFinal through the full completion loop.""" + + def test_final_answer_tool_skips_final_llm_call(self): + """When LLM calls a final-answer tool, result is returned + directly without a second call_assistant invocation.""" + + @Tool.define + def compute(x: int) -> Annotated[int, IsFinal]: + """Compute and return the result directly.""" + return x * 10 + + @Template.define + def task(n: int) -> int: + """Call compute with {n}.""" + raise NotHandled + + mock = MockCompletionHandler( + [make_tool_call_response("compute", '{"x": {"value": 7}}')] + ) + + with handler(LiteLLMProvider()), handler(mock): + result = task(7) + + assert result == 70 + # Only 1 call_assistant, not 2 (no final LLM round-trip) + assert mock.call_count == 1 + + def test_agent_history_valid_after_final_answer(self): + """Agent history has no orphaned tool_calls after IsFinal.""" + + @Tool.define + def final_tool(x: int) -> Annotated[int, IsFinal]: + """Return final answer.""" + return x + + @dataclasses.dataclass + class MyAgent(Agent): + @Template.define + def do_work(self, n: int) -> int: + """Process {n}.""" + raise NotHandled + + mock = MockCompletionHandler( + [make_tool_call_response("final_tool", '{"x": {"value": 5}}')] + ) + agent = MyAgent() + + with handler(LiteLLMProvider()), handler(mock): + result = agent.do_work(5) + + assert result == 5 + + # Verify no orphaned tool_calls in history + for msg in agent.__history__.values(): + tool_calls = msg.get("tool_calls") + if tool_calls: + for tc in tool_calls: + tc_id = tc["id"] if isinstance(tc, dict) else tc.id + has_response = any( + m.get("tool_call_id") == tc_id + for m in agent.__history__.values() + if m.get("role") == "tool" + ) + assert has_response, f"Orphaned tool_call {tc_id} in history" + + def test_agent_subsequent_call_after_final_answer(self): + """A follow-up call on the same Agent works after IsFinal.""" + + @Tool.define + def final_tool() -> Annotated[str, IsFinal]: + """Return final answer.""" + return "direct result" + + @dataclasses.dataclass + class MyAgent(Agent): + @Template.define + def step(self, msg: str) -> str: + """Do: {msg}""" + raise NotHandled + + call_count = 0 + + class PhaseHandler(ObjectInterpretation): + @implements(completion) + def _completion(self, model, messages=None, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return make_tool_call_response("final_tool", "{}") + return make_text_response("llm result") + + agent = MyAgent() + + with handler(LiteLLMProvider()), handler(PhaseHandler()): + r1 = agent.step("first") + r2 = agent.step("second") + + assert r1 == "direct result" + assert r2 == "llm result" + + def test_final_answer_with_retry_handler_active(self): + """IsFinal works correctly with RetryLLMHandler.""" + + @Tool.define + def final_tool(x: int) -> Annotated[int, IsFinal]: + """Return final answer.""" + return x * 3 + + @Template.define + def task(n: int) -> int: + """Call final_tool with {n}.""" + raise NotHandled + + mock = MockCompletionHandler( + [make_tool_call_response("final_tool", '{"x": {"value": 4}}')] + ) + + with ( + handler(LiteLLMProvider()), + handler(RetryLLMHandler()), + handler(mock), + ): + result = task(4) + + assert result == 12 + assert mock.call_count == 1 + + def test_retry_handler_error_on_final_tool_does_not_produce_final_answer(self): + """When RetryLLMHandler catches an error on an is_final tool, + the error feedback goes back to the LLM instead of None being + returned as the final answer.""" + call_count = 0 + + @Tool.define + def flaky_final(x: int) -> Annotated[int, IsFinal]: + """Return a final answer, but fail on first call.""" + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("transient failure") + return x * 10 + + @Template.define + def task(n: int) -> int: + """Call flaky_final with {n}.""" + raise NotHandled + + # Round 1: LLM calls flaky_final → error caught by RetryLLMHandler + # Round 2: LLM calls flaky_final again → succeeds + mock = MockCompletionHandler( + [ + make_tool_call_response("flaky_final", '{"x": {"value": 5}}'), + make_tool_call_response("flaky_final", '{"x": {"value": 5}}'), + ] + ) + + with ( + handler(LiteLLMProvider()), + handler(RetryLLMHandler()), + handler(mock), + ): + result = task(5) + + assert result == 50 # NOT None + assert call_count == 2 + assert mock.call_count == 2 + + def test_call_tool_returns_is_final_false_on_retry_handler_error(self): + """call_tool returns is_final=False when RetryLLMHandler catches + an error on an is_final tool.""" + + @Tool.define + def failing_final(x: int) -> Annotated[int, IsFinal]: + """Return a final answer.""" + raise ValueError("boom") + + sig = inspect.signature(failing_final) + bound_args = sig.bind(x=1) + tc = DecodedToolCall( + failing_final, bound_args, id="call_err", name="failing_final" + ) + + with handler(RetryLLMHandler()): + message, raw_result, is_final = call_tool(tc) + + assert message["role"] == "tool" + assert message["content"] # non-empty error feedback + assert raw_result is None + assert is_final is False + + def test_mismatched_return_type_raises_tool_call_decoding_error(self): + """IsFinal tool returning str when template expects int is rejected.""" + + @Tool.define + def wrong_type_tool(x: int) -> Annotated[str, IsFinal]: + """Return a string, but template expects int.""" + return str(x) + + message_sequence = collections.OrderedDict( + id1={"id": "id1", "role": "user", "content": "test"}, + ) + + mock = MockCompletionHandler( + [ + make_tool_call_response("wrong_type_tool", '{"x": {"value": 5}}'), + ] + ) + + with ( + handler(mock), + handler({_get_history: lambda: message_sequence}), + ): + with pytest.raises(ToolCallDecodingError) as exc_info: + call_assistant( + tools={"wrong_type_tool": wrong_type_tool}, + response_format=Encodable.define(int), + model="test-model", + ) + + assert isinstance(exc_info.value.original_error, TypeError) + + def test_matching_return_type_passes_validation(self): + """IsFinal tool with matching return type is accepted.""" + + @Tool.define + def correct_tool(x: int) -> Annotated[int, IsFinal]: + """Return an int matching template.""" + return x * 2 + + message_sequence = collections.OrderedDict( + id1={"id": "id1", "role": "user", "content": "test"}, + ) + + mock = MockCompletionHandler( + [ + make_tool_call_response("correct_tool", '{"x": {"value": 5}}'), + ] + ) + + with ( + handler(mock), + handler({_get_history: lambda: message_sequence}), + ): + _, tool_calls, _, is_final = call_assistant( + tools={"correct_tool": correct_tool}, + response_format=Encodable.define(int), + model="test-model", + ) + + assert len(tool_calls) == 1 + assert is_final is True