Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c192c78
Fresh diff
kiranandcode Jan 30, 2026
ea3d25c
remove instructionhandler
eb8680 Jan 16, 2026
f83d312
updated internal interface to make all tests pass
kiranandcode Jan 28, 2026
d8d52e7
fixed caching tests
kiranandcode Jan 28, 2026
a322a35
updated llm.ipynb
kiranandcode Jan 28, 2026
41beb78
removed unnecessarily defensive validation
kiranandcode Jan 29, 2026
1400d19
updated tool call decoding to use concrete type of tool result instea…
kiranandcode Jan 29, 2026
a06296d
updated completions to fix basic type errors
kiranandcode Jan 30, 2026
c47abd6
updated call assistant to handle decoding tool calls
kiranandcode Jan 30, 2026
43e5b78
dropped stale comments
kiranandcode Jan 30, 2026
6bb2b13
moved model and param model back to internals of `completions`
kiranandcode Jan 30, 2026
88da657
added default encodable instance for Callable
kiranandcode Jan 30, 2026
88c65ee
fixed type errors
kiranandcode Jan 30, 2026
18da11b
update to use more structured type for synthesis
kiranandcode Jan 31, 2026
52df3f6
updated callable encoding tests
kiranandcode Jan 31, 2026
4b10ac3
s/TypeError/NotImplementedError
kiranandcode Jan 31, 2026
2a8af94
Merge branch 'master' into kg-encodable-default
kiranandcode Jan 31, 2026
2b4449a
simplified smart constructor
kiranandcode Jan 31, 2026
553450f
bare callables not allowed
kiranandcode Jan 31, 2026
aac0eb9
droped synthesis and removed encoding_instructions
kiranandcode Jan 31, 2026
5b3a559
fixed imports
kiranandcode Jan 31, 2026
711b27d
fixed imports and tests
kiranandcode Jan 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 270 additions & 1 deletion effectful/handlers/llm/encoding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import ast
import base64
import inspect
import io
import textwrap
import types
import typing
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from dataclasses import dataclass
from types import CodeType
from typing import Any

import pydantic
Expand All @@ -13,6 +18,7 @@
)
from PIL import Image

import effectful.handlers.llm.evaluation as evaluation
from effectful.ops.semantics import _simple_type
from effectful.ops.syntax import _CustomSingleDispatchCallable
from effectful.ops.types import Operation, Term
Expand Down Expand Up @@ -253,6 +259,236 @@ def deserialize(self, serialized_value: str) -> typing.Any:
return typing.cast(typing.Any, adapter.validate_json(serialized_value))


def _format_callable_type(callable_type: type[Callable]) -> str:
"""Format a Callable type annotation as a string for LLM instructions."""
args = typing.get_args(callable_type)
if not args:
return "Callable"

# Callable[[arg1, arg2, ...], return_type]
if len(args) >= 2:
param_types = args[0]
return_type = args[-1]

if param_types is ...:
params_str = "..."
elif isinstance(param_types, (list, tuple)):
params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types)
else:
params_str = str(param_types)

return_str = getattr(return_type, "__name__", str(return_type))
return f"Callable[[{params_str}], {return_str}]"

return str(callable_type)


class SynthesizedFunction(pydantic.BaseModel):
"""Structured output for function synthesis.

Pydantic model representing synthesized code with function name and module code.
"""

module_code: str = pydantic.Field(
...,
description="Complete Python module code (no imports needed)",
)


def _create_typed_synthesized_function(
callable_type: type[Callable],
) -> type[SynthesizedFunction]:
"""Create a SynthesizedFunction subclass with type signature in the model description.

Uses pydantic.create_model to ensure the description is included in the JSON schema
sent to the LLM, informing it of the expected function signature.
"""
type_signature = _format_callable_type(callable_type)

description = f"""Given the specification above, generate a Python function satisfying the following specification and type signature.

<signature>{type_signature}</signature>

<instructions>
1. Produce one block of Python code.
2. The function MUST have type annotations for all parameters and the return type.
3. The function definition must be the LAST statement - do not add any code after it.
4. Do not include usage examples or function calls.
</instructions>
"""

