Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
06ed7ff
Move logic from completion to encoding
eb8680 Feb 9, 2026
3e89ab9
errors
eb8680 Feb 9, 2026
bf3ce15
attempt to appease mypy
eb8680 Feb 9, 2026
3cb5c13
nit
eb8680 Feb 9, 2026
efb3c68
import
eb8680 Feb 9, 2026
fd8f8c8
nit
eb8680 Feb 9, 2026
a608b54
fix
eb8680 Feb 9, 2026
0fcd308
nit
eb8680 Feb 9, 2026
161c802
Merge branch 'master' into eb-remove-response
eb8680 Feb 13, 2026
304ce5f
simplify and add tests
eb8680 Feb 13, 2026
f2e2923
update tests
eb8680 Feb 13, 2026
a4d27b2
Merge branch 'master' into eb-remove-response
eb8680 Feb 13, 2026
dd06808
tests
eb8680 Feb 14, 2026
37a7564
more consolidation
eb8680 Feb 15, 2026
97f5715
stash
eb8680 Feb 15, 2026
ab9097d
refactor tests
eb8680 Feb 15, 2026
e036596
Fix Encodable for Dataclass and Scalar in #548 (#571)
datvo06 Feb 21, 2026
33ab0bc
simplifications
eb8680 Feb 21, 2026
87c9160
stash
eb8680 Feb 22, 2026
47c19f3
stash
eb8680 Feb 22, 2026
a367d76
mostly works?
eb8680 Feb 22, 2026
5fa18d7
str
eb8680 Feb 23, 2026
3727c99
Merge branch 'master' into eb-pydantic-encodable
eb8680 Feb 23, 2026
09cc238
fix merge
eb8680 Feb 23, 2026
48b0d97
add tests for TypeEvaluator
eb8680 Feb 23, 2026
d01fe1e
concretize
eb8680 Feb 23, 2026
f6f31e2
inline response_format
eb8680 Feb 23, 2026
ec10beb
remove deserialize
eb8680 Feb 23, 2026
346e42f
Merge branch 'master' into eb-pydantic-encodable
eb8680 Feb 24, 2026
9d94906
to_content_blocks
eb8680 Feb 24, 2026
a4ece4b
validation
eb8680 Feb 24, 2026
cac472d
test cases
eb8680 Feb 24, 2026
32acb08
bare strings
eb8680 Feb 24, 2026
6ef6fd2
inline
eb8680 Feb 25, 2026
3a8cfcb
Merge branch 'master' into eb-pydantic-encodable
eb8680 Mar 4, 2026
7e2c6da
fix merge
eb8680 Mar 4, 2026
32c4719
remove serialize
eb8680 Mar 4, 2026
0b5b449
remove encode() from library code
eb8680 Mar 4, 2026
fb9f5fc
remoev decode from library code
eb8680 Mar 4, 2026
f12c327
remove define from library code
eb8680 Mar 4, 2026
25405fd
remove encodable except as type predicate
eb8680 Mar 4, 2026
3c3272d
move tool collection
eb8680 Mar 4, 2026
facd734
fix
eb8680 Mar 4, 2026
e5f1dca
nit
eb8680 Mar 4, 2026
1ca5793
make tests pass
eb8680 Mar 5, 2026
98cfdcd
fix types
eb8680 Mar 5, 2026
8675447
model
eb8680 Mar 5, 2026
1008172
appease mypy on ci?
eb8680 Mar 5, 2026
09c125d
wrapping
eb8680 Mar 5, 2026
86b3b9d
try to fix ci failures.......
eb8680 Mar 5, 2026
9196171
hand off claudes nonsense to ci and give up
eb8680 Mar 5, 2026
541f812
in which claude yanks away the football
eb8680 Mar 5, 2026
5935c4f
remove more claude nonsense
eb8680 Mar 5, 2026
b454a0f
lint
eb8680 Mar 5, 2026
e62d0e5
one more round of claudes insane dream logic
eb8680 Mar 5, 2026
0d7e2c4
revert another insane design choice by claude
eb8680 Mar 5, 2026
808f794
various fixes
eb8680 Mar 6, 2026
052136e
Merge branch 'master' into eb-pydantic-encodable
eb8680 Mar 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 116 additions & 58 deletions effectful/handlers/llm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dataclasses
import functools
import inspect
import json
import string
import textwrap
import traceback
Expand All @@ -24,8 +25,17 @@
OpenAIMessageContentListBlock,
)

from effectful.handlers.llm.encoding import DecodedToolCall, Encodable
from effectful.handlers.llm.template import Template, Tool
from effectful.handlers.llm.encoding import (
DecodedToolCall,
Encodable,
to_content_blocks,
)
from effectful.handlers.llm.template import (
Agent,
Template,
Tool,
_is_recursive_signature,
)
from effectful.internals.unification import nested_type
from effectful.ops.semantics import fwd, handler
from effectful.ops.syntax import ObjectInterpretation, implements
Expand Down Expand Up @@ -59,15 +69,23 @@ class UserMessage(OpenAIChatCompletionUserMessage):
)


