Skip to content
71 changes: 49 additions & 22 deletions effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from effectful.handlers.llm.encoding import DecodedToolCall, Encodable
from effectful.handlers.llm.template import Template, Tool
from effectful.internals.unification import nested_type
from effectful.ops.semantics import fwd, handler
from effectful.ops.semantics import _simple_type, fwd, handler
from effectful.ops.syntax import ObjectInterpretation, implements
from effectful.ops.types import Operation

Expand Down Expand Up @@ -157,7 +157,7 @@ def to_feedback_message(self, include_traceback: bool) -> Message:
)


type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None]
type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None, bool]


@Operation.define
Expand Down Expand Up @@ -217,6 +217,24 @@ def call_assistant[T, U](
for raw_tool_call in raw_tool_calls:
try:
tool_calls += [encoding.decode(raw_tool_call)] # type: ignore
if tool_calls[-1].is_final:
if len(raw_tool_calls) > 1:
raise ValueError(
f"IsFinal tool '{raw_tool_call.function.name}' must be the "
f"only tool call in a round, but {len(raw_tool_calls)} tool calls "
f"were generated."
)
# Validate that the tool's return type matches the template's.
tool_sig = inspect.signature(tool_calls[-1].tool)
return_annotation = typing.get_args(tool_sig.return_annotation)[0]
if not issubclass(
_simple_type(return_annotation), response_format.base
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might cause trouble in cases where we use IsFinal with return_annotation that doesn't match the outer template.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, this script fail due to mismatching between final_text return type and Template return type. But I guess it's ok.

from typing import Annotated

from effectful.handlers.llm import Template, Tool
from effectful.handlers.llm.completions import (
    LiteLLMProvider,
    ToolCallDecodingError,
    completion,
)
from effectful.handlers.llm.template import IsFinal
from effectful.ops.semantics import handler
from effectful.ops.syntax import ObjectInterpretation, implements
from effectful.ops.types import NotHandled
from tests.test_handlers_llm_template import make_tool_call_response


@Tool.define
def final_text() -> Annotated[str, IsFinal]:
    """Return a final text result."""
    return "123"


@Template.define
def task() -> int:
    """Call final_text."""
    raise NotHandled


with handler(LiteLLMProvider(model="gpt-4o-mini")):
    task()

Result:

  File "/Users/nguyendat/Marc/effectful/effectful/handlers/llm/completions.py", line 239, in call_assistant
    raise ToolCallDecodingError(
effectful.handlers.llm.completions.ToolCallDecodingError: Error decoding tool call 'final_text': IsFinal tool 'final_text' has signature <Signature () -> Annotated[str, <effectful.handlers.llm.template._IsFinalAnnotation object at 0x100d3d100>]>, but the enclosing template expects <class 'int'>.. Please provide a valid response and try again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is more troublesome, but it is also because python forbids class check. Still, this would work fine if we don't have the check there for IsFinal.

from typing import Annotated, TypedDict

from effectful.handlers.llm import Template, Tool
from effectful.handlers.llm.completions import (
    LiteLLMProvider,
    ToolCallDecodingError,
    completion,
)
from effectful.handlers.llm.template import IsFinal
from effectful.ops.semantics import handler
from effectful.ops.syntax import ObjectInterpretation, implements
from effectful.ops.types import NotHandled
from tests.test_handlers_llm_template import make_tool_call_response


class Payload(TypedDict):
    x: int


@Tool.define
def final_payload() -> Annotated[Payload, IsFinal]:
    """Return final payload."""
    return {"x": 1}


@Template.define
def task() -> Payload:
    """Call final_payload."""
    raise NotHandled


with handler(LiteLLMProvider(model="gpt-4o-mini")):
    task()

Result:

  File "/Users/nguyendat/Marc/effectful/effectful/handlers/llm/completions.py", line 239, in call_assistant
    raise ToolCallDecodingError(
effectful.handlers.llm.completions.ToolCallDecodingError: Error decoding tool call 'final_payload': TypedDict does not support instance and class checks. Please provide a valid response and try again.
(effectful) ➜  effectful git:(eb-final-answer) ✗ python effectful/handlers/llm/repro_typed_dict.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or Literal:

from typing import Annotated, Literal, TypedDict

from effectful.handlers.llm import Template, Tool
from effectful.handlers.llm.completions import (
    LiteLLMProvider,
    ToolCallDecodingError,
    completion,
)
from effectful.handlers.llm.template import IsFinal
from effectful.ops.semantics import handler
from effectful.ops.syntax import ObjectInterpretation, implements
from effectful.ops.types import NotHandled
from tests.test_handlers_llm_template import make_tool_call_response


@Tool.define
def final_payload() -> Annotated[Literal[1, 2, 3], IsFinal]:
    """Return final payload."""
    return 1


@Template.define
def task() -> Literal[1, 2, 3]:
    """Call final_payload."""
    raise NotHandled


with handler(LiteLLMProvider(model="gpt-4o-mini")):
    task()

Result:

File "/Users/nguyendat/Marc/effectful/effectful/handlers/llm/completions.py", line 239, in call_assistant
    raise ToolCallDecodingError(
effectful.handlers.llm.completions.ToolCallDecodingError: Error decoding tool call 'final_payload': Subscripted generics cannot be used with class and instance checks. Please provide a valid response and try again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior on these examples is consistent with the design choices laid out in the PR description, so those choices probably need to be revisited. For maximum flexibility we might want to let the LLM choose whether a tool call is final, instead of relying solely on the annotation as in this PR. For example, we could inject a fake is_final argument into every tool schema sent to the LLM and read off its value from the tool call request. We should probably also collect a few more examples like this that reflect more realistic use cases of this behavior.

raise TypeError(
f"IsFinal tool '{raw_tool_call.function.name}' has signature "
f"{tool_sig!r}, but the enclosing template expects "
f"{response_format.base!r}."
)
except Exception as e:
raise ToolCallDecodingError(
raw_tool_call=raw_tool_call,
Expand All @@ -238,33 +256,35 @@ def call_assistant[T, U](
except (pydantic.ValidationError, TypeError, ValueError, SyntaxError) as e:
raise ResultDecodingError(e, raw_message=raw_message) from e

return (raw_message, tool_calls, result)
is_final = any(tc.is_final for tc in tool_calls) or not tool_calls
return (raw_message, tool_calls, result, is_final)


@Operation.define
def call_tool(tool_call: DecodedToolCall) -> Message:
"""Implements a roundtrip call to a python function. Input is a json
string representing an LLM tool call request parameters. The output is
the serialised response to the model.

def call_tool[T](tool_call: DecodedToolCall[T]) -> tuple[Message, T | None, bool]:
"""Execute a tool and return the serialised message, the raw result, and
whether this result is a final answer.

Returns:
A 3-tuple ``(message, result, is_final)``. ``message`` is appended
to the conversation history. When ``is_final`` is ``True`` the
completion loop uses ``result`` directly as the template return value.
"""
# call tool with python types
try:
result = tool_call.tool(
result: T = tool_call.tool(
*tool_call.bound_args.args, **tool_call.bound_args.kwargs
)
except Exception as e:
raise ToolCallExecutionError(raw_tool_call=tool_call, original_error=e) from e

return_type = Encodable.define(
typing.cast(type[typing.Any], nested_type(result).value)
)
encoded_result = return_type.serialize(return_type.encode(result))
return_type = Encodable.define(nested_type(result).value) # type: ignore
encoded_result = return_type.serialize(return_type.encode(result)) # type: ignore
message = _make_message(
dict(role="tool", content=encoded_result, tool_call_id=tool_call.id),
)
append_message(message)
return message
return message, result, tool_call.is_final


@Operation.define
Expand Down Expand Up @@ -417,26 +437,32 @@ def _attempt() -> MessageResult[T]:
return fwd(tools, response_format, model, **kwargs)

with handler({_get_history: lambda: _message_sequence}):
message, tool_calls, result = self.call_assistant_retryer(_attempt)
message, tool_calls, result, is_final = self.call_assistant_retryer(
_attempt
)

append_message(message)
return (message, tool_calls, result)
return (message, tool_calls, result, is_final)

@implements(call_tool)
def _call_tool(self, tool_call: DecodedToolCall) -> Message:
def _call_tool[T](
self, tool_call: DecodedToolCall[T]
) -> tuple[Message, T | None, bool]:
"""Handle tool execution with runtime error capture.

Runtime errors from tool execution are captured and returned as
error messages to the LLM. Only exceptions matching `catch_tool_errors`
are caught; others propagate up.
are caught; others propagate up. When an error is caught,
``is_final`` is always ``False`` so the error feedback goes back
to the LLM rather than being mistaken for a final answer.
"""
try:
return fwd(tool_call)
except ToolCallExecutionError as e:
if isinstance(e.original_error, self.catch_tool_errors):
message = e.to_feedback_message(self.include_traceback)
append_message(message)
return message
return message, None, False
else:
raise

Expand Down Expand Up @@ -477,12 +503,13 @@ def _call[**P, T](
# loop based on: https://cookbook.openai.com/examples/reasoning_function_calls
tool_calls: list[DecodedToolCall] = []
result: T | None = None
while message["role"] != "assistant" or tool_calls:
message, tool_calls, result = call_assistant(
is_final: bool = False
while not is_final:
message, tool_calls, result, is_final = call_assistant(
template.tools, response_model, **self.config
)
for tool_call in tool_calls:
message = call_tool(tool_call)
message, result, is_final = call_tool(tool_call)

try:
_get_history()
Expand Down
9 changes: 8 additions & 1 deletion effectful/handlers/llm/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from PIL import Image

import effectful.handlers.llm.evaluation as evaluation
from effectful.handlers.llm.template import Tool
from effectful.handlers.llm.template import Tool, _IsFinalAnnotation
from effectful.internals.unification import nested_type
from effectful.ops.semantics import _simple_type
from effectful.ops.syntax import _CustomSingleDispatchCallable
Expand Down Expand Up @@ -61,6 +61,13 @@ class DecodedToolCall[T]:
id: ToolCallID
name: str

@property
def is_final(self) -> bool:
ret = inspect.signature(self.tool).return_annotation
return typing.get_origin(ret) is typing.Annotated and any(
isinstance(arg, _IsFinalAnnotation) for arg in ret.__metadata__
)


class Encodable[T, U](ABC):
base: type[T]
Expand Down
48 changes: 46 additions & 2 deletions effectful/handlers/llm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def factorial(n: int) -> Annotated[int, IsRecursive]:

@classmethod
def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature:
for name, ty in sig.parameters.items():
if not ty or not typing.get_origin(ty) is Annotated:
for name, param in sig.parameters.items():
ty = param.annotation
if ty is inspect.Parameter.empty or typing.get_origin(ty) is not Annotated:
continue
if any(isinstance(arg, cls) for arg in typing.get_args(ty)):
raise TypeError(
Expand All @@ -62,6 +63,48 @@ def _is_recursive_signature(sig: inspect.Signature):
return any(annotation is IsRecursive for annotation in annotations)


class _IsFinalAnnotation(Annotation):
"""
A special type annotation for return types in the signature of a
:class:`Tool` that indicates its result should be returned directly
as the final answer of the enclosing :class:`Template`, skipping
the final LLM API call.

.. warning::

:class:`IsFinal` annotations are only defined to ascribe
return annotations, and if used in a parameter will raise a
:class:`TypeError` at tool construction time.

**Example usage**::

>>> from typing import Annotated
>>> from effectful.handlers.llm import Tool
>>> from effectful.handlers.llm.template import IsFinal

>>> @Tool.define
... def generate(prompt: str) -> Annotated[str, IsFinal]:
... \"""Generate content for the given prompt.\"""
... return "generated content"
"""

@classmethod
def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature:
for name, param in sig.parameters.items():
ty = param.annotation
if ty is inspect.Parameter.empty or typing.get_origin(ty) is not Annotated:
continue
if any(isinstance(arg, cls) for arg in typing.get_args(ty)):
raise TypeError(
f"Illegal annotation {ty} for parameter {name}, "
"IsFinal must only be used to annotate return types."
)
return sig


IsFinal = _IsFinalAnnotation()


class Tool[**P, T](Operation[P, T]):
"""A :class:`Tool` is a function that may be called by a :class:`Template`.

Expand Down Expand Up @@ -96,6 +139,7 @@ def __init__(
if not default.__doc__:
raise ValueError("Tools must have docstrings.")
signature = IsRecursive.infer_annotations(signature)
signature = IsFinal.infer_annotations(signature)
super().__init__(signature, name, default)

@classmethod
Expand Down
30 changes: 15 additions & 15 deletions tests/test_handlers_llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def test_retry_handler_succeeds_on_first_attempt(self):
handler(mock_handler),
handler(message_sequence_provider),
):
message, tool_calls, result = call_assistant(
message, tool_calls, result, _ = call_assistant(
tools={},
response_format=Encodable.define(str),
model="test-model",
Expand Down Expand Up @@ -537,7 +537,7 @@ def test_retry_handler_retries_on_invalid_tool_call(self):
handler(mock_handler),
handler(message_sequence_provider),
):
message, tool_calls, result = call_assistant(
message, tool_calls, result, _ = call_assistant(
tools={"add_numbers": add_numbers},
response_format=Encodable.define(str),
model="test-model",
Expand Down Expand Up @@ -569,7 +569,7 @@ def test_retry_handler_retries_on_unknown_tool(self):
handler(mock_handler),
handler(message_sequence_provider),
):
message, tool_calls, result = call_assistant(
message, tool_calls, result, _ = call_assistant(
tools={"add_numbers": add_numbers},
response_format=Encodable.define(str),
model="test-model",
Expand Down Expand Up @@ -648,7 +648,7 @@ def test_retry_handler_valid_tool_call_passes_through(self):
handler(mock_handler),
handler(message_sequence_provider),
):
message, tool_calls, result = call_assistant(
message, tool_calls, result, _ = call_assistant(
tools={"add_numbers": add_numbers},
response_format=Encodable.define(str),
model="test-model",
Expand Down Expand Up @@ -723,7 +723,7 @@ def test_retry_handler_retries_on_invalid_result(self):
handler(mock_handler),
handler(message_sequence_provider),
):
message, tool_calls, result = call_assistant(
message, tool_calls, result, _ = call_assistant(
tools={},
response_format=Encodable.define(int),
model="test-model",
Expand Down Expand Up @@ -973,7 +973,7 @@ def test_retry_handler_catches_tool_runtime_error(self):
tool_call = DecodedToolCall(failing_tool, bound_args, "call_1", "failing_tool")

with handler(RetryLLMHandler()):
result = call_tool(tool_call)
result, _, _ = call_tool(tool_call)

# The result should be an error message, not an exception
assert result["role"] == "tool"
Expand All @@ -990,7 +990,7 @@ def test_retry_handler_catches_division_by_zero(self):
tool_call = DecodedToolCall(divide_tool, bound_args, "call_div", "divide_tool")

with handler(RetryLLMHandler()):
result = call_tool(tool_call)
result, _, _ = call_tool(tool_call)

assert result["role"] == "tool"
assert result["tool_call_id"] == "call_div"
Expand All @@ -1005,7 +1005,7 @@ def test_successful_tool_execution_returns_result(self):
tool_call = DecodedToolCall(add_numbers, bound_args, "call_add", "add_numbers")

with handler(RetryLLMHandler()):
result = call_tool(tool_call)
result, _, _ = call_tool(tool_call)

assert result["role"] == "tool"
assert result["tool_call_id"] == "call_add"
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def _call_assistant(self, tools, response_format, model, **kwargs):
handler(mock_handler),
handler(message_sequence_provider),
):
message, tool_calls, result = call_assistant(
message, tool_calls, result, _ = call_assistant(
tools={"failing_tool": failing_tool},
response_format=Encodable.define(str),
model="test-model",
Expand Down Expand Up @@ -1417,13 +1417,13 @@ def _completion(self_, model, messages, *args, **kwargs):
handler({_get_history: lambda: message_sequence}),
):
# First call: input is the latest message (msg_user)
resp1, _, _ = call_assistant(
resp1, _, _, _ = call_assistant(
tools={},
response_format=Encodable.define(str),
model="test-model",
)
# Second call: input is the first response
resp2, _, _ = call_assistant(
resp2, _, _, _ = call_assistant(
tools={},
response_format=Encodable.define(str),
model="test-model",
Expand Down Expand Up @@ -1588,7 +1588,7 @@ def test_call_tool_success_does_not_raise(self):
bound_args = sig.bind(a=3, b=4)
tc = DecodedToolCall(add_numbers, bound_args, "call_ok", "add_numbers")

result = call_tool(tc)
result, _, _ = call_tool(tc)
assert result["role"] == "tool"
assert result["tool_call_id"] == "call_ok"

Expand All @@ -1603,7 +1603,7 @@ def test_matching_error_returns_feedback_message(self):
tc = DecodedToolCall(flaky_tool, bound_args, "call_match", "flaky_tool")

with handler(RetryLLMHandler(catch_tool_errors=ConnectionError)):
result = call_tool(tc)
result, _, _ = call_tool(tc)

assert result["role"] == "tool"
assert result["tool_call_id"] == "call_match"
Expand Down Expand Up @@ -1632,7 +1632,7 @@ def test_default_catch_all_catches_everything(self):
)

with handler(RetryLLMHandler()):
result = call_tool(tc)
result, _, _ = call_tool(tc)

assert result["role"] == "tool"
assert "Tool execution failed" in result["content"]
Expand All @@ -1648,7 +1648,7 @@ def test_tuple_of_error_types(self):
catch_tool_errors=(ConnectionError, ValueError),
)
):
result = call_tool(tc)
result, _, _ = call_tool(tc)

assert result["role"] == "tool"
assert "Tool execution failed" in result["content"]
Expand Down
Loading
Loading