# Use pydantic.create_model to create a proper model with the description
# The __doc__ becomes the model's description in the JSON schema
model = pydantic.create_model(
"TypedSynthesizedFunction",
__base__=SynthesizedFunction,
__doc__=description,
)
return model


def _validate_signature_ast(
func_ast: ast.FunctionDef | ast.AsyncFunctionDef,
expected_params: list[type] | None,
) -> None:
"""Validate the function signature from AST before execution."""
if expected_params is not None:
ast_params = func_ast.args.args + func_ast.args.posonlyargs
if len(ast_params) != len(expected_params):
raise ValueError(
f"decode() expected function with {len(expected_params)} parameters, "
f"got {len(ast_params)}"
)


def _validate_signature_callable(
func: Callable,
expected_params: list[type] | None,
expected_return: type,
) -> None:
"""Validate the function signature from runtime callable after execution.

The synthesized function must have type annotations for parameters and return type.
"""
sig = inspect.signature(func)

if expected_params is not None:
actual_params = list(sig.parameters.values())
if len(actual_params) != len(expected_params):
raise ValueError(
f"decode() expected function with {len(expected_params)} parameters, "
f"got {len(actual_params)}"
)

actual_return = sig.return_annotation
if actual_return is inspect.Parameter.empty:
raise ValueError(
"decode() requires synthesized function to have a return type annotation"
)

expected_name = getattr(expected_return, "__name__", str(expected_return))
actual_name = getattr(actual_return, "__name__", str(actual_return))
if expected_name != actual_name:
raise ValueError(
f"decode() expected function with return type {expected_name}, "
f"got {actual_name}"
)


@dataclass
class CallableEncodable(Encodable[Callable, SynthesizedFunction]):
base: type[Callable]
enc: type[SynthesizedFunction]
ctx: Mapping[str, Any]
expected_params: list[type] | None = None
expected_return: type | None = None # None means decode is disabled

def encode(self, t: Callable) -> SynthesizedFunction:
# (https://github.com/python/mypy/issues/14928)
if not isinstance(t, Callable): # type: ignore
raise TypeError(f"Expected callable, got {type(t)}")

try:
source = inspect.getsource(t)
except (OSError, TypeError):
source = None

if source:
return self.enc(module_code=textwrap.dedent(source))

# Source not available - create stub from name, signature, and docstring
# This is useful for builtins and C extensions
name = getattr(t, "__name__", None)
if not name:
raise RuntimeError(
f"Cannot encode callable {t}: no source code and no __name__"
)

try:
sig = inspect.signature(t)
sig_str = str(sig)
except (ValueError, TypeError):
# Some builtins don't have inspectable signatures
sig_str = "(...)"

docstring = inspect.getdoc(t)
if not docstring:
raise RuntimeError(
f"Cannot encode callable {t}: no source code and no docstring"
)

# Format as a stub function with docstring
stub_code = f'''def {name}{sig_str}:
"""{docstring}"""
...
'''
return self.enc(module_code=stub_code)

def decode(self, encoded_value: SynthesizedFunction) -> Callable:
# Decode requires a concrete return type for synthesis
if self.expected_return is None:
raise TypeError(
"Cannot decode/synthesize callable without a concrete type signature. "
"Use Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType] "
"with a concrete return type (not Any)."
)

filename = f"<synthesis:{id(self)}>"

module_code = encoded_value.module_code

# Parse and validate AST before execution
module: ast.AST = evaluation.parse(module_code, filename)

if not isinstance(module, ast.Module) or not module.body:
raise ValueError(
"decode() requires module code with at least one statement."
)

last_stmt = module.body[-1]
if not isinstance(last_stmt, ast.FunctionDef):
raise ValueError(
f"decode() requires the last statement to be a function definition, "
f"got {type(last_stmt).__name__}"
)

# Validate signature from AST before execution
_validate_signature_ast(last_stmt, self.expected_params)

# Compile and execute
# https://docs.python.org/3/library/functions.html#exec
g: MutableMapping[str, Any] = {}
g.update(self.ctx or {})

bytecode: CodeType = evaluation.compile(module, filename)
evaluation.exec(bytecode, g)

func_name = last_stmt.name
if func_name not in g:
raise ValueError(
f"decode() expected function '{func_name}' to be defined in globals"
)