class _NoActiveHistoryException(Exception):
"""Raised when there is no active message history to append to."""


@Operation.define
def _get_history() -> collections.OrderedDict[str, Message]:
raise NotImplementedError
raise _NoActiveHistoryException(
"No active message history. This operation should only be used within a handler that provides a message history."
)


def append_message(message: Message):
def append_message(message: Message, last: bool = True) -> None:
try:
_get_history()[message["id"]] = message
except NotImplementedError:
if not last:
_get_history().move_to_end(message["id"], last=False)
except _NoActiveHistoryException:
pass


Expand Down Expand Up @@ -160,6 +178,37 @@ def to_feedback_message(self, include_traceback: bool) -> Message:
type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None]


def _collect_tools(
env: collections.abc.Mapping[str, typing.Any],
) -> collections.abc.Mapping[str, Tool]:
"""Operations and Templates available as tools. Auto-capture from lexical context."""
result = {}

for name, obj in env.items():
# Collect tools directly in context
if isinstance(obj, Tool | Template):
result[name] = obj

# Collect tools as methods on Agent instances in context
elif isinstance(obj, Agent):
for cls in type(obj).__mro__:
for attr_name in vars(cls):
if isinstance(getattr(obj, attr_name), Tool):
result[f"{name}__{attr_name}"] = getattr(obj, attr_name)

# The same Tool can appear under multiple names when it is both
# visible in the enclosing scope *and* discovered via an Agent
# instance's MRO. Since Tools are hashable Operations and
# instance-method Tools are cached per instance, we keep only
# the last name for each unique tool object.
tool2name = {tool: name for name, tool in sorted(result.items())}
for name, tool in tuple(result.items()):
if tool2name[tool] != name:
del result[name]

return result


@Operation.define
@functools.wraps(litellm.completion)
def completion(*args, **kwargs) -> typing.Any:
Expand All @@ -172,10 +221,14 @@ def completion(*args, **kwargs) -> typing.Any:
return litellm.completion(*args, **kwargs)


class _BoxedResponse[T](pydantic.BaseModel):
value: T


