diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index fc6ca47a..2e169d93 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -4,6 +4,7 @@ import dataclasses import functools import inspect +import json import string import textwrap import traceback @@ -24,8 +25,17 @@ OpenAIMessageContentListBlock, ) -from effectful.handlers.llm.encoding import DecodedToolCall, Encodable -from effectful.handlers.llm.template import Template, Tool +from effectful.handlers.llm.encoding import ( + DecodedToolCall, + Encodable, + to_content_blocks, +) +from effectful.handlers.llm.template import ( + Agent, + Template, + Tool, + _is_recursive_signature, +) from effectful.internals.unification import nested_type from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements @@ -59,15 +69,23 @@ class UserMessage(OpenAIChatCompletionUserMessage): ) +class _NoActiveHistoryException(Exception): + """Raised when there is no active message history to append to.""" + + @Operation.define def _get_history() -> collections.OrderedDict[str, Message]: - raise NotImplementedError + raise _NoActiveHistoryException( + "No active message history. This operation should only be used within a handler that provides a message history." + ) -def append_message(message: Message): +def append_message(message: Message, last: bool = True) -> None: try: _get_history()[message["id"]] = message - except NotImplementedError: + if not last: + _get_history().move_to_end(message["id"], last=False) + except _NoActiveHistoryException: pass @@ -160,6 +178,37 @@ def to_feedback_message(self, include_traceback: bool) -> Message: type MessageResult[T] = tuple[Message, typing.Sequence[DecodedToolCall], T | None] +def _collect_tools( + env: collections.abc.Mapping[str, typing.Any], +) -> collections.abc.Mapping[str, Tool]: + """Operations and Templates available as tools. Auto-capture from lexical context.""" + result = {} + + for name, obj in env.items(): + # Collect tools directly in context + if isinstance(obj, Tool | Template): + result[name] = obj + + # Collect tools as methods on Agent instances in context + elif isinstance(obj, Agent): + for cls in type(obj).__mro__: + for attr_name in vars(cls): + if isinstance(getattr(obj, attr_name), Tool): + result[f"{name}__{attr_name}"] = getattr(obj, attr_name) + + # The same Tool can appear under multiple names when it is both + # visible in the enclosing scope *and* discovered via an Agent + # instance's MRO. Since Tools are hashable Operations and + # instance-method Tools are cached per instance, we keep only + # the last name for each unique tool object. + tool2name = {tool: name for name, tool in sorted(result.items())} + for name, tool in tuple(result.items()): + if tool2name[tool] != name: + del result[name] + + return result + + @Operation.define @functools.wraps(litellm.completion) def completion(*args, **kwargs) -> typing.Any: @@ -172,10 +221,14 @@ def completion(*args, **kwargs) -> typing.Any: return litellm.completion(*args, **kwargs) +class _BoxedResponse[T](pydantic.BaseModel): + value: T + + @Operation.define -def call_assistant[T, U]( - tools: collections.abc.Mapping[str, Tool], - response_format: Encodable[T, U], +def call_assistant[T]( + env: collections.abc.Mapping[str, typing.Any], + response_type: type[T], model: str, **kwargs, ) -> MessageResult[T]: @@ -190,15 +243,28 @@ def call_assistant[T, U]( ResultDecodingError: If the result cannot be decoded. The error includes the raw assistant message for retry handling. """ + tools = _collect_tools(env) tool_specs = { - k: Encodable.define(type(t), tools).encode(t) # type: ignore + k: typing.cast( + pydantic.TypeAdapter[typing.Any], + pydantic.TypeAdapter(Encodable[type(t)]), # type: ignore[misc] + ).dump_python(t, mode="json", context={k: t}) for k, t in tools.items() } - messages = list(_get_history().values()) + + # The OpenAI API requires a wrapper object for non-object structured output types, + # so we create one on the fly here. Using a Pydantic model offloads JSON schema + # generation and validation logic to litellm, and offers better error messages. + response_format: type[_BoxedResponse[T]] = pydantic.create_model( + "BoxedResponse", + value=Encodable[response_type], # type: ignore[valid-type] + __base__=_BoxedResponse, + ) + response: litellm.types.utils.ModelResponse = completion( model, - messages=list(messages), - response_format=None if response_format.enc is str else response_format.enc, + messages=list(_get_history().values()), + response_format=None if response_type is str else response_format, tools=list(tool_specs.values()), **kwargs, ) @@ -212,11 +278,12 @@ def call_assistant[T, U]( append_message(raw_message) tool_calls: list[DecodedToolCall] = [] - raw_tool_calls = message.get("tool_calls") or [] - encoding = Encodable.define(DecodedToolCall, tools) # type: ignore - for raw_tool_call in raw_tool_calls: + encoding: pydantic.TypeAdapter[DecodedToolCall] = pydantic.TypeAdapter( + Encodable[DecodedToolCall] + ) + for raw_tool_call in message.get("tool_calls") or []: try: - tool_calls += [encoding.decode(raw_tool_call)] # type: ignore + tool_calls += [encoding.validate_python(raw_tool_call, context=tools)] except Exception as e: raise ToolCallDecodingError( raw_tool_call=raw_tool_call, @@ -231,12 +298,15 @@ def call_assistant[T, U]( assert isinstance(serialized_result, str), ( "final response from the model should be a string" ) - try: - result = response_format.decode( - response_format.deserialize(serialized_result) - ) - except (pydantic.ValidationError, TypeError, ValueError, SyntaxError) as e: - raise ResultDecodingError(e, raw_message=raw_message) from e + if response_type is str: + result = typing.cast(T, serialized_result) + else: + try: + result = response_format.model_validate( + json.loads(serialized_result), context=env + ).value + except Exception as e: + raise ResultDecodingError(e, raw_message=raw_message) from e return (raw_message, tool_calls, result) @@ -256,10 +326,12 @@ def call_tool(tool_call: DecodedToolCall) -> Message: except Exception as e: raise ToolCallExecutionError(raw_tool_call=tool_call, original_error=e) from e - return_type = Encodable.define( - typing.cast(type[typing.Any], nested_type(result).value) + return_type: pydantic.TypeAdapter[typing.Any] = pydantic.TypeAdapter( + Encodable[nested_type(result).value] # type: ignore[misc] + ) + encoded_result = to_content_blocks( + return_type.dump_python(result, mode="json", context={}) ) - encoded_result = return_type.serialize(return_type.encode(result)) message = _make_message( dict(role="tool", content=encoded_result, tool_call_id=tool_call.id), ) @@ -295,13 +367,11 @@ def flush_text() -> None: continue obj, _ = formatter.get_field(field_name, (), env) - encoder = Encodable.define( - typing.cast(type[typing.Any], nested_type(obj).value), env - ) - encoded_obj: typing.Sequence[OpenAIMessageContentListBlock] = encoder.serialize( - encoder.encode(obj) + encoder: pydantic.TypeAdapter[typing.Any] = pydantic.TypeAdapter( + Encodable[nested_type(obj).value] # type: ignore[misc] ) - for part in encoded_obj: + encoded_obj = encoder.dump_python(obj, mode="json", context=env) + for part in to_content_blocks(encoded_obj): if part["type"] == "text": text = ( formatter.convert_field(part["text"], conversion) @@ -327,21 +397,8 @@ def call_system(template: Template) -> Message: """Get system instruction message(s) to prepend to all LLM prompts.""" system_prompt = template.__system_prompt__ or DEFAULT_SYSTEM_PROMPT message = _make_message(dict(role="system", content=system_prompt)) - try: - history: collections.OrderedDict[str, Message] = _get_history() - if any(m["role"] == "system" for m in history.values()): - assert sum(1 for m in history.values() if m["role"] == "system") == 1, ( - "There should be at most one system message in the history" - ) - assert history[next(iter(history))]["role"] == "system", ( - "The system message should be the first message in the history" - ) - history.popitem(last=False) # remove existing system message - history[message["id"]] = message - history.move_to_end(message["id"], last=False) - return message - except NotImplementedError: - return message + append_message(message, last=False) + return message class RetryLLMHandler(ObjectInterpretation): @@ -404,20 +461,17 @@ def _before_sleep(self, retry_state: tenacity.RetryCallState) -> None: self._user_before_sleep(retry_state) @implements(call_assistant) - def _call_assistant[T, U]( + def _call_assistant[T]( self, - tools: collections.abc.Mapping[str, Tool], - response_format: Encodable[T, U], + env: collections.abc.Mapping[str, typing.Any], + response_type: type[T], model: str, **kwargs, ) -> MessageResult[T]: _message_sequence = _get_history().copy() - def _attempt() -> MessageResult[T]: - return fwd(tools, response_format, model, **kwargs) - with handler({_get_history: lambda: _message_sequence}): - message, tool_calls, result = self.call_assistant_retryer(_attempt) + message, tool_calls, result = self.call_assistant_retryer(fwd) append_message(message) return (message, tool_calls, result) @@ -461,8 +515,8 @@ def _call[**P, T]( bound_args.apply_defaults() env = template.__context__.new_child(bound_args.arguments) - # Create response_model with env so tools passed as arguments are available - response_model = Encodable.define(template.__signature__.return_annotation, env) + if not _is_recursive_signature(template.__signature__): + env = env.new_child({k: None for k, v in env.items() if v is template}) history: collections.OrderedDict[str, Message] = getattr( template, "__history__", collections.OrderedDict() @@ -470,7 +524,11 @@ def _call[**P, T]( history_copy = history.copy() with handler({_get_history: lambda: history_copy}): - call_system(template) + if ( + not _get_history() + or next(iter(_get_history().values()))["role"] != "system" + ): + call_system(template) message: Message = call_user(template.__prompt_template__, env) @@ -479,14 +537,14 @@ def _call[**P, T]( result: T | None = None while message["role"] != "assistant" or tool_calls: message, tool_calls, result = call_assistant( - template.tools, response_model, **self.config + env, template.__signature__.return_annotation, **self.config ) for tool_call in tool_calls: message = call_tool(tool_call) try: _get_history() - except NotImplementedError: + except _NoActiveHistoryException: history.clear() history.update(history_copy) return typing.cast(T, result) diff --git a/effectful/handlers/llm/encoding.py b/effectful/handlers/llm/encoding.py index edd5b07d..eff24f9c 100644 --- a/effectful/handlers/llm/encoding.py +++ b/effectful/handlers/llm/encoding.py @@ -1,28 +1,24 @@ import ast import base64 +import dataclasses import functools import inspect import io +import json import textwrap import types import typing -from abc import ABC, abstractmethod from collections.abc import ( Callable, - Hashable, Mapping, MutableMapping, - MutableSequence, - Sequence, ) -from dataclasses import dataclass -from types import CodeType from typing import Any import litellm import pydantic from litellm import ( - ChatCompletionImageUrlObject, + ChatCompletionImageObject, ChatCompletionMessageToolCall, ChatCompletionTextObject, ChatCompletionToolParam, @@ -32,25 +28,70 @@ import effectful.handlers.llm.evaluation as evaluation from effectful.handlers.llm.template import Tool -from effectful.internals.unification import nested_type -from effectful.ops.semantics import _simple_type -from effectful.ops.syntax import _CustomSingleDispatchCallable +from effectful.internals.unification import GenericAlias, TypeEvaluator, nested_type from effectful.ops.types import Operation, Term type ToolCallID = str +CONTENT_BLOCK_TYPES: frozenset[str] = frozenset( + literal + for member in typing.get_args(OpenAIMessageContentListBlock) + for literal in typing.get_args(typing.get_type_hints(member).get("type", str)) + if isinstance(literal, str) +) -def _pil_image_to_base64_data(pil_image: Image.Image) -> str: - buf = io.BytesIO() - pil_image.save(buf, format="PNG") - return base64.b64encode(buf.getvalue()).decode("utf-8") +@pydantic.validate_call(validate_return=True) +def to_content_blocks(value: typing.Any) -> list[OpenAIMessageContentListBlock]: + """Convert an encoded JSON-compatible value into a flat list of content blocks. -def _pil_image_to_base64_data_uri(pil_image: Image.Image) -> str: - return f"data:image/png;base64,{_pil_image_to_base64_data(pil_image)}" + Walks the value tree, extracting content-block-shaped dicts (identified by + their ``type`` discriminator) and emitting JSON syntax as text around them. + Top-level strings are emitted bare (for natural template rendering). + Inside JSON structures, separators match ``json.dumps`` defaults so that + the linearization law holds for non-string encoded values: + ``linearize(to_content_blocks(v)) == json.dumps(v)``. + """ + if isinstance(value, str): + return [ChatCompletionTextObject(type="text", text=value)] + + buf: list[str] = [] + blocks: list[OpenAIMessageContentListBlock] = [] + + def flush() -> None: + if buf: + blocks.append(ChatCompletionTextObject(type="text", text="".join(buf))) + buf.clear() + + def walk(v: typing.Any) -> None: + if isinstance(v, dict) and v.get("type") in CONTENT_BLOCK_TYPES: + flush() + blocks.append(typing.cast(OpenAIMessageContentListBlock, v)) + elif isinstance(v, dict): + buf.append("{") + for i, (k, val) in enumerate(v.items()): + if i: + buf.append(", ") + buf.append(json.dumps(k) + ": ") + walk(val) + buf.append("}") + elif isinstance(v, list): + buf.append("[") + for i, item in enumerate(v): + if i: + buf.append(", ") + walk(item) + buf.append("]") + else: + buf.append(json.dumps(v)) -@dataclass(frozen=True, eq=True) + walk(value) + flush() + return blocks + + +@dataclasses.dataclass(frozen=True, eq=True) class DecodedToolCall[T]: """ Structured representation of a tool call decoded from an LLM response. @@ -62,374 +103,197 @@ class DecodedToolCall[T]: name: str -class Encodable[T, U](ABC): - base: type[T] - enc: type[U] - ctx: Mapping[str, Any] +if typing.TYPE_CHECKING: + type Encodable[T] = typing.Annotated[T, "encoded"] +else: - @abstractmethod - def encode(self, value: T) -> U: - raise NotImplementedError + class Encodable: + def __class_getitem__(cls, item): + return TypeToPydanticType().evaluate(item) - @abstractmethod - def decode(self, encoded_value: U) -> T: - raise NotImplementedError - @abstractmethod - def serialize(self, encoded_value: U) -> Sequence[OpenAIMessageContentListBlock]: - raise NotImplementedError +class TypeToPydanticType(TypeEvaluator): + """Substitute custom types with their Pydantic Annotated equivalents. - # serialize and deserialize have different types reflecting the LLM api chat.completions(list[content]) -> str - @abstractmethod - def deserialize(self, serialized_value: str) -> U: - raise NotImplementedError + Recursively walks a type annotation tree, replacing leaf types that have + registered Pydantic annotations (e.g., Image.Image -> PydanticImage) and + reconstructing the full generic type. - @typing.final - @staticmethod - @_CustomSingleDispatchCallable - def define( - __dispatch: Callable[ - [type[T]], Callable[[type[T], Mapping[str, Any] | None], "Encodable[T, U]"] - ], - t: type[T], - ctx: Mapping[str, Any] | None = None, - ) -> "Encodable[T, U]": - dispatch_ty = _simple_type(t) - encodable: Encodable[T, U] = __dispatch(dispatch_ty)(t, ctx) - assert issubclass( - pydantic.create_model("Model", v=(encodable.enc, ...)), pydantic.BaseModel - ), f"enc type {encodable.enc} is not a valid pydantic field type for {t}" - return encodable - - -class _BoxEncoding[T](pydantic.BaseModel): - value: T - - -@dataclass -class BaseEncodable[T](Encodable[T, _BoxEncoding[T]]): - base: type[T] - enc: type[_BoxEncoding[T]] - ctx: Mapping[str, Any] - - def encode(self, value: T) -> _BoxEncoding[T]: - return self.enc(value=value) - - def decode(self, encoded_value: _BoxEncoding[T]) -> T: - return typing.cast(T, encoded_value.value) - - def serialize( - self, encoded_value: _BoxEncoding[T] - ) -> Sequence[OpenAIMessageContentListBlock]: - return [{"type": "text", "text": encoded_value.model_dump_json()}] - - def deserialize(self, serialized_value: str) -> _BoxEncoding[T]: - return self.enc.model_validate_json(serialized_value) + The result can be passed to pydantic.TypeAdapter() for automatic + validation and serialization of nested structures. + """ @staticmethod - @functools.cache - def wrapped_model(ty: Hashable) -> type[_BoxEncoding[Any]]: - scalar_ty = typing.cast(type[Any], ty) - return typing.cast( - type[_BoxEncoding[Any]], - pydantic.create_model( - f"Response_{getattr(scalar_ty, '__name__', 'scalar')}", - value=(scalar_ty, ...), - __base__=_BoxEncoding, - __config__={"extra": "forbid"}, - ), - ) + @functools.singledispatch + def _registry(ty: type): + raise RuntimeError("should not be here!") + + @classmethod + def register(cls, *args, **kwargs): + return cls._registry.register(*args, **kwargs) + + def evaluate(self, ty): + app = super().evaluate(ty) + if ( + isinstance(app, type | GenericAlias) + and typing.get_origin(app) is not typing.Annotated + ): + return self._registry.dispatch(typing.get_origin(app) or app)(app) + else: + return app -@dataclass -class StrEncodable(Encodable[str, str]): - base: type[str] - enc: type[str] - ctx: Mapping[str, Any] +@TypeToPydanticType.register(object) +@TypeToPydanticType.register(str) +@TypeToPydanticType.register(pydantic.BaseModel) +def _pydantic_type_base[T](ty: type[T]) -> type[T]: + return ty - def encode(self, value: str) -> str: - return value - def decode(self, encoded_value: str) -> str: - return encoded_value +class _ComplexModel(typing.TypedDict): + real: float + imag: float - def serialize(self, encoded_value: str) -> Sequence[ChatCompletionTextObject]: - # Serialize strings without JSON encoding (no extra quotes) - return [{"type": "text", "text": encoded_value}] - def deserialize(self, serialized_value: str) -> str: - return serialized_value +@pydantic.validate_call(validate_return=True) +def _validate_complex(value: _ComplexModel) -> complex: + return complex(value["real"], value["imag"]) -@dataclass -class PydanticBaseModelEncodable[T: pydantic.BaseModel](Encodable[T, T]): - base: type[T] - enc: type[T] - ctx: Mapping[str, Any] +@pydantic.validate_call(validate_return=True) +def _serialize_complex(value: complex) -> _ComplexModel: + return {"real": value.real, "imag": value.imag} - def decode(self, encoded_value: T) -> T: - return encoded_value - def encode(self, value: T) -> T: - return value +@TypeToPydanticType.register(complex) +def _pydantic_type_complex(ty): + """Encode ``complex`` as ``{"real": float, "imag": float}``.""" - def serialize(self, encoded_value: T) -> Sequence[ChatCompletionTextObject]: - return [{"type": "text", "text": encoded_value.model_dump_json()}] + adapted_schema = pydantic.TypeAdapter(_ComplexModel).json_schema() - def deserialize(self, serialized_value: str) -> T: - return typing.cast(T, self.base.model_validate_json(serialized_value)) + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate_complex), + pydantic.PlainSerializer(_serialize_complex), + pydantic.WithJsonSchema(adapted_schema), + ] -@dataclass -class ImageEncodable(Encodable[Image.Image, pydantic.BaseModel]): - base: type[Image.Image] - enc: type[pydantic.BaseModel] - ctx: Mapping[str, Any] +@TypeToPydanticType.register(tuple) +def _pydantic_type_tuple(ty): + """Convert finitary tuples to object-based schemas (``properties/required``). - def encode(self, value: Image.Image) -> pydantic.BaseModel: - return self.enc( - detail="auto", - url=_pil_image_to_base64_data_uri(value), - ) + OpenAI's strict mode rejects the ``prefixItems`` array schema that Pydantic + emits for fixed-length tuples. We convert them to a Pydantic model with + positional ``item_0``, ``item_1``, … fields instead. - def decode( - self, encoded_value: pydantic.BaseModel | Mapping[str, Any] - ) -> Image.Image: - normalized = self.enc.model_validate(encoded_value) - image_url = typing.cast(str, getattr(normalized, "url")) - if not image_url.startswith("data:image/"): - raise TypeError( - f"expected base64 encoded image as data uri, received {image_url}" - ) - data = image_url.split(",")[1] - return Image.open(fp=io.BytesIO(base64.b64decode(data))) - - def serialize( - self, encoded_value: pydantic.BaseModel - ) -> Sequence[OpenAIMessageContentListBlock]: - return [ - { - "type": "image_url", - "image_url": typing.cast( - ChatCompletionImageUrlObject, - encoded_value.model_dump(exclude_none=True), - ), - } - ] - - def deserialize(self, serialized_value: str) -> pydantic.BaseModel: - # Images are serialized as image_url blocks, not text - # This shouldn't be called in normal flow, but provide a fallback - raise NotImplementedError("Image deserialization from string is not supported") - - -@dataclass -class TupleEncodable[T](Encodable[T, typing.Any]): - """Encodes fixed-length heterogeneous tuples (e.g. ``tuple[int, str]``). - - ``model_cls`` is a dynamic pydantic model (``TupleItems``) with one field - per position, producing an object JSON schema that OpenAI accepts - (unlike the ``prefixItems`` schema from native tuple types). + Bare ``tuple`` and variadic ``tuple[T, ...]`` are passed through unchanged. """ + args = typing.get_args(ty) - base: type[T] - enc: type[typing.Any] - model_cls: type[pydantic.BaseModel] - ctx: Mapping[str, Any] - has_image: bool - element_encoders: list[Encodable] - - def encode(self, value: T) -> typing.Any: - if not isinstance(value, tuple): - raise TypeError(f"Expected tuple, got {type(value)}") - if len(value) != len(self.element_encoders): - raise ValueError( - f"Tuple length {len(value)} does not match expected length {len(self.element_encoders)}" - ) - return tuple( - enc.encode(elem) for enc, elem in zip(self.element_encoders, value) - ) + # Bare tuple or tuple[T, ...] — Pydantic's native handling is fine. + # Note: tuple[()] also has get_args() == (), but has origin=tuple. + if (not args and typing.get_origin(ty) is None) or ( + len(args) == 2 and args[1] is Ellipsis + ): + return ty - def decode(self, encoded_value: typing.Any) -> T: - # Pydantic validation produces a TupleItems model instance; - # extract the positional fields back into a sequence. - if isinstance(encoded_value, pydantic.BaseModel): - items = list(encoded_value.model_dump().values()) - else: - items = list(encoded_value) - if len(items) != len(self.element_encoders): - raise ValueError( - f"tuple length {len(items)} does not match expected length {len(self.element_encoders)}" - ) - return typing.cast( - T, - tuple(enc.decode(elem) for enc, elem in zip(self.element_encoders, items)), - ) + # tuple[()] (empty args with origin) maps to zero fields; otherwise use args. + effective: list[typing.Any] = list(args) - def serialize( - self, encoded_value: typing.Any - ) -> Sequence[OpenAIMessageContentListBlock]: - if self.has_image: - result: list[OpenAIMessageContentListBlock] = [] - for enc, elem in zip(self.element_encoders, encoded_value): - result.extend(enc.serialize(elem)) - return result - model_instance = self.model_cls( - **{f"item_{i}": v for i, v in enumerate(encoded_value)} - ) - json_str = model_instance.model_dump_json() - return [{"type": "text", "text": json_str}] - - def deserialize(self, serialized_value: str) -> typing.Any: - model = self.model_cls.model_validate_json(serialized_value) - # Return raw field values (preserving nested pydantic models). - # Use tuple to be compatible with SequenceEncodable (which also - # produces tuples), ensuring encode idempotency via nested_type. + adapters = [pydantic.TypeAdapter(a) for a in effective] + + model = pydantic.create_model( + "TupleItems", + __config__={"extra": "forbid"}, + **{f"item_{i}": (a, ...) for i, a in enumerate(effective)}, + ) + + def _validate(value, info: pydantic.ValidationInfo): + if isinstance(value, tuple | list): + value = {f"item_{i}": v for i, v in enumerate(value)} return tuple( - getattr(model, f"item_{i}") for i in range(len(self.element_encoders)) + adapters[i].validate_python(value[f"item_{i}"], context=info.context) + for i in range(len(effective)) ) + def _serialize(value, info: pydantic.SerializationInfo): + return { + f"item_{i}": adapters[i].dump_python(v, mode="json", context=info.context) + for i, v in enumerate(value) + } -@dataclass -class NamedTupleEncodable[T](TupleEncodable[T]): - """Tuple encodable that reconstructs the original NamedTuple type on decode.""" + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate), + pydantic.PlainSerializer(_serialize), + pydantic.WithJsonSchema(_inline_refs(model.model_json_schema())), + ] - def decode(self, encoded_value: typing.Any) -> T: - if isinstance(encoded_value, pydantic.BaseModel): - items = list(encoded_value.model_dump().values()) - else: - items = list(encoded_value) - if len(items) != len(self.element_encoders): - raise ValueError( - f"tuple length {len(items)} does not match expected length {len(self.element_encoders)}" - ) - decoded_elements: list[typing.Any] = [ - enc.decode(elem) for enc, elem in zip(self.element_encoders, items) - ] - return typing.cast(T, self.base(*decoded_elements)) - - -@dataclass -class SequenceEncodable[T](Encodable[Sequence[T], typing.Any]): - """Variable-length sequence encoded as a JSON array, decoded back to tuple.""" - - base: type[typing.Any] - enc: type[typing.Any] - ctx: Mapping[str, Any] - has_image: bool - element_encoder: Encodable[T, typing.Any] - - def encode(self, value: Sequence[T]) -> typing.Any: - # Return a tuple so that nested_type routes back through the tuple - # dispatcher, preserving encode idempotency. - return tuple(self.element_encoder.encode(elem) for elem in value) - - def decode(self, encoded_value: typing.Any) -> Sequence[T]: - return typing.cast( - Sequence[T], - tuple(self.element_encoder.decode(elem) for elem in encoded_value), - ) - def serialize( - self, encoded_value: typing.Any - ) -> Sequence[OpenAIMessageContentListBlock]: - if self.has_image: - result: list[OpenAIMessageContentListBlock] = [] - for elem in encoded_value: - result.extend(self.element_encoder.serialize(elem)) - return result - adapter = pydantic.TypeAdapter(self.enc) - # Convert to list for pydantic serialization (enc is list[...]) - json_str = adapter.dump_json(list(encoded_value)).decode("utf-8") - return [{"type": "text", "text": json_str}] - - def deserialize(self, serialized_value: str) -> typing.Any: - adapter = pydantic.TypeAdapter(self.enc) - # validate_json returns a list; convert back to tuple for - # compatibility with SequenceEncodable (which uses tuples). - return tuple(adapter.validate_json(serialized_value)) - - -@dataclass -class MutableSequenceEncodable[T](SequenceEncodable[T]): - """Mutable sequence (list) — same as SequenceEncodable but returns a list.""" - - def encode(self, value: Sequence[T]) -> typing.Any: - if not isinstance(value, MutableSequence): - raise TypeError(f"Expected MutableSequence, got {type(value)}") - return [self.element_encoder.encode(elem) for elem in value] - - def decode(self, encoded_value: typing.Any) -> MutableSequence[T]: - decoded_elements: list[T] = [ - self.element_encoder.decode(elem) for elem in encoded_value - ] - return typing.cast(MutableSequence[T], decoded_elements) - - def deserialize(self, serialized_value: str) -> typing.Any: - adapter = pydantic.TypeAdapter(self.enc) - return adapter.validate_json(serialized_value) - - -@dataclass -class TypedDictEncodable[T](Encodable[T, pydantic.BaseModel]): - base: type[T] - enc: type[pydantic.BaseModel] - ctx: Mapping[str, Any] - - def encode(self, value: T) -> pydantic.BaseModel: - return self.enc.model_validate(value) - - def decode(self, encoded_value: pydantic.BaseModel) -> T: - decoded_value: dict[str, Any] = encoded_value.model_dump() - adapter = pydantic.TypeAdapter(self.base) - return typing.cast(T, adapter.validate_python(decoded_value)) - - def serialize( - self, encoded_value: pydantic.BaseModel - ) -> Sequence[OpenAIMessageContentListBlock]: - return [{"type": "text", "text": encoded_value.model_dump_json()}] - - def deserialize(self, serialized_value: str) -> pydantic.BaseModel: - return self.enc.model_validate_json(serialized_value) +@TypeToPydanticType.register(Term) +def _pydantic_type_term(ty: type[Term]): + raise TypeError("Terms cannot be converted to Pydantic types.") - @staticmethod - @functools.cache - def _typeddict_model(td: type[Any]) -> type[pydantic.BaseModel]: - hints = typing.get_type_hints(td) - required = typing.cast( - frozenset[str], getattr(td, "__required_keys__", frozenset()) - ) - fields: dict[str, Any] = {} - for k, v in hints.items(): - fields[k] = (v, ...) if k in required else (v, None) - return pydantic.create_model( - td.__name__, - **fields, - ) +@TypeToPydanticType.register(Operation) +def _pydantic_type_operation(ty: type[Operation]): + raise TypeError("Operations cannot be converted to Pydantic types.") -def _format_callable_type(callable_type: type[Callable]) -> str: - """Format a Callable type annotation as a string for LLM instructions.""" - args = typing.get_args(callable_type) - if not args: - return "Callable" - # Callable[[arg1, arg2, ...], return_type] - if len(args) >= 2: - param_types = args[0] - return_type = args[-1] +@pydantic.validate_call(validate_return=False) +def _validate_image(value: ChatCompletionImageObject) -> Image.Image: + value = pydantic.TypeAdapter(ChatCompletionImageObject).validate_python(value) + image_url: litellm.ChatCompletionImageUrlObject | str = value["image_url"] + url: str = image_url["url"] if isinstance(image_url, dict) else image_url + prefix, data = url.split(",") + if not prefix.startswith("data:image/"): + raise ValueError(f"expected base64 encoded image as data uri, received {url}") + return Image.open(fp=io.BytesIO(base64.b64decode(data))) - if param_types is ...: - params_str = "..." - elif isinstance(param_types, list | tuple): - params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types) - else: - params_str = str(param_types) - return_str = getattr(return_type, "__name__", str(return_type)) - return f"Callable[[{params_str}], {return_str}]" +def _serialize_image(value: Image.Image) -> ChatCompletionImageObject: + buf = io.BytesIO() + value.save(buf, format="PNG") + url = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}" + return pydantic.TypeAdapter(ChatCompletionImageObject).validate_python( + {"type": "image_url", "image_url": {"detail": "auto", "url": url}} + ) + + +def _inline_refs(schema: dict) -> dict: + """Inline ``$ref`` pointers so ``WithJsonSchema`` never emits orphan refs. + + Workaround for https://github.com/pydantic/pydantic/issues/12145 — + Pydantic's ``GenerateJsonSchema`` does not merge user-provided ``$defs`` + into its internal ref map, so any ``$ref`` in a ``WithJsonSchema`` value + causes a ``KeyError`` when the annotated type is composed into a model. + """ + defs = schema.get("$defs", {}) - return str(callable_type) + def _resolve(obj): + if isinstance(obj, dict): + if "$ref" in obj: + ref_name = obj["$ref"].split("/")[-1] + if ref_name in defs: + return _resolve(defs[ref_name]) + return {k: _resolve(v) for k, v in obj.items() if k != "$defs"} + if isinstance(obj, list): + return [_resolve(item) for item in obj] + return obj + + return _resolve(schema) + + +@TypeToPydanticType.register(Image.Image) +def _pydantic_type_image(ty: type[Image.Image]): + adapter = pydantic.TypeAdapter(ChatCompletionImageObject) + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate_image), + pydantic.PlainSerializer(_serialize_image), + pydantic.WithJsonSchema(_inline_refs(adapter.json_schema())), + ] class SynthesizedFunction(pydantic.BaseModel): @@ -452,7 +316,24 @@ def _create_typed_synthesized_function( Uses pydantic.create_model to ensure the description is included in the JSON schema sent to the LLM, informing it of the expected function signature. """ - type_signature = _format_callable_type(callable_type) + if not typing.get_args(callable_type): + type_signature = "Callable" + # Callable[[arg1, arg2, ...], return_type] + elif len(typing.get_args(callable_type)) >= 2: + param_types = typing.get_args(callable_type)[0] + return_type = typing.get_args(callable_type)[-1] + + if param_types is ...: + params_str = "..." + elif isinstance(param_types, list | tuple): + params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types) + else: + params_str = str(param_types) + + return_str = getattr(return_type, "__name__", str(return_type)) + type_signature = f"Callable[[{params_str}], {return_str}]" + else: + type_signature = str(callable_type) description = f"""Given the specification above, generate a Python function satisfying the following specification and type signature. @@ -516,65 +397,56 @@ def _validate_signature_callable( ) -@dataclass -class CallableEncodable(Encodable[Callable, SynthesizedFunction]): - base: type[Callable] - enc: type[SynthesizedFunction] - ctx: Mapping[str, Any] - expected_params: list[type] | None = None - expected_return: type | None = None # None means decode is disabled +@TypeToPydanticType.register(Callable) +def _pydantic_callable(callable_type: Any) -> Any: + """Create a Pydantic-compatible Annotated type for a parameterized Callable. - def encode(self, value: Callable) -> SynthesizedFunction: - # (https://github.com/python/mypy/issues/14928) - if not isinstance(value, Callable): # type: ignore - raise TypeError(f"Expected callable, got {type(value)}") - - try: - source = inspect.getsource(value) - except (OSError, TypeError): - source = None - - if source: - return self.enc(module_code=textwrap.dedent(source)) + Usage: PydanticCallable(Callable[[int, str], bool]) + """ + type_args = typing.get_args(callable_type) - # Source not available - create stub from name, signature, and docstring - # This is useful for builtins and C extensions - name = getattr(value, "__name__", None) - docstring = inspect.getdoc(value) - if name is None or docstring is None: + if not type_args: + typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type] + expected_params = None + expected_return = None + else: + if len(type_args) < 2: + raise TypeError( + f"Callable type signature incomplete: {callable_type}. " + "Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]." + ) + param_types, expected_return = type_args[0], type_args[-1] + typed_enc = _create_typed_synthesized_function(callable_type) + if param_types is not ... and isinstance(param_types, list | tuple): + expected_params = list(param_types) + else: + expected_params = None + + def _validate(value: Any, info: pydantic.ValidationInfo) -> Callable: + if callable(value) and not isinstance(value, dict): + return value + if isinstance(value, SynthesizedFunction): + encoded = value + elif isinstance(value, dict): + encoded = typed_enc.model_validate(value) + elif isinstance(value, str): + encoded = typed_enc.model_validate_json(value) + else: raise ValueError( - f"Cannot encode callable {value}: no source code and no __name__ or docstring" + f"Expected callable, SynthesizedFunction dict, or JSON string, " + f"got {type(value)}" ) - try: - sig = inspect.signature(value) - sig_str = str(sig) - except (ValueError, TypeError): - # Some builtins don't have inspectable signatures - sig_str = "(...)" - - # Format as a stub function with docstring - stub_code = f'''def {name}{sig_str}: - """{docstring}""" - ... -''' - return self.enc(module_code=stub_code) - - def decode(self, encoded_value: SynthesizedFunction) -> Callable: - # Decode requires a concrete return type for synthesis - if self.expected_return is None: + if expected_return is None: raise TypeError( "Cannot decode/synthesize callable without a concrete type signature. " "Use Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType] " "with a concrete return type (not Any)." ) - filename = f"" - - module_code = encoded_value.module_code - - # Parse and validate AST before execution - module: ast.AST = evaluation.parse(module_code, filename) + ctx = info.context or {} + filename = f"" + module: ast.AST = evaluation.parse(encoded.module_code, filename) if not isinstance(module, ast.Module) or not module.body: raise ValueError( @@ -588,20 +460,12 @@ def decode(self, encoded_value: SynthesizedFunction) -> Callable: f"got {type(last_stmt).__name__}" ) - # Validate signature from AST before execution - _validate_signature_ast(last_stmt, self.expected_params) + _validate_signature_ast(last_stmt, expected_params) + evaluation.type_check(module, ctx, expected_params, expected_return) - # Type-check with mypy; pass original module_code so mypy sees exact source - evaluation.type_check( - module, self.ctx, self.expected_params, self.expected_return - ) - - # Compile and execute - # https://docs.python.org/3/library/functions.html#exec g: MutableMapping[str, Any] = {} - g.update(self.ctx or {}) - - bytecode: CodeType = evaluation.compile(module, filename) + g.update(ctx) + bytecode: types.CodeType = evaluation.compile(module, filename) evaluation.exec(bytecode, g) func_name = last_stmt.name @@ -616,450 +480,156 @@ def decode(self, encoded_value: SynthesizedFunction) -> Callable: f"decode() expected '{func_name}' to be callable, got {type(result)}" ) - # Validate signature from runtime callable after execution - _validate_signature_callable(result, self.expected_params, self.expected_return) - + _validate_signature_callable(result, expected_params, expected_return) return result - def serialize( - self, encoded_value: SynthesizedFunction - ) -> Sequence[OpenAIMessageContentListBlock]: - return [{"type": "text", "text": encoded_value.model_dump_json()}] - - def deserialize(self, serialized_value: str) -> SynthesizedFunction: - return SynthesizedFunction.model_validate_json(serialized_value) - - -def _param_model(sig: inspect.Signature) -> type[pydantic.BaseModel]: - return pydantic.create_model( - "Params", - __config__={"extra": "forbid"}, - **{ - name: Encodable.define(param.annotation).enc - for name, param in sig.parameters.items() - }, # type: ignore - ) - - -@dataclass -class ToolEncodable[**P, T](Encodable[Tool[P, T], pydantic.BaseModel]): - base: type[Tool] - enc: type[pydantic.BaseModel] - ctx: Mapping[str, Any] - - def encode(self, value: Tool[P, T]) -> pydantic.BaseModel: - response_format = litellm.utils.type_to_response_format_param( - _param_model(inspect.signature(value)) - ) - assert response_format is not None - assert value.__default__.__doc__ is not None - return self.enc.model_validate( - { - "type": "function", - "function": { - "name": next( - (k for k, v in self.ctx.items() if v is value), - value.__name__, - ), - "description": textwrap.dedent(value.__default__.__doc__), - "parameters": response_format["json_schema"]["schema"], - "strict": True, - }, - } - ) - - def decode(self, encoded_value: pydantic.BaseModel) -> Tool[P, T]: - raise NotImplementedError("Tools cannot yet be decoded from LLM responses") - - def serialize( - self, encoded_value: pydantic.BaseModel - ) -> Sequence[OpenAIMessageContentListBlock]: - return [ - { - "type": "text", - "text": encoded_value.model_dump_json(exclude_none=True), - } - ] - - def deserialize(self, serialized_value: str) -> pydantic.BaseModel: - return self.enc.model_validate_json(serialized_value) - - -@dataclass -class ToolCallEncodable[T]( - Encodable[DecodedToolCall[T], ChatCompletionMessageToolCall] -): - base: type[DecodedToolCall[T]] - enc: type[ChatCompletionMessageToolCall] - ctx: Mapping[str, Any] - - def encode(self, value: DecodedToolCall[T]) -> ChatCompletionMessageToolCall: - sig = inspect.signature(value.tool) - encoded_args = _param_model(sig).model_validate( - { - k: Encodable.define( - typing.cast(type[Any], nested_type(v).value), self.ctx - ).encode(v) - for k, v in value.bound_args.arguments.items() - } - ) - return ChatCompletionMessageToolCall.model_validate( - { - "type": "tool_call", - "id": value.id, - "function": { - "name": next( - (k for k, v in self.ctx.items() if v is value.tool), - value.tool.__name__, - ), - "arguments": encoded_args.model_dump_json(), - }, - } - ) - - def decode( - self, encoded_value: ChatCompletionMessageToolCall - ) -> DecodedToolCall[T]: - """Decode a tool call from the LLM response into a DecodedToolCall. - - Args: - encoded_value: The tool call to decode. - """ - assert encoded_value.function.name is not None - tool: Tool[..., T] = self.ctx[encoded_value.function.name] - assert isinstance(tool, Tool) - - json_str = encoded_value.function.arguments - sig = inspect.signature(tool) - - raw_args = _param_model(sig).model_validate_json(json_str) - - bound_args: inspect.BoundArguments = sig.bind( - **{ - name: Encodable.define( - typing.cast(type[Any], sig.parameters[name].annotation), self.ctx - ).decode(getattr(raw_args, name)) - for name in raw_args.model_fields_set - } - ) - return DecodedToolCall( - tool=tool, - bound_args=bound_args, - id=encoded_value.id, - name=encoded_value.function.name, - ) - - def serialize( - self, encoded_value: ChatCompletionMessageToolCall - ) -> Sequence[OpenAIMessageContentListBlock]: - return [{"type": "text", "text": encoded_value.model_dump_json()}] - - def deserialize(self, serialized_value: str) -> ChatCompletionMessageToolCall: - return self.enc.model_validate_json(serialized_value) - - -@Encodable.define.register(object) -def _encodable_object[T, U]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - ctx = {} if ctx is None else ctx - wrapped = BaseEncodable.wrapped_model(typing.cast(Hashable, ty)) - return typing.cast(Encodable[T, U], BaseEncodable(ty, wrapped, ctx)) - - -@Encodable.define.register(str) -def _encodable_str(ty: type[str], ctx: Mapping[str, Any] | None) -> Encodable[str, str]: - """Handler for str type that serializes without JSON encoding.""" - return StrEncodable(ty, ty, ctx or {}) - - -class _ComplexParts(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra="forbid") - real: float - imag: float - - -@dataclass -class _ComplexEncodable(Encodable[complex, _ComplexParts]): - base: type[complex] - enc: type[_ComplexParts] - ctx: Mapping[str, Any] - - def encode(self, value: complex) -> _ComplexParts: - return _ComplexParts(real=value.real, imag=value.imag) - - def decode(self, encoded_value: _ComplexParts) -> complex: - return complex(encoded_value.real, encoded_value.imag) - - def serialize( - self, encoded_value: _ComplexParts - ) -> Sequence[OpenAIMessageContentListBlock]: - return [{"type": "text", "text": encoded_value.model_dump_json()}] - - def deserialize(self, serialized_value: str) -> _ComplexParts: - return _ComplexParts.model_validate_json(serialized_value) - - -@Encodable.define.register(complex) -def _encodable_complex( - ty: type[complex], ctx: Mapping[str, Any] | None -) -> Encodable[complex, _ComplexParts]: - return _ComplexEncodable(ty, _ComplexParts, ctx or {}) - - -@Encodable.define.register(Term) -def _encodable_term[T: Term, U]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - raise TypeError("Terms cannot be encoded or decoded in general.") - - -@Encodable.define.register(Operation) -def _encodable_operation[T: Operation, U]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - raise TypeError("Operations cannot be encoded or decoded in general.") - - -@Encodable.define.register(pydantic.BaseModel) -def _encodable_pydantic_base_model[T: pydantic.BaseModel]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, T]: - return PydanticBaseModelEncodable(ty, ty, ctx or {}) - - -@Encodable.define.register(Image.Image) -def _encodable_image( - ty: type[Image.Image], ctx: Mapping[str, Any] | None -) -> Encodable[Image.Image, pydantic.BaseModel]: - image_model = TypedDictEncodable._typeddict_model(ChatCompletionImageUrlObject) - return ImageEncodable(ty, image_model, ctx or {}) - - -@Encodable.define.register(tuple) -def _encodable_tuple[T, U]( - ty: type[T], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - """Dispatch for ``tuple`` types. - - * Bare ``tuple`` (no type params) or ``tuple[T, ...]`` - (variadic) to :class:`SequenceEncodable` (JSON array). - * Named-tuples (subclasses with ``_fields``) → :class:`NamedTupleEncodable`. - * Finitary forms (``tuple[()]``, ``tuple[T]``, ``tuple[T1, T2]``, ...) - to :class:`TupleEncodable` (JSON object with positional fields). + def _serialize(value: Callable) -> dict: + if not callable(value): + raise TypeError(f"Expected callable, got {type(value)}") - https://docs.python.org/3/library/typing.html#annotating-tuples - """ + try: + source = inspect.getsource(value) + except (OSError, TypeError): + source = None - def _is_namedtuple_type(ty: type[Any]) -> bool: - return isinstance(ty, type) and issubclass(ty, tuple) and hasattr(ty, "_fields") + if source: + return typed_enc(module_code=textwrap.dedent(source)).model_dump() - args = typing.get_args(ty) - ctx = {} if ctx is None else ctx - - if typing.get_origin(ty) is None: - if ty is tuple: - # Bare tuple — treat as tuple[Any, ...]. - element_encoder = Encodable.define(typing.cast(type, typing.Any), ctx) - encoded_ty = typing.cast(type[typing.Any], list[element_encoder.enc]) # type: ignore - return typing.cast( - Encodable[T, U], - SequenceEncodable(ty, encoded_ty, ctx, False, element_encoder), - ) - if _is_namedtuple_type(ty): - # NamedTuple — route through tuple logic but decode back into - # the concrete NamedTuple class. - hints = typing.get_type_hints(ty) - tuple_field_types: list[type[Any]] = list(hints.values()) - if not tuple_field_types: - tuple_field_types = [typing.Any] * len(getattr(ty, "_fields", ())) - if not tuple_field_types: - # Empty namedtuple. - empty_model = pydantic.create_model( - "TupleItems", __config__={"extra": "forbid"} - ) - return typing.cast( - Encodable[T, U], - NamedTupleEncodable(ty, empty_model, empty_model, ctx, False, []), - ) - element_encoders = [Encodable.define(arg, ctx) for arg in tuple_field_types] - has_image = any(arg is Image.Image for arg in tuple_field_types) - model_cls = pydantic.create_model( # type: ignore[call-overload] - "TupleItems", - __config__={"extra": "forbid"}, - **{ - f"item_{i}": (enc.enc, ...) - for i, enc in enumerate(element_encoders) - }, - ) - return typing.cast( - Encodable[T, U], - NamedTupleEncodable( - ty, model_cls, model_cls, ctx, has_image, element_encoders - ), + name = getattr(value, "__name__", None) + docstring = inspect.getdoc(value) + if name is None or docstring is None: + raise ValueError( + f"Cannot encode callable {value}: no source code and no __name__ or docstring" ) - # Other tuple subclass — delegate to object. - return _encodable_object(ty, ctx) - - # tuple[T, ...] — variable-length, encode as JSON array. - if len(args) == 2 and args[1] is Ellipsis: - element_ty = args[0] - element_encoder = Encodable.define(element_ty, ctx) - has_image = element_ty is Image.Image - encoded_ty = typing.cast(type[typing.Any], list[element_encoder.enc]) # type: ignore - return typing.cast( - Encodable[T, U], - SequenceEncodable(ty, encoded_ty, ctx, has_image, element_encoder), - ) - # Finitary tuple — fixed-length positional struct. - # Build a pydantic model with item_0, item_1, ... fields so the JSON - # schema uses "properties"/"required" (accepted by OpenAI), not - # "prefixItems" (rejected by OpenAI). - effective_args = [] if (not args or args == ((),)) else list(args) - element_encoders = [Encodable.define(arg, ctx) for arg in effective_args] - has_image = any(arg is Image.Image for arg in effective_args) - model_cls = pydantic.create_model( # type: ignore[call-overload] - "TupleItems", + try: + sig = inspect.signature(value) + sig_str = str(sig) + except (ValueError, TypeError): + sig_str = "(...)" + + stub_code = f'''def {name}{sig_str}: + """{docstring}""" + ... +''' + return typed_enc(module_code=stub_code).model_dump() + + return typing.Annotated[ + callable_type, + pydantic.PlainValidator(_validate), + pydantic.PlainSerializer(_serialize), + pydantic.WithJsonSchema( + _inline_refs(pydantic.TypeAdapter(typed_enc).json_schema()) + ), + ] + + +def _validate_tool( + value: ChatCompletionToolParam, info: pydantic.ValidationInfo +) -> Tool: + assert isinstance(info.context, Mapping), "Tool decoding requires context" + value = pydantic.TypeAdapter(ChatCompletionToolParam).validate_python(value) + try: + return info.context[value["function"]["name"]] + except KeyError as e: + raise NotImplementedError(f"Unknown tool: {value['function']['name']}") from e + + +def _serialize_tool(value: Tool) -> ChatCompletionToolParam: + fields: dict[str, Any] = { + name: TypeToPydanticType().evaluate(param.annotation) + for name, param in inspect.signature(value).parameters.items() + } + sig_model = pydantic.create_model( + "Params", __config__={"extra": "forbid"}, - **{f"item_{i}": (enc.enc, ...) for i, enc in enumerate(element_encoders)}, + **fields, ) - - return typing.cast( - Encodable[T, U], - TupleEncodable(ty, model_cls, model_cls, ctx, has_image, element_encoders), + response_format = litellm.utils.type_to_response_format_param(sig_model) + assert response_format is not None + assert value.__default__.__doc__ is not None + return pydantic.TypeAdapter(ChatCompletionToolParam).validate_python( + { + "type": "function", + "function": { + "name": value.__name__, + "description": textwrap.dedent(value.__default__.__doc__), + "parameters": response_format["json_schema"]["schema"], + "strict": True, + }, + } ) -@Encodable.define.register(Sequence) -def _encodable_sequence[T, U]( - ty: type[Sequence[T]], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - """Dispatch for ``Sequence[T]`` — immutable variable-length sequence.""" - args = typing.get_args(ty) - ctx = {} if ctx is None else ctx - - if not args: - return _encodable_object(ty, ctx) - - element_ty = args[0] - element_encoder = Encodable.define(element_ty, ctx) - has_image = element_ty is Image.Image - encoded_ty = typing.cast(type[typing.Any], list[element_encoder.enc]) # type: ignore - - return typing.cast( - Encodable[T, U], - SequenceEncodable(ty, encoded_ty, ctx, has_image, element_encoder), - ) - - -@Encodable.define.register(list) -@Encodable.define.register(MutableSequence) -def _encodable_mutable_sequence[T, U]( - ty: type[MutableSequence[T]], ctx: Mapping[str, Any] | None -) -> Encodable[T, U]: - args = typing.get_args(ty) - ctx = {} if ctx is None else ctx - - # Handle unparameterized list (list without type args) - if not args: - identity_encoder = typing.cast( - Encodable[T, typing.Any], - BaseEncodable( - typing.cast(type[T], object), - typing.cast( - type[_BoxEncoding[T]], - BaseEncodable.wrapped_model(typing.cast(Hashable, object)), - ), - ctx, - ), +@TypeToPydanticType.register(Tool) +def _pydantic_type_tool(ty: type[Tool]): + adapter = pydantic.TypeAdapter(ChatCompletionToolParam) + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate_tool), + pydantic.PlainSerializer(_serialize_tool), + pydantic.WithJsonSchema(_inline_refs(adapter.json_schema())), + ] + + +def _validate_tool_call( + value: ChatCompletionMessageToolCall, + info: pydantic.ValidationInfo, +) -> DecodedToolCall: + if isinstance(value, dict): + value = ChatCompletionMessageToolCall.model_validate(value) + ctx = info.context or {} + assert value.function.name is not None + tool = ctx[value.function.name] + assert isinstance(tool, Tool) + sig = inspect.signature(tool) + decoded_args = {} + for name, raw_arg in json.loads(value.function.arguments).items(): + assert name in sig.parameters, ( + f"Unexpected argument {name} for tool {tool.__name__}" ) - return typing.cast( - Encodable[T, U], - MutableSequenceEncodable(ty, list[Any], ctx, False, identity_encoder), + param = sig.parameters[name] + arg_enc: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter( + Encodable[param.annotation] # type: ignore[name-defined] ) - - # Get the element type (first type argument) - element_ty = args[0] - element_encoder = Encodable.define(element_ty, ctx) - - # Check if element type is Image.Image - has_image = element_ty is Image.Image - - # Use enc for Image (schema-valid), base otherwise - encoded_ty: type[typing.Any] = typing.cast( - type[typing.Any], - list[element_encoder.enc], # type: ignore - ) - - return typing.cast( - Encodable[T, U], - MutableSequenceEncodable(ty, encoded_ty, ctx, has_image, element_encoder), + decoded_args[name] = arg_enc.validate_python(raw_arg, context=ctx) + return DecodedToolCall( + tool=tool, + bound_args=sig.bind(**decoded_args), + id=value.id, + name=value.function.name, ) -@Encodable.define.register(dict) -@Encodable.define.register(MutableMapping) -@Encodable.define.register(Mapping) -def _encodable_mapping[K, V, U]( - ty: type[Mapping[K, V]], ctx: Mapping[str, Any] | None -) -> Encodable[Mapping[K, V], U]: - ctx = {} if ctx is None else ctx - - if typing.is_typeddict(ty): - return typing.cast( - Encodable[Mapping[K, V], U], - TypedDictEncodable(ty, TypedDictEncodable._typeddict_model(ty), ctx), +def _serialize_tool_call( + value: DecodedToolCall, info: pydantic.SerializationInfo +) -> dict: + ctx = info.context or {} + encoded_args = {} + for k, v in value.bound_args.arguments.items(): + v_enc: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter( + Encodable[nested_type(v).value] # type: ignore[misc] ) - - return _encodable_object(ty, ctx) - - -@Encodable.define.register(Callable) -def _encodable_callable( - ty: type[Callable], ctx: Mapping[str, Any] | None -) -> Encodable[Callable, SynthesizedFunction]: - ctx = ctx or {} - - type_args = typing.get_args(ty) - - # Bare Callable without type args - allow encoding but disable decode - # this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type) - if not type_args: - assert ty is types.FunctionType, f"Callable must have type signatures {ty}" - typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type] - return CallableEncodable(ty, typed_enc, ctx) - - if len(type_args) < 2: - raise TypeError( - f"Callable type signature incomplete: {ty}. " - "Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]." - ) - - param_types, expected_return = type_args[0], type_args[-1] - - typed_enc = _create_typed_synthesized_function(ty) - - # Ellipsis means any params, skip param validation - expected_params: list[type] | None = None - if param_types is not ... and isinstance(param_types, list | tuple): - expected_params = list(param_types) - - return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return) - - -@Encodable.define.register(Tool) -def _encodable_tool[**P, T]( - ty: type[Tool[P, T]], ctx: Mapping[str, Any] | None -) -> Encodable[Tool[P, T], pydantic.BaseModel]: - ctx = ctx or {} - tool_model = TypedDictEncodable._typeddict_model(ChatCompletionToolParam) - return ToolEncodable(ty, tool_model, ctx) - - -@Encodable.define.register(DecodedToolCall) -def _encodable_tool_call[T]( - ty: type[DecodedToolCall[T]], ctx: Mapping[str, Any] | None -) -> Encodable[DecodedToolCall[T], ChatCompletionMessageToolCall]: - ctx = ctx or {} - return ToolCallEncodable(ty, ChatCompletionMessageToolCall, ctx) + encoded_args[k] = v_enc.dump_python(v, mode="json", context=ctx) + return ChatCompletionMessageToolCall.model_validate( + { + "type": "tool_call", + "id": value.id, + "function": { + "name": value.tool.__name__, + "arguments": encoded_args, + }, + } + ).model_dump(mode="json") + + +@TypeToPydanticType.register(DecodedToolCall) +def _pydantic_type_tool_call(ty: type[DecodedToolCall]): + return typing.Annotated[ + ty, + pydantic.PlainValidator(_validate_tool_call), + pydantic.PlainSerializer(_serialize_tool_call), + pydantic.WithJsonSchema( + _inline_refs(ChatCompletionMessageToolCall.model_json_schema()) + ), + ] diff --git a/effectful/handlers/llm/evaluation.py b/effectful/handlers/llm/evaluation.py index 07348cc9..2c3f734c 100644 --- a/effectful/handlers/llm/evaluation.py +++ b/effectful/handlers/llm/evaluation.py @@ -392,7 +392,7 @@ def signature_to_ast(name: str, sig: inspect.Signature) -> ast.FunctionDef: except TypeError: returns = type_to_ast(typing.Any) - node = ast.FunctionDef( # type: ignore + node = ast.FunctionDef( name=name, args=ast.arguments( posonlyargs=[], @@ -413,8 +413,9 @@ def signature_to_ast(name: str, sig: inspect.Signature) -> ast.FunctionDef: cause=None, ) ], - decorator_list=[], + decorator_list=typing.cast(list[ast.expr], []), returns=returns, + type_params=[], ) return ast.fix_missing_locations(node) diff --git a/effectful/handlers/llm/template.py b/effectful/handlers/llm/template.py index a9903005..e0c396dd 100644 --- a/effectful/handlers/llm/template.py +++ b/effectful/handlers/llm/template.py @@ -215,34 +215,17 @@ def __prompt_template__(self) -> str: @property def tools(self) -> Mapping[str, Tool]: """Operations and Templates available as tools. Auto-capture from lexical context.""" - result = {} - is_recursive = _is_recursive_signature(self.__signature__) - - for name, obj in self.__context__.items(): - # Collect tools directly in context - if isinstance(obj, Tool): - result[name] = obj - - # Collect tools as methods on Agent instances in context - elif isinstance(obj, Agent): - for cls in type(obj).__mro__: - for attr_name in vars(cls): - if isinstance(getattr(obj, attr_name), Tool): - result[f"{name}__{attr_name}"] = getattr(obj, attr_name) - - # Deduplicate by tool identity and remove self-references. - # - # The same Tool can appear under multiple names when it is both - # visible in the enclosing scope *and* discovered via an Agent - # instance's MRO. Since Tools are hashable Operations and - # instance-method Tools are cached per instance, we keep only - # the last name for each unique tool object. We also remove - # the template itself from the tool map unless it is explicitly + from effectful.handlers.llm.completions import _collect_tools + + result = _collect_tools(self.__context__) + + # We remove the template itself from the tool map unless it is explicitly # marked as recursive (see test_template_method, test_template_method_nested_class). - tool2name = {tool: name for name, tool in sorted(result.items())} - for name, tool in tuple(result.items()): - if tool2name[tool] != name or (tool is self and not is_recursive): - del result[name] + if not _is_recursive_signature(self.__signature__): + result = dict(result) # copy to allow mutation + for name, tool in tuple(result.items()): + if tool is self: + del result[name] return result @@ -377,18 +360,10 @@ def send(self, user_input: str) -> str: """ - __history__: OrderedDict[str, Mapping[str, Any]] - __system_prompt__: str + @functools.cached_property + def __history__(self) -> OrderedDict[str, Mapping[str, Any]]: + return OrderedDict() - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if not hasattr(cls, "__history__"): - prop = functools.cached_property(lambda _: OrderedDict()) - prop.__set_name__(cls, "__history__") - cls.__history__ = prop - if not hasattr(cls, "__system_prompt__"): - sp = functools.cached_property( - lambda self: inspect.getdoc(type(self)) or "" - ) - sp.__set_name__(cls, "__system_prompt__") - cls.__system_prompt__ = sp + @functools.cached_property + def __system_prompt__(self) -> str: + return inspect.getdoc(type(self)) or "" diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 71d6583f..8841b2aa 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -104,6 +104,86 @@ class Box[T]: value: T +class TypeEvaluator(abc.ABC): + """ + Abstract base class for evaluating type expressions. + + This class defines the interface for evaluating type expressions, which may + involve resolving type variables, computing canonical forms of types, or + performing other transformations. Subclasses should implement the evaluate + method to provide specific evaluation logic. + + The TypeEvaluator can be used in contexts where type expressions need to be + processed or normalized before unification or other type operations. + """ + + @functools.singledispatchmethod + def evaluate(self, typ) -> TypeExpressions: + """ + Normalize generic types + """ + raise TypeError(f"Cannot traverse type {typ}.") + + @evaluate.register + def _(self, typ: TypeConstant | TypeVariable): + return typ + + @evaluate.register + def _(self, typ: GenericAlias): + origin, args = typing.get_origin(typ), typing.get_args(typ) + return origin[self.evaluate(args)] # type: ignore[index] + + @evaluate.register + def _(self, typ: UnionType): + ctyp = self.evaluate(typing.get_args(typ)[0]) + for arg in typing.get_args(typ)[1:]: + ctyp = ctyp | self.evaluate(arg) # type: ignore + return ctyp + + @evaluate.register + def _(self, typ: typing._AnnotatedAlias): # type: ignore + return typing.Annotated[ + self.evaluate(typing.get_args(typ)[0]), + typ.__metadata__, + ] + + @evaluate.register + def _(self, typ: typing._LiteralGenericAlias): # type: ignore + return typ + + @evaluate.register + def _(self, typ: typing.ParamSpecArgs | typing.ParamSpecKwargs): + return typ + + @evaluate.register + def _(self, typ: typing._SpecialGenericAlias): # type: ignore + assert not typing.get_args(typ), "Should not have type arguments" + return typ + + @evaluate.register + def _(self, typ: typing._ConcatenateGenericAlias): # type: ignore + return typing.Concatenate[self.evaluate(typing.get_args(typ))] + + @evaluate.register + def _(self, typ: list | tuple): + return type(typ)(self.evaluate(item) for item in typ) + + @evaluate.register + def _(self, typ: typing.NewType): + return typing.NewType(typ.__name__, self.evaluate(typ.__supertype__)) # type: ignore[attr-defined] + + @evaluate.register + def _(self, typ: typing.TypeAliasType): + return self.evaluate(typ.__value__) + + @evaluate.register + def _(self, typ: typing.ForwardRef): + if typ.__forward_value__ is not None: + return self.evaluate(typ.__forward_value__) + else: + return typ + + @typing.overload def unify( typ: inspect.Signature, diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLLMLoggingHandler__test_custom_logger.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLLMLoggingHandler__test_custom_logger.json index c9d68532..0107fefe 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLLMLoggingHandler__test_custom_logger.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLLMLoggingHandler__test_custom_logger.json @@ -41,4 +41,4 @@ }, "total_tokens": 332 } -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLLMLoggingHandler__test_logs_requests.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLLMLoggingHandler__test_logs_requests.json index 022def82..dc1729b2 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLLMLoggingHandler__test_logs_requests.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLLMLoggingHandler__test_logs_requests.json @@ -41,4 +41,4 @@ }, "total_tokens": 332 } -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_bare_tuple_param.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_bare_tuple_param.json index b1828b6d..196b7407 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_bare_tuple_param.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_bare_tuple_param.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"The items include apples, bananas, and cherries.\"}", + "content": "The items include apples, bananas, and cherries.", "role": "assistant", "tool_calls": null, "function_call": null, diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_integer_return_type.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_integer_return_type.json index bc5d3bc0..ffa7d90a 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_integer_return_type.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_integer_return_type.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":73}", + "content": "{\"value\": 73}", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_nested_template_with_tuple_param_1.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_nested_template_with_tuple_param_1.json index 7681eb88..b2f2adba 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_nested_template_with_tuple_param_1.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_nested_template_with_tuple_param_1.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Call search(query='nearby restaurants', max_results=10) to find relevant results.\"}", + "content": "Call search(query='nearby restaurants', max_results=10) to find relevant results.", "role": "assistant", "tool_calls": null, "function_call": null, diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_nested_template_with_tuple_param_2.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_nested_template_with_tuple_param_2.json index 254db88d..a48512eb 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_nested_template_with_tuple_param_2.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_nested_template_with_tuple_param_2.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"To find nearby restaurants, use the search tool with query='nearby restaurants' and max_results=10. This will return a list of relevant restaurant options.\"}", + "content": "To find nearby restaurants, use the search tool with query='nearby restaurants' and max_results=10. This will return a list of relevant restaurant options.", "role": "assistant", "tool_calls": null, "function_call": null, diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_cross_endpoint[claude-haiku-4-5].json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_cross_endpoint[claude-haiku-4-5].json index 7fdb0c6e..a7b10c1d 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_cross_endpoint[claude-haiku-4-5].json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_cross_endpoint[claude-haiku-4-5].json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\": \"Testing is a crucial process that ensures software quality, identifies bugs, and validates that systems work as intended before deployment to users.\"}", + "content": "Testing is a crucial process that ensures software quality, identifies bugs, and validates that systems work as intended before deployment to users.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -36,4 +36,4 @@ "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0 } -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_cross_endpoint[gpt-4o-mini].json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_cross_endpoint[gpt-4o-mini].json index be954644..8dae5bf9 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_cross_endpoint[gpt-4o-mini].json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_cross_endpoint[gpt-4o-mini].json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Testing ensures that a product meets its requirements and functions as intended.\"}", + "content": "Testing ensures that a product meets its requirements and functions as intended.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_multiple_models[gpt-4o-mini].json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_multiple_models[gpt-4o-mini].json index 9d40c5f2..5f053059 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_multiple_models[gpt-4o-mini].json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_multiple_models[gpt-4o-mini].json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Testing is the process of evaluating a system or component to ensure it meets specified requirements and functions correctly.\"}", + "content": "Testing is the process of evaluating a system or component to ensure it meets specified requirements and functions correctly.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_multiple_models[gpt-5-nano].json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_multiple_models[gpt-5-nano].json index 83ad3c82..47d240c0 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_multiple_models[gpt-5-nano].json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_simple_prompt_multiple_models[gpt-5-nano].json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Testing helps catch mistakes before they reach users.\"}", + "content": "Testing helps catch mistakes before they reach users.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_structured_output.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_structured_output.json index ae5497ee..68073973 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_structured_output.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_structured_output.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\": {\"genre\": \"action\", \"explanation\": \"The plot centers on a rogue cop taking on an evil group to stop them from taking over a skyscraper, featuring high-stakes conflict and action-oriented sequences typical of the action genre.\"}}", + "content": "{\"genre\": \"action\", \"explanation\": \"The plot centers on a rogue cop taking on an evil group to stop them from taking over a skyscraper, featuring high-stakes conflict and action-oriented sequences typical of the action genre.\"}", "role": "assistant", "tool_calls": null, "function_call": null, diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_with_config_params.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_with_config_params.json index 3a5c57fd..17a9b408 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_with_config_params.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestLiteLLMProvider__test_with_config_params.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"A deterministic test consistently produces the same output for a given input, ensuring reliability and repeatability in its results.\"}", + "content": "A deterministic test consistently produces the same output for a given input, ensuring reliability and repeatability in its results.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_simple_prompt_unique_message_ids.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_simple_prompt_unique_message_ids.json index 09007763..ac4b5e5a 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_simple_prompt_unique_message_ids.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_simple_prompt_unique_message_ids.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Testing is essential to ensure the quality and reliability of software before its release.\"}", + "content": "Testing is essential to ensure the quality and reliability of software before its release.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_tool_calling_no_duplicate_message_ids.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_tool_calling_no_duplicate_message_ids.json index 88246e5b..ec9143b1 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_tool_calling_no_duplicate_message_ids.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_tool_calling_no_duplicate_message_ids.json @@ -14,7 +14,7 @@ "tool_calls": [ { "function": { - "arguments": "{\"a\":{\"value\":3},\"b\":{\"value\":5}}", + "arguments": "{\"a\":3,\"b\":5}", "name": "add_numbers" }, "id": "call_Enwfy9ZeiH3Oj5TTAad5owj6", diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_tool_calling_no_duplicate_message_ids_1.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_tool_calling_no_duplicate_message_ids_1.json index e641e33a..aa5fec00 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_tool_calling_no_duplicate_message_ids_1.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestMessageSequenceReplay__test_tool_calling_no_duplicate_message_ids_1.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":8}", + "content": "{\"value\": 8}", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestProgramSynthesis__test_generates_callable.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestProgramSynthesis__test_generates_callable.json index b511db5f..9e71d7e8 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestProgramSynthesis__test_generates_callable.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestProgramSynthesis__test_generates_callable.json @@ -41,4 +41,4 @@ }, "total_tokens": 489 } -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_1.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_1.json index 26a7e114..3f6d1402 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_1.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_1.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":{\"module_code\":\"def solution() -> str:\\n return \\\"The moon hung low in the midnight sky with a mysterious glow, inviting all to take a long walk. Shadows danced across the street, tumbling with every breeze that whispered secrets. Somewhere in the distance, there was the sound of another, an echo of a life unknown. In the middle of it all, the unmistakable laugh of a lunatic rang out, sending chills down spines.\\\"\"}}", + "content": "{\"module_code\": \"def solution() -> str:\\n return \\\"The moon hung low in the midnight sky with a mysterious glow, inviting all to take a long walk. Shadows danced across the street, tumbling with every breeze that whispered secrets. Somewhere in the distance, there was the sound of another, an echo of a life unknown. In the middle of it all, the unmistakable laugh of a lunatic rang out, sending chills down spines.\\\"\"}", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_2.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_2.json index ce6556b9..3fa09f26 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_2.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_2.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":{\"module_code\":\"def solution() -> str:\\n return \\\"The moon hung low in the midnight sky with a mysterious glow, inviting all to take a long walk. Shadows danced across the street, tumbling with every breeze that whispered secrets. Somewhere in the distance, there was the sound of another, an echo of a life unknown. In the middle of it all, the unmistakable laugh of a lunatic rang out, sending chills down spines.\\\"\"}}", + "content": "{\"module_code\": \"def solution() -> str:\\n return \\\"The moon hung low in the midnight sky with a mysterious glow, inviting all to take a long walk. Shadows danced across the street, tumbling with every breeze that whispered secrets. Somewhere in the distance, there was the sound of another, an echo of a life unknown. In the middle of it all, the unmistakable laugh of a lunatic rang out, sending chills down spines.\\\"\"}", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_3.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_3.json index aec0fb90..11dff89b 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_3.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_3.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":{\"module_code\":\"def solution() -> str:\\n return \\\"The moon hung low in the midnight sky with a mysterious glow, inviting all to take a long walk. Shadows danced across the street, tumbling with every breeze that whispered secrets. Somewhere in the distance, there was the sound of another, an echo of a life unknown. In the middle of it all, the unmistakable laugh of a lunatic rang out, sending chills down spines.\\\"\"}}", + "content": "{\"module_code\": \"def solution() -> str:\\n return \\\"The moon hung low in the midnight sky with a mysterious glow, inviting all to take a long walk. Shadows danced across the street, tumbling with every breeze that whispered secrets. Somewhere in the distance, there was the sound of another, an echo of a life unknown. In the middle of it all, the unmistakable laugh of a lunatic rang out, sending chills down spines.\\\"\"}", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_4.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_4.json index 7accf770..c3dd832b 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_4.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestRetryLLMHandler__test_codeadapt_notebook_replay_fixture_4.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"It seems there was an issue with executing the `codeact` tool. However, I can manually provide you with a solution.\\n\\nHere's a Python function named `solution` that generates a paragraph:\\n\\n```python\\ndef solution() -> str:\\n return \\\"The moon hung low in the midnight sky with a mysterious glow, inviting all to take a long walk. Shadows danced across the street, tumbling with every breeze that whispered secrets. Somewhere in the distance, there was the sound of another, an echo of a life unknown. In the middle of it all, the unmistakable laugh of a lunatic rang out, sending chills down spines.\\\"\\n```\\n\\nThis function simply returns a string that matches the sentence structure requirements (ending with the words 'walk', 'tumbling', 'another', and 'lunatic').\"}", + "content": "It seems there was an issue with executing the `codeact` tool. However, I can manually provide you with a solution.\n\nHere's a Python function named `solution` that generates a paragraph:\n\n```python\ndef solution() -> str:\n return \"The moon hung low in the midnight sky with a mysterious glow, inviting all to take a long walk. Shadows danced across the street, tumbling with every breeze that whispered secrets. Somewhere in the distance, there was the sound of another, an echo of a life unknown. In the middle of it all, the unmistakable laugh of a lunatic rang out, sending chills down spines.\"\n```\n\nThis function simply returns a string that matches the sentence structure requirements (ending with the words 'walk', 'tumbling', 'another', and 'lunatic').", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_image_input.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_image_input.json index 3c2d22d5..e1b5d368 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_image_input.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_image_input.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"The image is a simple pixel art of a smiley face. It features two square white eyes and a wide white smile set against a black background.\"}", + "content": "The image is a simple pixel art of a smiley face. It features two square white eyes and a wide white smile set against a black background.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration.json index 40b905a6..58b33f9c 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are crisp, juicy fruits that come in a variety of colors, offering a sweet or tart flavor profile perfect for snacking or baking.\"}", + "content": "Apples are crisp, juicy fruits that come in a variety of colors, offering a sweet or tart flavor profile perfect for snacking or baking.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_1.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_1.json index 8c5a0bf8..2cc3db94 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_1.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_1.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are crisp, juicy fruits that come in a variety of colors, offering a sweet or tart flavor profile perfect for snacking or baking.\"}", + "content": "Apples are crisp, juicy fruits that come in a variety of colors, offering a sweet or tart flavor profile perfect for snacking or baking.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -39,4 +39,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_2.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_2.json index 34190c89..fd5358f9 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_2.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_2.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Oranges are vibrant citrus fruits known for their juicy sweetness and rich vitamin C content.\"}", + "content": "Oranges are vibrant citrus fruits known for their juicy sweetness and rich vitamin C content.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_disabled.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_disabled.json index 8a57f076..3fdcfbcb 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_disabled.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_disabled.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are a versatile and nutritious fruit, enjoyed in a variety of dishes worldwide.\"}", + "content": "Apples are a versatile and nutritious fruit, enjoyed in a variety of dishes worldwide.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_disabled_1.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_disabled_1.json index 45b4613d..8038d1b3 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_disabled_1.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_integration_disabled_1.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are crisp, sweet fruits that come in a variety of colors and are packed with essential vitamins and fiber.\"}", + "content": "Apples are crisp, sweet fruits that come in a variety of colors and are packed with essential vitamins and fiber.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective.json index fd332bfc..1e109770 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are crisp, juicy fruits that come in a variety of colors and flavors, often enjoyed fresh or used in delicious recipes.\"}", + "content": "Apples are crisp, juicy fruits that come in a variety of colors and flavors, often enjoyed fresh or used in delicious recipes.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_1.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_1.json index 4822c6d3..88c3e48e 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_1.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_1.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are crisp, sweet fruits that come in a variety of colors, including red, green, and yellow.\"}", + "content": "Apples are crisp, sweet fruits that come in a variety of colors, including red, green, and yellow.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_2.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_2.json index beb3c03e..16a72d5f 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_2.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_2.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are a popular and nutritious fruit, known for their crisp texture and a range of flavors from sweet to tart.\"}", + "content": "Apples are a popular and nutritious fruit, known for their crisp texture and a range of flavors from sweet to tart.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_3.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_3.json index d4211c06..7253e7d7 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_3.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_3.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are a popular and nutritious fruit, known for their crisp texture and a range of flavors from sweet to tart.\"}", + "content": "Apples are a popular and nutritious fruit, known for their crisp texture and a range of flavors from sweet to tart.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -39,4 +39,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_4.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_4.json index 2bf9bc05..c67f63dd 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_4.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_4.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are crisp, juicy fruits that come in a variety of colors, offering both health benefits and a sweet taste.\"}", + "content": "Apples are crisp, juicy fruits that come in a variety of colors, offering both health benefits and a sweet taste.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_5.json b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_5.json index 92639a81..12d85814 100644 --- a/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_5.json +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__test_litellm_caching_selective_5.json @@ -9,7 +9,7 @@ "finish_reason": "stop", "index": 0, "message": { - "content": "{\"value\":\"Apples are nutritious fruits that come in a variety of colors and flavors, offering a refreshing and naturally sweet snack.\"}", + "content": "Apples are nutritious fruits that come in a variety of colors and flavors, offering a refreshing and naturally sweet snack.", "role": "assistant", "tool_calls": null, "function_call": null, @@ -41,4 +41,4 @@ } }, "service_tier": "default" -} \ No newline at end of file +} diff --git a/tests/test_handlers_llm_encoding.py b/tests/test_handlers_llm_encoding.py index 9cccf93d..fd3d2889 100644 --- a/tests/test_handlers_llm_encoding.py +++ b/tests/test_handlers_llm_encoding.py @@ -1,12 +1,13 @@ """ Law-based test suite for effectful.handlers.llm.encoding. -Each test function verifies a single equational law of the Encodable[T, U] -interface, parametrized over many types and values. +Each test function verifies a single equational law of the Encodable[T] +type-level encoding, parametrized over many types and values. """ import inspect import io +import json import re from collections.abc import Callable, Mapping from dataclasses import dataclass @@ -15,22 +16,23 @@ import litellm import pydantic import pytest -from litellm import ChatCompletionMessageToolCall +from litellm import ChatCompletionMessageToolCall, OpenAIMessageContentListBlock from PIL import Image from effectful.handlers.llm.encoding import ( + CONTENT_BLOCK_TYPES, DecodedToolCall, Encodable, SynthesizedFunction, + to_content_blocks, ) from effectful.handlers.llm.evaluation import RestrictedEvalProvider, UnsafeEvalProvider from effectful.handlers.llm.template import Tool from effectful.internals.unification import nested_type from effectful.ops.semantics import handler -from effectful.ops.types import Operation, Term from tests.test_handlers_llm_tool_calling_book import requires_openai -CHEAP_MODEL = "gpt-4o-mini" +CHEAP_MODEL = "gpt-4-mini" # --------------------------------------------------------------------------- # Module-level type definitions @@ -225,7 +227,7 @@ def _make_dtc(tool, kwargs, call_id): # --------------------------------------------------------------------------- # (type_annotation, value, ctx) triples — reused across law tests. -# ctx=None means Encodable.define(ty), otherwise Encodable.define(ty, ctx). +# ctx=None means no context, otherwise passed as context to dump_python/validate_python. ROUNDTRIP_CASES = [ # --- str --- pytest.param(str, "hello", None, id="str-hello"), @@ -317,6 +319,12 @@ def _make_dtc(tool, kwargs, call_id): None, id="tuple-img-str", ), + pytest.param( + tuple[str, Image.Image, str], + ("before", _make_png_image("RGB", (5, 5), "green"), "after"), + None, + id="tuple-str-img-str", + ), pytest.param( list[Image.Image], [ @@ -326,12 +334,39 @@ def _make_dtc(tool, kwargs, call_id): None, id="list-img", ), + # --- deeper generic composition with Image --- + pytest.param( + list[tuple[str, Image.Image]], + [ + ("first", _make_png_image("RGB", (4, 4), "red")), + ("second", _make_png_image("RGB", (4, 4), "blue")), + ], + None, + id="list-tuple-str-img", + ), # --- Tool --- - pytest.param(type(_tool_add), _tool_add, None, id="tool-add"), - pytest.param(type(_tool_greet), _tool_greet, None, id="tool-greet"), - pytest.param(type(_tool_process), _tool_process, None, id="tool-process"), - pytest.param(type(_tool_get_value), _tool_get_value, None, id="tool-no-params"), - pytest.param(type(_tool_distance), _tool_distance, None, id="tool-pydantic-param"), + pytest.param(type(_tool_add), _tool_add, {"_tool_add": _tool_add}, id="tool-add"), + pytest.param( + type(_tool_greet), _tool_greet, {"_tool_greet": _tool_greet}, id="tool-greet" + ), + pytest.param( + type(_tool_process), + _tool_process, + {"_tool_process": _tool_process}, + id="tool-process", + ), + pytest.param( + type(_tool_get_value), + _tool_get_value, + {"_tool_get_value": _tool_get_value}, + id="tool-no-params", + ), + pytest.param( + type(_tool_distance), + _tool_distance, + {"_tool_distance": _tool_distance}, + id="tool-pydantic-param", + ), # --- DecodedToolCall --- pytest.param( DecodedToolCall, @@ -365,122 +400,140 @@ def _make_dtc(tool, kwargs, call_id): ), ] -# Filter ID sets -_IMAGE_IDS = frozenset({"img-red", "img-blue-alpha", "tuple-img-str", "list-img"}) -_TOOL_IDS = frozenset( - {"tool-add", "tool-greet", "tool-process", "tool-no-params", "tool-pydantic-param"} -) - -_tool_decode_xfail = pytest.mark.xfail( - raises=NotImplementedError, reason="Tool.decode not yet implemented" -) +# ============================================================================ +# Law 1: decode(encode(v)) == v +# ============================================================================ -def _xfail_tools(cases): - """Add xfail mark to Tool cases (whose decode raises NotImplementedError).""" - return [ - pytest.param(*c.values, marks=[*c.marks, _tool_decode_xfail], id=c.id) - if c.id in _TOOL_IDS - else c - for c in cases - ] +@pytest.mark.parametrize("ty,value,ctx", ROUNDTRIP_CASES) +def test_encode_decode_roundtrip(ty, value, ctx): + enc = pydantic.TypeAdapter(Encodable[ty]) + encoded = enc.dump_python(value, mode="json", context=ctx or {}) + assert enc.validate_python(encoded, context=ctx or {}) == value -# Derived case lists -# decode: Tool cases are xfail (decode raises NotImplementedError) -DECODE_CASES = _xfail_tools(ROUNDTRIP_CASES) +# ============================================================================ +# Law 2: json.loads(json.dumps(encode(v))) == encode(v) +# ============================================================================ -# Text-serializable: everything except Image-containing types -TEXT_CASES = [c for c in ROUNDTRIP_CASES if c.id not in _IMAGE_IDS] -# Full pipeline (encode→serialize→deserialize→decode): needs both text and decode -FULL_PIPELINE_CASES = _xfail_tools( - [c for c in ROUNDTRIP_CASES if c.id not in _IMAGE_IDS] -) +@pytest.mark.parametrize("ty,value,ctx", ROUNDTRIP_CASES) +def test_serialize_deserialize_roundtrip(ty, value, ctx): + enc = pydantic.TypeAdapter(Encodable[ty]) + encoded = enc.dump_python(value, mode="json", context=ctx or {}) + assert json.loads(json.dumps(encoded)) == encoded # ============================================================================ -# Law 1: decode(encode(v)) == v +# Law 3: decode(json.loads(json.dumps(encode(v)))) == v # ============================================================================ -@pytest.mark.parametrize("ty,value,ctx", DECODE_CASES) -def test_encode_decode_roundtrip(ty, value, ctx): - enc = Encodable.define(ty, ctx) - assert enc.decode(enc.encode(value)) == value +@pytest.mark.parametrize("ty,value,ctx", ROUNDTRIP_CASES) +def test_full_pipeline_roundtrip(ty, value, ctx): + enc = pydantic.TypeAdapter(Encodable[ty]) + encoded = enc.dump_python(value, mode="json", context=ctx or {}) + assert ( + enc.validate_python(json.loads(json.dumps(encoded)), context=ctx or {}) == value + ) # ============================================================================ -# Law 2: deserialize(serialize(encode(v))[0]["text"]) == encode(v) +# Law 5: encode(encode(v)) == encode(v) (idempotency) # ============================================================================ -@pytest.mark.parametrize("ty,value,ctx", TEXT_CASES) -def test_serialize_deserialize_roundtrip(ty, value, ctx): - enc = Encodable.define(ty, ctx) - encoded = enc.encode(value) - blocks = enc.serialize(encoded) - assert len(blocks) == 1 - assert blocks[0]["type"] == "text" - assert enc.deserialize(blocks[0]["text"]) == encoded +@pytest.mark.parametrize("ty,value,ctx", ROUNDTRIP_CASES) +def test_encode_idempotent(ty, value, ctx): + once = pydantic.TypeAdapter(Encodable[ty]).dump_python( + value, mode="json", context=ctx or {} + ) + twice = pydantic.TypeAdapter(Encodable[nested_type(once).value]).dump_python( + once, mode="json", context=ctx or {} + ) + assert once == twice # ============================================================================ -# Law 3: decode(deserialize(serialize(encode(v))[0]["text"])) == v +# to_content_blocks helpers # ============================================================================ -@pytest.mark.parametrize("ty,value,ctx", FULL_PIPELINE_CASES) -def test_full_pipeline_roundtrip(ty, value, ctx): - enc = Encodable.define(ty, ctx) - encoded = enc.encode(value) - text = enc.serialize(encoded)[0]["text"] - assert enc.decode(enc.deserialize(text)) == value +def _linearize(blocks: list[OpenAIMessageContentListBlock]) -> str: + """Concatenate content blocks back into a JSON string.""" + return "".join(b["text"] if b["type"] == "text" else json.dumps(b) for b in blocks) + + +def _has_content_block(v): + """Recursively check whether v contains any content-block-shaped dicts.""" + if isinstance(v, dict) and v.get("type") in CONTENT_BLOCK_TYPES: + return True + if isinstance(v, dict): + return any(_has_content_block(val) for val in v.values()) + if isinstance(v, list): + return any(_has_content_block(item) for item in v) + return False # ============================================================================ -# Law 4: serialize(encode(v)) succeeds +# Law 6: linearize(to_content_blocks(encode(v))) == json.dumps(encode(v)) +# (for non-string encoded values; bare strings are emitted unquoted) # ============================================================================ @pytest.mark.parametrize("ty,value,ctx", ROUNDTRIP_CASES) -def test_serialize_succeeds(ty, value, ctx): - enc = Encodable.define(ty, ctx) - enc.serialize(enc.encode(value)) +def test_to_content_blocks_linearization(ty, value, ctx): + encoded = pydantic.TypeAdapter(Encodable[ty]).dump_python( + value, mode="json", context=ctx or {} + ) + if isinstance(encoded, str): + # Bare strings are emitted without JSON quoting for natural template rendering + assert _linearize(to_content_blocks(encoded)) == encoded + else: + assert _linearize(to_content_blocks(encoded)) == json.dumps(encoded) # ============================================================================ -# Law 5: encode(encode(v)) == encode(v) (idempotency) +# Law 7: decode(json.loads(linearize(to_content_blocks(encode(v))))) == v +# (for non-string encoded values; bare strings roundtrip directly) # ============================================================================ -@pytest.mark.parametrize( - "ty,value,ctx", - ROUNDTRIP_CASES, -) -def test_encode_idempotent(ty, value, ctx): - enc = Encodable.define(ty, ctx) - once = enc.encode(value) - twice = Encodable.define(nested_type(once).value, ctx).encode(once) - assert once == twice +@pytest.mark.parametrize("ty,value,ctx", ROUNDTRIP_CASES) +def test_to_content_blocks_full_pipeline(ty, value, ctx): + enc = pydantic.TypeAdapter(Encodable[ty]) + encoded = enc.dump_python(value, mode="json", context=ctx or {}) + linearized = _linearize(to_content_blocks(encoded)) + if isinstance(encoded, str): + assert enc.validate_python(linearized, context=ctx or {}) == value + else: + assert enc.validate_python(json.loads(linearized), context=ctx or {}) == value # ============================================================================ -# Term-specific: Encodable.define raises TypeError for Term and Operation +# Law 8: no content blocks hidden in text (maximal extraction) # ============================================================================ -@pytest.mark.parametrize("ty", [Term, Operation]) -def test_define_raises_for_invalid_types(ty): - with pytest.raises(TypeError): - Encodable.define(ty) +@pytest.mark.parametrize("ty,value,ctx", ROUNDTRIP_CASES) +def test_to_content_blocks_maximal_extraction(ty, value, ctx): + encoded = pydantic.TypeAdapter(Encodable[ty]).dump_python( + value, mode="json", context=ctx or {} + ) + if isinstance(encoded, str): + # Bare strings are emitted unquoted; they can't contain content blocks + return + blocks = to_content_blocks(encoded) + skeleton = json.loads( + "".join(b["text"] if b["type"] == "text" else "null" for b in blocks) + ) + assert not _has_content_block(skeleton) # ============================================================================ -# Image-specific: deserialize raises, decode rejects invalid URLs +# Tuple-specific: error cases # ============================================================================ - TUPLE_SCHEMA_CASES = [ pytest.param(tuple[int, str], id="tuple-int-str"), pytest.param(tuple[int, str, bool], id="tuple-three"), @@ -491,25 +544,12 @@ def test_define_raises_for_invalid_types(ty): @pytest.mark.parametrize("ty", TUPLE_SCHEMA_CASES) def test_tuple_schema_no_prefix_items(ty): """Finitary tuple schemas use properties/required, not prefixItems.""" - enc = Encodable.define(ty) - schema = pydantic.TypeAdapter(enc.enc).json_schema() + schema = pydantic.TypeAdapter(Encodable[ty]).json_schema() assert "prefixItems" not in str(schema), ( f"Schema for {ty} should not contain prefixItems: {schema}" ) -def test_image_deserialize_raises(): - enc = Encodable.define(Image.Image) - with pytest.raises(NotImplementedError): - enc.deserialize("anything") - - -def test_image_decode_rejects_non_data_uri(): - enc = Encodable.define(Image.Image) - with pytest.raises(TypeError): - enc.decode({"url": "http://example.com/image.png", "detail": "auto"}) - - # ============================================================================ # DecodedToolCall-specific: error cases # ============================================================================ @@ -565,9 +605,10 @@ def test_toolcall_decode_rejects_invalid(tool_name, args_json, ctx, exc_type): "function": {"name": tool_name, "arguments": args_json}, } ) - enc = Encodable.define(DecodedToolCall, ctx) with pytest.raises(exc_type): - enc.decode(tool_call) + pydantic.TypeAdapter(Encodable[DecodedToolCall]).validate_python( + tool_call, context=ctx + ) # ============================================================================ @@ -616,9 +657,11 @@ def test_callable_encode_decode_behavioral( ty, func, ctx, args, expected, eval_provider ): """Decoded callable is behaviorally equivalent to the original.""" - enc = Encodable.define(ty, ctx) + enc = pydantic.TypeAdapter(Encodable[ty]) with handler(eval_provider): - decoded = enc.decode(enc.encode(func)) + decoded = enc.validate_python( + enc.dump_python(func, mode="json", context=ctx), context=ctx + ) assert decoded(*args) == expected @@ -628,10 +671,10 @@ def test_callable_full_pipeline_behavioral( ty, func, ctx, args, expected, eval_provider ): """Full encode->serialize->deserialize->decode pipeline is behaviorally equivalent.""" - enc = Encodable.define(ty, ctx) - text = enc.serialize(enc.encode(func))[0]["text"] + enc = pydantic.TypeAdapter(Encodable[ty]) + text = json.dumps(enc.dump_python(func, mode="json", context=ctx)) with handler(eval_provider): - decoded = enc.decode(enc.deserialize(text)) + decoded = enc.validate_python(json.loads(text), context=ctx) assert decoded(*args) == expected @@ -673,16 +716,16 @@ def test_callable_full_pipeline_behavioral( @pytest.mark.parametrize("ty,ctx,source,exc_type", CALLABLE_ERROR_CASES) @pytest.mark.parametrize("eval_provider", EVAL_PROVIDERS) def test_callable_decode_rejects_invalid(ty, ctx, source, exc_type, eval_provider): - enc = Encodable.define(ty, ctx) with pytest.raises(exc_type): with handler(eval_provider): - enc.decode(source) + pydantic.TypeAdapter(Encodable[ty]).validate_python(source, context=ctx) def test_callable_encode_non_callable(): - enc = Encodable.define(Callable[..., int], {}) - with pytest.raises(TypeError): - enc.encode("not a callable") + with pytest.raises(Exception): + pydantic.TypeAdapter(Encodable[Callable[..., int]]).dump_python( + "not a callable", mode="json", context={} + ) def test_callable_encode_no_source_no_docstring(): @@ -694,9 +737,10 @@ class _NoDocCallable: def __call__(self): pass - enc = Encodable.define(Callable[..., int], {}) with pytest.raises(ValueError): - enc.encode(_NoDocCallable()) + pydantic.TypeAdapter(Encodable[Callable[..., int]]).dump_python( + _NoDocCallable(), mode="json", context={} + ) # --------------------------------------------------------------------------- @@ -710,7 +754,7 @@ def __call__(self): def _provider_case_marks(case_id: str) -> list[pytest.MarkDecorator]: marks: list[pytest.MarkDecorator] = [] - if case_id.startswith(("list-", "img-", "tool-", "dtc-")): + if "img" in case_id or "tool" in case_id or "dtc" in case_id: marks.append(_provider_response_format_xfail) return marks @@ -733,52 +777,46 @@ def _cases_with_provider_xfails(cases: list[Any]) -> list[Any]: PROVIDER_CASES = _cases_with_provider_xfails(ROUNDTRIP_CASES) -def _encode_tool_spec(tool: Tool[..., Any]) -> dict[str, Any]: - tool_ty: type[Any] = type(tool) - tool_enc: Encodable[Any, Any] = Encodable.define(tool_ty) - tool_spec_obj = tool_enc.encode(tool) - if isinstance(tool_spec_obj, Mapping): - return dict(tool_spec_obj) - elif hasattr(tool_spec_obj, "model_dump"): - return dict(tool_spec_obj.model_dump()) - raise TypeError(f"Unexpected encoded tool spec type: {type(tool_spec_obj)}") - - @requires_openai @pytest.mark.parametrize("ty,_value,ctx", PROVIDER_CASES) def test_litellm_completion_accepts_encodable_response_model_for_supported_types( ty: Any, _value: Any, ctx: Mapping[str, Any] | None ) -> None: - enc = Encodable.define(ty, ctx) - kwargs: dict[str, Any] = { - "model": CHEAP_MODEL, - "messages": [ + enc: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Encodable[ty]) + response = litellm.completion( + model=CHEAP_MODEL, + response_format={ + "type": "json_schema", + "json_schema": { + "name": "response", + "schema": enc.json_schema(), + "strict": True, + }, + }, + messages=[ { "role": "user", "content": f"Return an instance of {getattr(ty, '__name__', repr(ty))}.", } ], - "max_tokens": 200, - } - if enc.enc is not str: - kwargs["response_format"] = enc.enc - response = litellm.completion(**kwargs) - assert response is not None - - content = response.choices[0].message.content + max_tokens=400, + ) + assert isinstance(response, litellm.ModelResponse) + + choice = response.choices[0] + assert isinstance(choice, litellm.Choices) + content = choice.message.content assert content is not None, ( f"Expected content in response for {getattr(ty, '__name__', repr(ty))}" ) - deserialized = enc.deserialize(content) - pydantic.TypeAdapter(enc.enc).validate_python(deserialized) - - decoded = enc.decode(deserialized) - pydantic.TypeAdapter(enc.base).validate_python(decoded) + deserialized = json.loads(content) + decoded = enc.validate_python(deserialized, context=ctx or {}) + pydantic.TypeAdapter(ty).validate_python(decoded) @requires_openai -@pytest.mark.parametrize("ty,_value,ctx", PROVIDER_CASES) +@pytest.mark.parametrize("ty,_value,ctx", ROUNDTRIP_CASES) def test_litellm_completion_accepts_tool_with_type_as_param( ty: Any, _value: Any, ctx: Mapping[str, Any] | None ) -> None: @@ -792,18 +830,21 @@ def _fn(value): _fn.__annotations__ = {"value": ty, "return": None} tool: Tool[..., Any] = Tool.define(_fn) + enc: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter( + Encodable[type(tool)] # type: ignore[misc] + ) response = litellm.completion( model=CHEAP_MODEL, messages=[{"role": "user", "content": "Return hello, do NOT call any tools."}], - tools=[_encode_tool_spec(tool)], + tools=[enc.dump_python(tool, mode="json", context=ctx or {})], tool_choice="none", - max_tokens=200, + max_tokens=400, ) - assert response is not None + assert isinstance(response, litellm.ModelResponse) @requires_openai -@pytest.mark.parametrize("ty,_value,ctx", PROVIDER_CASES) +@pytest.mark.parametrize("ty,_value,ctx", ROUNDTRIP_CASES) def test_litellm_completion_accepts_tool_with_type_as_return( ty: Any, _value: Any, ctx: Mapping[str, Any] | None ) -> None: @@ -817,11 +858,14 @@ def _fn(): _fn.__annotations__ = {"return": ty} tool: Tool[..., Any] = Tool.define(_fn) + enc: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter( + Encodable[type(tool)] # type: ignore[misc] + ) response = litellm.completion( model=CHEAP_MODEL, messages=[{"role": "user", "content": "Return hello, do NOT call any tools."}], - tools=[_encode_tool_spec(tool)], + tools=[enc.dump_python(tool, mode="json", context=ctx or {})], tool_choice="none", - max_tokens=200, + max_tokens=400, ) - assert response is not None + assert isinstance(response, litellm.ModelResponse) diff --git a/tests/test_handlers_llm_evaluation.py b/tests/test_handlers_llm_evaluation.py index 934ba8b5..2cc1e1cd 100644 --- a/tests/test_handlers_llm_evaluation.py +++ b/tests/test_handlers_llm_evaluation.py @@ -1401,7 +1401,6 @@ def run(x: int) -> bool: def test_restricted_blocks_private_attribute_access(): """RestrictedPython blocks access to underscore-prefixed attributes by default.""" - encodable = Encodable.define(Callable[[str], int], {}) source = SynthesizedFunction( module_code="""def get_private(s: str) -> int: return s.__class__.__name__""" @@ -1409,7 +1408,9 @@ def test_restricted_blocks_private_attribute_access(): # Should raise due to restricted attribute access with pytest.raises(Exception): # Could be NameError or AttributeError with handler(RestrictedEvalProvider()): - fn = encodable.decode(source) + fn = pydantic.TypeAdapter(Encodable[Callable[[str], int]]).validate_python( + source, context={} + ) fn("test") @@ -1420,13 +1421,14 @@ def test_restricted_with_custom_policy(): class CustomPolicy(RestrictingNodeTransformer): pass - encodable = Encodable.define(Callable[[int, int], int], {}) source = SynthesizedFunction( module_code="""def add(a: int, b: int) -> int: return a + b""" ) with handler(RestrictedEvalProvider(policy=CustomPolicy)): - fn = encodable.decode(source) + fn = pydantic.TypeAdapter(Encodable[Callable[[int, int], int]]).validate_python( + source, context={} + ) assert fn(2, 3) == 5 @@ -1443,18 +1445,18 @@ def test_builtins_in_env_does_not_bypass_security(): dangerous_ctx = {"__builtins__": builtins.__dict__} # Test 1: open() should not be usable even with __builtins__ in context - encodable_open = Encodable.define(Callable[[str], str], dangerous_ctx) source_open = SynthesizedFunction( module_code="""def read_file(path: str) -> str: return open(path).read()""" ) with pytest.raises(Exception): # Could be NameError, ValueError, or other with handler(RestrictedEvalProvider()): - fn = encodable_open.decode(source_open) + fn = pydantic.TypeAdapter(Encodable[Callable[[str], str]]).validate_python( + source_open, context=dangerous_ctx + ) fn("/etc/passwd") # Test 2: __import__ should not be usable - encodable_import = Encodable.define(Callable[[], str], dangerous_ctx) source_import = SynthesizedFunction( module_code="""def get_os_name() -> str: os = __import__('os') @@ -1462,26 +1464,30 @@ def test_builtins_in_env_does_not_bypass_security(): ) with pytest.raises(Exception): with handler(RestrictedEvalProvider()): - fn = encodable_import.decode(source_import) + fn = pydantic.TypeAdapter(Encodable[Callable[[], str]]).validate_python( + source_import, context=dangerous_ctx + ) fn() # Test 3: Verify safe code still works with dangerous context - encodable_safe = Encodable.define(Callable[[int, int], int], dangerous_ctx) source_safe = SynthesizedFunction( module_code="""def add(a: int, b: int) -> int: return a + b""" ) with handler(RestrictedEvalProvider()): - fn = encodable_safe.decode(source_safe) + fn = pydantic.TypeAdapter(Encodable[Callable[[int, int], int]]).validate_python( + source_safe, context=dangerous_ctx + ) assert fn(2, 3) == 5, "Safe code should still work" # Test 4: Private attribute access should still be blocked - encodable_private = Encodable.define(Callable[[str], str], dangerous_ctx) source_private = SynthesizedFunction( module_code="""def get_class(s: str) -> str: return s.__class__.__name__""" ) with pytest.raises(Exception): with handler(RestrictedEvalProvider()): - fn = encodable_private.decode(source_private) + fn = pydantic.TypeAdapter(Encodable[Callable[[str], str]]).validate_python( + source_private, context=dangerous_ctx + ) fn("test") diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index eec2fc78..69a70331 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -15,6 +15,7 @@ from pathlib import Path import litellm +import pydantic import pytest import tenacity from litellm import ChatCompletionMessageToolCall @@ -38,7 +39,7 @@ call_tool, completion, ) -from effectful.handlers.llm.encoding import Encodable, SynthesizedFunction +from effectful.handlers.llm.encoding import Encodable from effectful.handlers.llm.evaluation import UnsafeEvalProvider from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements @@ -502,14 +503,15 @@ def make_tool_call_response( ) -def make_text_response(content: str) -> ModelResponse: +def make_text_response(content: object) -> ModelResponse: """Create a ModelResponse with text content.""" + encoded = content if isinstance(content, str) else json.dumps(content) return ModelResponse( id="test", choices=[ { "index": 0, - "message": {"role": "assistant", "content": content}, + "message": {"role": "assistant", "content": encoded}, "finish_reason": "stop", } ], @@ -544,8 +546,8 @@ def test_retry_handler_succeeds_on_first_attempt(self): handler(message_sequence_provider), ): message, tool_calls, result = call_assistant( - tools={}, - response_format=Encodable.define(str), + env={}, + response_type=str, model="test-model", ) @@ -574,8 +576,8 @@ def test_retry_handler_retries_on_invalid_tool_call(self): handler(message_sequence_provider), ): message, tool_calls, result = call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -606,8 +608,8 @@ def test_retry_handler_retries_on_unknown_tool(self): handler(message_sequence_provider), ): message, tool_calls, result = call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -633,8 +635,8 @@ def test_retry_handler_exhausts_retries(self): handler(message_sequence_provider), ): call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -660,17 +662,15 @@ def test_retry_handler_with_zero_retries(self): handler(message_sequence_provider), ): call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) def test_retry_handler_valid_tool_call_passes_through(self): """Test that valid tool calls are decoded and returned.""" responses = [ - make_tool_call_response( - "add_numbers", '{"a": {"value": 1}, "b": {"value": 2}}' - ), + make_tool_call_response("add_numbers", '{"a": 1, "b": 2}'), ] mock_handler = MockCompletionHandler(responses) @@ -685,8 +685,8 @@ def test_retry_handler_valid_tool_call_passes_through(self): handler(message_sequence_provider), ): message, tool_calls, result = call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -744,8 +744,8 @@ def test_retry_handler_retries_on_invalid_result(self): """Test that RetryLLMHandler retries when result decoding fails.""" # First response has invalid JSON, second has valid response responses = [ - make_text_response('"not valid for int"'), # Invalid for int - make_text_response('{"value": 42}'), # Valid + make_text_response("not_valid_for_int"), # Invalid for int + make_text_response({"value": 42}), # Valid (wrapped for non-object types) ] mock_handler = MockCompletionHandler(responses) @@ -760,8 +760,8 @@ def test_retry_handler_retries_on_invalid_result(self): handler(message_sequence_provider), ): message, tool_calls, result = call_assistant( - tools={}, - response_format=Encodable.define(int), + env={}, + response_type=int, model="test-model", ) @@ -776,7 +776,7 @@ def test_retry_handler_exhausts_retries_on_result_decoding(self): """Test that RetryLLMHandler raises after exhausting retries on result decoding.""" # All responses have invalid results for int type responses = [ - make_text_response('"not an int"'), + make_text_response("not_an_int"), ] mock_handler = MockCompletionHandler(responses) @@ -792,8 +792,8 @@ def test_retry_handler_exhausts_retries_on_result_decoding(self): handler(message_sequence_provider), ): call_assistant( - tools={}, - response_format=Encodable.define(int), + env={}, + response_type=int, model="test-model", ) @@ -819,8 +819,8 @@ def test_retry_handler_raises_tool_call_decoding_error(self): handler(message_sequence_provider), ): call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -833,7 +833,7 @@ def test_retry_handler_raises_tool_call_decoding_error(self): def test_retry_handler_raises_result_decoding_error(self): """Test that RetryLLMHandler raises ResultDecodingError with correct attributes.""" responses = [ - make_text_response('"not an int"'), + make_text_response("not_an_int"), ] mock_handler = MockCompletionHandler(responses) @@ -849,8 +849,8 @@ def test_retry_handler_raises_result_decoding_error(self): handler(message_sequence_provider), ): call_assistant( - tools={}, - response_format=Encodable.define(int), + env={}, + response_type=int, model="test-model", ) @@ -877,8 +877,8 @@ def test_retry_handler_error_feedback_contains_tool_name(self): handler(message_sequence_provider), ): call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -907,8 +907,8 @@ def test_retry_handler_unknown_tool_error_contains_tool_name(self): handler(message_sequence_provider), ): call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -937,8 +937,8 @@ def test_retry_handler_include_traceback_in_error_feedback(self): handler(message_sequence_provider), ): call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -968,8 +968,8 @@ def test_retry_handler_no_traceback_when_disabled(self): handler(message_sequence_provider), ): call_assistant( - tools={"add_numbers": add_numbers}, - response_format=Encodable.define(str), + env={"add_numbers": add_numbers}, + response_type=str, model="test-model", ) @@ -1056,7 +1056,7 @@ def test_tool_execution_error_not_pruned_from_messages(self): # First call: valid tool call that will fail at runtime # Second call: successful text response responses = [ - make_tool_call_response("failing_tool", '{"x": {"value": 42}}'), + make_tool_call_response("failing_tool", '{"x": 42}'), make_text_response("handled the error"), ] @@ -1069,8 +1069,8 @@ def test_tool_execution_error_not_pruned_from_messages(self): # We need a custom provider that actually calls call_tool class TestProvider(ObjectInterpretation): @implements(call_assistant) - def _call_assistant(self, tools, response_format, model, **kwargs): - return fwd(tools, response_format, model, **kwargs) + def _call_assistant(self, env, response_type, model, **kwargs): + return fwd(env, response_type, model, **kwargs) with ( handler(RetryLLMHandler()), @@ -1079,8 +1079,8 @@ def _call_assistant(self, tools, response_format, model, **kwargs): handler(message_sequence_provider), ): message, tool_calls, result = call_assistant( - tools={"failing_tool": failing_tool}, - response_format=Encodable.define(str), + env={"failing_tool": failing_tool}, + response_type=str, model="test-model", ) @@ -1280,13 +1280,13 @@ def test_synthesized_function_roundtrip(self, request): assert callable(add_func) # Encode it back to SynthesizedFunction - encodable = Encodable.define(Callable[[int, int], int], {}) - encoded = encodable.encode(add_func) - assert isinstance(encoded, SynthesizedFunction) - assert "def " in encoded.module_code + enc = pydantic.TypeAdapter(Encodable[Callable[[int, int], int]]) + encoded = enc.dump_python(add_func, mode="json", context={}) + assert isinstance(encoded, dict) + assert "def " in encoded["module_code"] # Decode it again and verify it still works - decoded = encodable.decode(encoded) + decoded = enc.validate_python(encoded, context={}) assert callable(decoded) assert decoded(5, 7) == 12 @@ -1412,8 +1412,8 @@ def _completion(self_, model, messages, *args, **kwargs): handler({_get_history: lambda: message_sequence}), ): call_assistant( - tools={}, - response_format=Encodable.define(str), + env={}, + response_type=str, model="test-model", ) @@ -1454,14 +1454,14 @@ def _completion(self_, model, messages, *args, **kwargs): ): # First call: input is the latest message (msg_user) resp1, _, _ = call_assistant( - tools={}, - response_format=Encodable.define(str), + env={}, + response_type=str, model="test-model", ) # Second call: input is the first response resp2, _, _ = call_assistant( - tools={}, - response_format=Encodable.define(str), + env={}, + response_type=str, model="test-model", ) @@ -1493,8 +1493,8 @@ def _call_assistant(self_, messages, *args, **kwargs): ): call_assistant( messages=[msg], - tools={}, - response_format=Encodable.define(str), + env={}, + response_type=str, model="test-model", ) @@ -1706,7 +1706,7 @@ def test_messages_pruned_on_tool_execution_error(self): """When a tool error propagates, all messages from that call are pruned.""" # LLM says "call flaky_tool", then tool raises unhandled error responses = [ - make_tool_call_response("flaky_tool", '{"x": {"value": 1}}'), + make_tool_call_response("flaky_tool", '{"x": 1}'), ] mock_handler = MockCompletionHandler(responses) @@ -1755,7 +1755,7 @@ def task_with_tools(instruction: str) -> str: def test_pre_existing_messages_preserved_on_error(self): """Pre-existing messages in the sequence are not pruned when a call fails.""" responses = [ - make_tool_call_response("flaky_tool", '{"x": {"value": 1}}'), + make_tool_call_response("flaky_tool", '{"x": 1}'), ] mock_handler = MockCompletionHandler(responses) diff --git a/tests/test_handlers_llm_template.py b/tests/test_handlers_llm_template.py index 7135f170..ece3058b 100644 --- a/tests/test_handlers_llm_template.py +++ b/tests/test_handlers_llm_template.py @@ -3,6 +3,7 @@ import collections import dataclasses import inspect +import json from dataclasses import dataclass import pytest @@ -134,13 +135,19 @@ def j(self) -> str: # --------------------------------------------------------------------------- -def make_text_response(content: str) -> ModelResponse: +def make_text_response(content: object) -> ModelResponse: + """Create a ModelResponse mimicking real LLM output. + + For ``str`` content the LLM returns plain text (no JSON wrapping). + For non-``str`` content the LLM returns JSON. + """ + encoded = content if isinstance(content, str) else json.dumps(content) return ModelResponse( id="test", choices=[ { "index": 0, - "message": {"role": "assistant", "content": content}, + "message": {"role": "assistant", "content": encoded}, "finish_reason": "stop", } ], @@ -414,8 +421,8 @@ def pick_number(self) -> int: mock = MockCompletionHandler( [ - make_text_response('"not_an_int"'), - make_text_response('{"value": 7}'), + make_text_response("not_an_int"), + make_text_response({"value": 7}), ] ) @@ -559,9 +566,7 @@ def compute(self, question: str) -> str: mock = MockCompletionHandler( [ - make_tool_call_response( - "add", '{"a": {"value": 2}, "b": {"value": 3}}' - ), + make_tool_call_response("add", '{"a": 2, "b": 3}'), make_text_response("The answer is 5"), ] ) @@ -586,9 +591,9 @@ def test_failed_retries_dont_pollute_history(self): mock = MockCompletionHandler( [ # First attempt: invalid result for int - make_text_response('"not_an_int"'), + make_text_response("not_an_int"), # Retry: valid - make_text_response('{"value": 42}'), + make_text_response({"value": 42}), ] ) @@ -1226,7 +1231,7 @@ def convert(feet: int) -> float: with handler(TemplateStringIntp()): assert ( convert(7920) - == 'How many miles is {"value":7920} feet? There are {"value":5280} feet per mile.' + == "How many miles is 7920 feet? There are 5280 feet per mile." ) @@ -1341,7 +1346,7 @@ def convert(feet: int) -> str: with handler(TemplateStringIntp()): assert ( convert(7920) - == 'How many miles is {"value":7920} feet? There are {"value":5280} feet per mile.' + == "How many miles is 7920 feet? There are 5280 feet per mile." ) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 18b158d8..5bb2b979 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -1,4 +1,6 @@ import collections.abc +import dataclasses +import functools import inspect import typing from typing import Literal @@ -7,6 +9,10 @@ from effectful.internals.unification import ( Box, + Substitutions, + TypeEvaluator, + TypeExpressions, + TypeVariable, canonicalize, freetypevars, nested_type, @@ -28,8 +34,58 @@ W = typing.TypeVar("W") +@dataclasses.dataclass +class _Substitute(TypeEvaluator): + """ + Helper class to perform substitution using TypeEvaluator. + Used to test TypeEvaluator's traversal logic against the ground-truth ``substitute`` + """ + + subs: Substitutions + + @classmethod + def substitute(cls, typ: TypeExpressions, subs: Substitutions) -> TypeExpressions: + return cls(subs).evaluate(typ) + + @functools.singledispatchmethod + def evaluate(self, typ: TypeExpressions) -> TypeExpressions: + return super().evaluate(typ) + + @evaluate.register + def _(self, typ: TypeVariable): + if typ in self.subs: + return self.evaluate(self.subs[typ]) + else: + return super().evaluate(typ) + + +@dataclasses.dataclass +class _FreeTypeVars(TypeEvaluator): + """ + Helper class to perform free variable collection using TypeEvaluator. + Used to test TypeEvaluator's traversal logic against the ground-truth ``freetypevars`` + """ + + fvs: set[TypeVariable] = dataclasses.field(default_factory=set) + + @classmethod + def freetypevars(cls, typ: TypeExpressions) -> collections.abc.Set[TypeVariable]: + evaluator = cls() + evaluator.evaluate(typ) + return evaluator.fvs + + @functools.singledispatchmethod + def evaluate(self, typ: TypeExpressions) -> TypeExpressions: + return super().evaluate(typ) + + @evaluate.register + def _(self, typ: TypeVariable): + self.fvs.add(typ) + return super().evaluate(typ) + + @pytest.mark.parametrize( - "typ,fvs", + "typ,expected", [ # Basic cases (T, {T}), @@ -79,8 +135,17 @@ # (collections.abc.Callable[typing.ParamSpec("P"), T], {T}), # Would need to handle ParamSpec ], ) -def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): - assert freetypevars(typ) == fvs +@pytest.mark.parametrize( + "freetypevars_impl", [freetypevars, _FreeTypeVars.freetypevars] +) +def test_freetypevars( + freetypevars_impl: collections.abc.Callable[ + [TypeExpressions], collections.abc.Set[TypeVariable] + ], + typ: TypeExpressions, + expected: collections.abc.Set[typing.TypeVar], +): + assert freetypevars_impl(typ) == expected def test_canonicalize_1(): @@ -249,10 +314,16 @@ class GenericClass[T]: ), ], ) +@pytest.mark.parametrize("substitute_impl", [substitute, _Substitute.substitute]) def test_substitute( - typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type + substitute_impl: collections.abc.Callable[ + [TypeExpressions, Substitutions], TypeExpressions + ], + typ: TypeExpressions, + subs: Substitutions, + expected: TypeExpressions, ): - assert substitute(typ, subs) == expected # type: ignore + assert substitute_impl(typ, subs) == expected @pytest.mark.parametrize(