result = g[func_name]
if not callable(result):
raise ValueError(
f"decode() expected '{func_name}' to be callable, got {type(result)}"
)

# Validate signature from runtime callable after execution
_validate_signature_callable(result, self.expected_params, self.expected_return)

return result

def serialize(
self, encoded_value: SynthesizedFunction
) -> Sequence[OpenAIMessageContentListBlock]:
return [{"type": "text", "text": encoded_value.model_dump_json()}]

def deserialize(self, serialized_value: str) -> SynthesizedFunction:
return SynthesizedFunction.model_validate_json(serialized_value)


@Encodable.define.register(object)
def _encodable_object[T, U](
ty: type[T], ctx: Mapping[str, Any] | None
Expand Down Expand Up @@ -355,3 +591,36 @@ def _encodable_list[T, U](
return typing.cast(
Encodable[T, U], ListEncodable(ty, encoded_ty, ctx, has_image, element_encoder)
)


@Encodable.define.register(Callable)
def _encodable_callable(
ty: type[Callable], ctx: Mapping[str, Any] | None
) -> Encodable[Callable, SynthesizedFunction]:
ctx = ctx or {}

type_args = typing.get_args(ty)

# Bare Callable without type args - allow encoding but disable decode
# this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type)
if not type_args:
assert ty is types.FunctionType, f"Callable must have type signatures {ty}"
typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type]
return CallableEncodable(ty, typed_enc, ctx)

if len(type_args) < 2:
raise TypeError(
f"Callable type signature incomplete: {ty}. "
"Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]."
)

param_types, expected_return = type_args[0], type_args[-1]

typed_enc = _create_typed_synthesized_function(ty)

# Ellipsis means any params, skip param validation
expected_params: list[type] | None = None
if param_types is not ... and isinstance(param_types, (list, tuple)):
expected_params = list(param_types)

return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return)
88 changes: 88 additions & 0 deletions effectful/handlers/llm/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import ast
import builtins
import linecache
import typing
from types import CodeType
from typing import Any

from effectful.ops.syntax import ObjectInterpretation, defop, implements


@defop
def parse(source: str, filename: str) -> ast.Module:
"""
Parse source text into an AST.

source: The Python source code to parse.
filename: The filename recorded in the resulting AST for tracebacks and tooling.

Returns the parsed AST.
"""
raise NotImplementedError(
"An eval provider must be installed in order to parse code."
)


@defop
def compile(module: ast.Module, filename: str) -> CodeType:
"""
Compile an AST into a Python code object.

module: The AST to compile (typically produced by parse()).
filename: The filename recorded in the resulting code object (CodeType.co_filename), used in tracebacks and by inspect.getsource().

Returns the compiled code object.
"""
raise NotImplementedError(
"An eval provider must be installed in order to compile code."
)


@defop
def exec(
bytecode: CodeType,
env: dict[str, Any],
) -> None:
"""
Execute a compiled code object.

bytecode: A code object to execute (typically produced by compile()).
env: The namespace mapping used during execution.
"""
raise NotImplementedError(
"An eval provider must be installed in order to execute code."
)


class UnsafeEvalProvider(ObjectInterpretation):
"""UNSAFE provider that handles parse, comple and exec operations
by shelling out to python *without* any further checks. Only use for testing."""

@implements(parse)
def parse(self, source: str, filename: str) -> ast.Module:
# Cache source under `filename` so inspect.getsource() can retrieve it later.
# inspect uses f.__code__.co_filename -> linecache.getlines(filename)
linecache.cache[filename] = (
len(source),
None,
source.splitlines(True),
filename,
)

return ast.parse(source, filename=filename, mode="exec")

@implements(compile)
def compile(self, module: ast.AST, filename: str) -> CodeType:
return builtins.compile(typing.cast(typing.Any, module), filename, "exec")

@implements(exec)
def exec(
self,
bytecode: CodeType,
env: dict[str, Any],
) -> None:
# Ensure builtins exist in the execution environment.
env.setdefault("__builtins__", __builtins__)

# Execute module-style so top-level defs land in `env`.
builtins.exec(bytecode, env, env)
Loading
Loading