@Operation.define
def call_assistant[T, U](
tools: collections.abc.Mapping[str, Tool],
response_format: Encodable[T, U],
def call_assistant[T](
env: collections.abc.Mapping[str, typing.Any],
response_type: type[T],
model: str,
**kwargs,
) -> MessageResult[T]:
Expand All @@ -190,15 +243,28 @@ def call_assistant[T, U](
ResultDecodingError: If the result cannot be decoded. The error
includes the raw assistant message for retry handling.
"""
tools = _collect_tools(env)
tool_specs = {
k: Encodable.define(type(t), tools).encode(t) # type: ignore
k: typing.cast(
pydantic.TypeAdapter[typing.Any],
pydantic.TypeAdapter(Encodable[type(t)]), # type: ignore[misc]
).dump_python(t, mode="json", context={k: t})
for k, t in tools.items()
}
messages = list(_get_history().values())

# The OpenAI API requires a wrapper object for non-object structured output types,
# so we create one on the fly here. Using a Pydantic model offloads JSON schema
# generation and validation logic to litellm, and offers better error messages.
response_format: type[_BoxedResponse[T]] = pydantic.create_model(
"BoxedResponse",
value=Encodable[response_type], # type: ignore[valid-type]
__base__=_BoxedResponse,
)

response: litellm.types.utils.ModelResponse = completion(
model,
messages=list(messages),
response_format=None if response_format.enc is str else response_format.enc,
messages=list(_get_history().values()),
response_format=None if response_type is str else response_format,
tools=list(tool_specs.values()),
**kwargs,
)
Expand All @@ -212,11 +278,12 @@ def call_assistant[T, U](
append_message(raw_message)

tool_calls: list[DecodedToolCall] = []
raw_tool_calls = message.get("tool_calls") or []
encoding = Encodable.define(DecodedToolCall, tools) # type: ignore
for raw_tool_call in raw_tool_calls:
encoding: pydantic.TypeAdapter[DecodedToolCall] = pydantic.TypeAdapter(
Encodable[DecodedToolCall]
)
for raw_tool_call in message.get("tool_calls") or []:
try:
tool_calls += [encoding.decode(raw_tool_call)] # type: ignore
tool_calls += [encoding.validate_python(raw_tool_call, context=tools)]
except Exception as e:
raise ToolCallDecodingError(
raw_tool_call=raw_tool_call,
Expand All @@ -231,12 +298,15 @@ def call_assistant[T, U](
assert isinstance(serialized_result, str), (
"final response from the model should be a string"
)
try:
result = response_format.decode(
response_format.deserialize(serialized_result)
)
except (pydantic.ValidationError, TypeError, ValueError, SyntaxError) as e:
raise ResultDecodingError(e, raw_message=raw_message) from e
if response_type is str:
result = typing.cast(T, serialized_result)
else:
try:
result = response_format.model_validate(
json.loads(serialized_result), context=env
).value
except Exception as e:
raise ResultDecodingError(e, raw_message=raw_message) from e

return (raw_message, tool_calls, result)

Expand All @@ -256,10 +326,12 @@ def call_tool(tool_call: DecodedToolCall) -> Message:
except Exception as e:
raise ToolCallExecutionError(raw_tool_call=tool_call, original_error=e) from e

return_type = Encodable.define(
typing.cast(type[typing.Any], nested_type(result).value)
return_type: pydantic.TypeAdapter[typing.Any] = pydantic.TypeAdapter(
Encodable[nested_type(result).value] # type: ignore[misc]
)
encoded_result = to_content_blocks(
return_type.dump_python(result, mode="json", context={})
)
encoded_result = return_type.serialize(return_type.encode(result))
message = _make_message(
dict(role="tool", content=encoded_result, tool_call_id=tool_call.id),
)
Expand Down Expand Up @@ -295,13 +367,11 @@ def flush_text() -> None:
continue

obj, _ = formatter.get_field(field_name, (), env)
encoder = Encodable.define(
typing.cast(type[typing.Any], nested_type(obj).value), env
)
encoded_obj: typing.Sequence[OpenAIMessageContentListBlock] = encoder.serialize(
encoder.encode(obj)
encoder: pydantic.TypeAdapter[typing.Any] = pydantic.TypeAdapter(
Encodable[nested_type(obj).value] # type: ignore[misc]
)
for part in encoded_obj:
encoded_obj = encoder.dump_python(obj, mode="json", context=env)
for part in to_content_blocks(encoded_obj):
if part["type"] == "text":
text = (
formatter.convert_field(part["text"], conversion)
Expand All @@ -327,21 +397,8 @@ def call_system(template: Template) -> Message:
"""Get system instruction message(s) to prepend to all LLM prompts."""
system_prompt = template.__system_prompt__ or DEFAULT_SYSTEM_PROMPT
message = _make_message(dict(role="system", content=system_prompt))
try:
history: collections.OrderedDict[str, Message] = _get_history()
if any(m["role"] == "system" for m in history.values()):
assert sum(1 for m in history.values() if m["role"] == "system") == 1, (
"There should be at most one system message in the history"
)
assert history[next(iter(history))]["role"] == "system", (
"The system message should be the first message in the history"
)
history.popitem(last=False) # remove existing system message
history[message["id"]] = message
history.move_to_end(message["id"], last=False)
return message
except NotImplementedError:
return message
append_message(message, last=False)
return message


class RetryLLMHandler(ObjectInterpretation):
Expand Down Expand Up @@ -404,20 +461,17 @@ def _before_sleep(self, retry_state: tenacity.RetryCallState) -> None:
self._user_before_sleep(retry_state)

@implements(call_assistant)
def _call_assistant[T, U](
def _call_assistant[T](
self,
tools: collections.abc.Mapping[str, Tool],
response_format: Encodable[T, U],
env: collections.abc.Mapping[str, typing.Any],
response_type: type[T],
model: str,
**kwargs,
) -> MessageResult[T]:
_message_sequence = _get_history().copy()

def _attempt() -> MessageResult[T]:
return fwd(tools, response_format, model, **kwargs)

with handler({_get_history: lambda: _message_sequence}):
message, tool_calls, result = self.call_assistant_retryer(_attempt)
message, tool_calls, result = self.call_assistant_retryer(fwd)

append_message(message)
return (message, tool_calls, result)
Expand Down Expand Up @@ -461,16 +515,20 @@ def _call[**P, T](
bound_args.apply_defaults()
env = template.__context__.new_child(bound_args.arguments)

# Create response_model with env so tools passed as arguments are available
response_model = Encodable.define(template.__signature__.return_annotation, env)
if not _is_recursive_signature(template.__signature__):
env = env.new_child({k: None for k, v in env.items() if v is template})

history: collections.OrderedDict[str, Message] = getattr(
template, "__history__", collections.OrderedDict()
) # type: ignore
history_copy = history.copy()

with handler({_get_history: lambda: history_copy}):
call_system(template)
if (
not _get_history()
or next(iter(_get_history().values()))["role"] != "system"
):
call_system(template)

message: Message = call_user(template.__prompt_template__, env)

Expand All @@ -479,14 +537,14 @@ def _call[**P, T](
result: T | None = None
while message["role"] != "assistant" or tool_calls:
message, tool_calls, result = call_assistant(
template.tools, response_model, **self.config
env, template.__signature__.return_annotation, **self.config
)
for tool_call in tool_calls:
message = call_tool(tool_call)

try:
_get_history()
except NotImplementedError:
except _NoActiveHistoryException:
history.clear()
history.update(history_copy)
return typing.cast(T, result)
Loading
Loading