diff --git a/effectful/handlers/llm/encoding.py b/effectful/handlers/llm/encoding.py index 65c4491c..d0796d6c 100644 --- a/effectful/handlers/llm/encoding.py +++ b/effectful/handlers/llm/encoding.py @@ -1,9 +1,14 @@ +import ast import base64 +import inspect import io +import textwrap +import types import typing from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Mapping, MutableMapping, Sequence from dataclasses import dataclass +from types import CodeType from typing import Any import pydantic @@ -13,6 +18,7 @@ ) from PIL import Image +import effectful.handlers.llm.evaluation as evaluation from effectful.ops.semantics import _simple_type from effectful.ops.syntax import _CustomSingleDispatchCallable from effectful.ops.types import Operation, Term @@ -253,6 +259,236 @@ def deserialize(self, serialized_value: str) -> typing.Any: return typing.cast(typing.Any, adapter.validate_json(serialized_value)) +def _format_callable_type(callable_type: type[Callable]) -> str: + """Format a Callable type annotation as a string for LLM instructions.""" + args = typing.get_args(callable_type) + if not args: + return "Callable" + + # Callable[[arg1, arg2, ...], return_type] + if len(args) >= 2: + param_types = args[0] + return_type = args[-1] + + if param_types is ...: + params_str = "..." + elif isinstance(param_types, (list, tuple)): + params_str = ", ".join(getattr(t, "__name__", str(t)) for t in param_types) + else: + params_str = str(param_types) + + return_str = getattr(return_type, "__name__", str(return_type)) + return f"Callable[[{params_str}], {return_str}]" + + return str(callable_type) + + +class SynthesizedFunction(pydantic.BaseModel): + """Structured output for function synthesis. + + Pydantic model representing synthesized code with function name and module code. + """ + + module_code: str = pydantic.Field( + ..., + description="Complete Python module code (no imports needed)", + ) + + +def _create_typed_synthesized_function( + callable_type: type[Callable], +) -> type[SynthesizedFunction]: + """Create a SynthesizedFunction subclass with type signature in the model description. + + Uses pydantic.create_model to ensure the description is included in the JSON schema + sent to the LLM, informing it of the expected function signature. + """ + type_signature = _format_callable_type(callable_type) + + description = f"""Given the specification above, generate a Python function satisfying the following specification and type signature. + +{type_signature} + + +1. Produce one block of Python code. +2. The function MUST have type annotations for all parameters and the return type. +3. The function definition must be the LAST statement - do not add any code after it. +4. Do not include usage examples or function calls. + +""" + + # Use pydantic.create_model to create a proper model with the description + # The __doc__ becomes the model's description in the JSON schema + model = pydantic.create_model( + "TypedSynthesizedFunction", + __base__=SynthesizedFunction, + __doc__=description, + ) + return model + + +def _validate_signature_ast( + func_ast: ast.FunctionDef | ast.AsyncFunctionDef, + expected_params: list[type] | None, +) -> None: + """Validate the function signature from AST before execution.""" + if expected_params is not None: + ast_params = func_ast.args.args + func_ast.args.posonlyargs + if len(ast_params) != len(expected_params): + raise ValueError( + f"decode() expected function with {len(expected_params)} parameters, " + f"got {len(ast_params)}" + ) + + +def _validate_signature_callable( + func: Callable, + expected_params: list[type] | None, + expected_return: type, +) -> None: + """Validate the function signature from runtime callable after execution. + + The synthesized function must have type annotations for parameters and return type. + """ + sig = inspect.signature(func) + + if expected_params is not None: + actual_params = list(sig.parameters.values()) + if len(actual_params) != len(expected_params): + raise ValueError( + f"decode() expected function with {len(expected_params)} parameters, " + f"got {len(actual_params)}" + ) + + actual_return = sig.return_annotation + if actual_return is inspect.Parameter.empty: + raise ValueError( + "decode() requires synthesized function to have a return type annotation" + ) + + expected_name = getattr(expected_return, "__name__", str(expected_return)) + actual_name = getattr(actual_return, "__name__", str(actual_return)) + if expected_name != actual_name: + raise ValueError( + f"decode() expected function with return type {expected_name}, " + f"got {actual_name}" + ) + + +@dataclass +class CallableEncodable(Encodable[Callable, SynthesizedFunction]): + base: type[Callable] + enc: type[SynthesizedFunction] + ctx: Mapping[str, Any] + expected_params: list[type] | None = None + expected_return: type | None = None # None means decode is disabled + + def encode(self, t: Callable) -> SynthesizedFunction: + # (https://github.com/python/mypy/issues/14928) + if not isinstance(t, Callable): # type: ignore + raise TypeError(f"Expected callable, got {type(t)}") + + try: + source = inspect.getsource(t) + except (OSError, TypeError): + source = None + + if source: + return self.enc(module_code=textwrap.dedent(source)) + + # Source not available - create stub from name, signature, and docstring + # This is useful for builtins and C extensions + name = getattr(t, "__name__", None) + if not name: + raise RuntimeError( + f"Cannot encode callable {t}: no source code and no __name__" + ) + + try: + sig = inspect.signature(t) + sig_str = str(sig) + except (ValueError, TypeError): + # Some builtins don't have inspectable signatures + sig_str = "(...)" + + docstring = inspect.getdoc(t) + if not docstring: + raise RuntimeError( + f"Cannot encode callable {t}: no source code and no docstring" + ) + + # Format as a stub function with docstring + stub_code = f'''def {name}{sig_str}: + """{docstring}""" + ... +''' + return self.enc(module_code=stub_code) + + def decode(self, encoded_value: SynthesizedFunction) -> Callable: + # Decode requires a concrete return type for synthesis + if self.expected_return is None: + raise TypeError( + "Cannot decode/synthesize callable without a concrete type signature. " + "Use Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType] " + "with a concrete return type (not Any)." + ) + + filename = f"" + + module_code = encoded_value.module_code + + # Parse and validate AST before execution + module: ast.AST = evaluation.parse(module_code, filename) + + if not isinstance(module, ast.Module) or not module.body: + raise ValueError( + "decode() requires module code with at least one statement." + ) + + last_stmt = module.body[-1] + if not isinstance(last_stmt, ast.FunctionDef): + raise ValueError( + f"decode() requires the last statement to be a function definition, " + f"got {type(last_stmt).__name__}" + ) + + # Validate signature from AST before execution + _validate_signature_ast(last_stmt, self.expected_params) + + # Compile and execute + # https://docs.python.org/3/library/functions.html#exec + g: MutableMapping[str, Any] = {} + g.update(self.ctx or {}) + + bytecode: CodeType = evaluation.compile(module, filename) + evaluation.exec(bytecode, g) + + func_name = last_stmt.name + if func_name not in g: + raise ValueError( + f"decode() expected function '{func_name}' to be defined in globals" + ) + + result = g[func_name] + if not callable(result): + raise ValueError( + f"decode() expected '{func_name}' to be callable, got {type(result)}" + ) + + # Validate signature from runtime callable after execution + _validate_signature_callable(result, self.expected_params, self.expected_return) + + return result + + def serialize( + self, encoded_value: SynthesizedFunction + ) -> Sequence[OpenAIMessageContentListBlock]: + return [{"type": "text", "text": encoded_value.model_dump_json()}] + + def deserialize(self, serialized_value: str) -> SynthesizedFunction: + return SynthesizedFunction.model_validate_json(serialized_value) + + @Encodable.define.register(object) def _encodable_object[T, U]( ty: type[T], ctx: Mapping[str, Any] | None @@ -355,3 +591,36 @@ def _encodable_list[T, U]( return typing.cast( Encodable[T, U], ListEncodable(ty, encoded_ty, ctx, has_image, element_encoder) ) + + +@Encodable.define.register(Callable) +def _encodable_callable( + ty: type[Callable], ctx: Mapping[str, Any] | None +) -> Encodable[Callable, SynthesizedFunction]: + ctx = ctx or {} + + type_args = typing.get_args(ty) + + # Bare Callable without type args - allow encoding but disable decode + # this occurs when decoding the result of Tools which return callable (need to Encodable.define(return_type) for return type) + if not type_args: + assert ty is types.FunctionType, f"Callable must have type signatures {ty}" + typed_enc = _create_typed_synthesized_function(Callable[..., typing.Any]) # type: ignore[arg-type] + return CallableEncodable(ty, typed_enc, ctx) + + if len(type_args) < 2: + raise TypeError( + f"Callable type signature incomplete: {ty}. " + "Expected Callable[[ParamTypes...], ReturnType] or Callable[..., ReturnType]." + ) + + param_types, expected_return = type_args[0], type_args[-1] + + typed_enc = _create_typed_synthesized_function(ty) + + # Ellipsis means any params, skip param validation + expected_params: list[type] | None = None + if param_types is not ... and isinstance(param_types, (list, tuple)): + expected_params = list(param_types) + + return CallableEncodable(ty, typed_enc, ctx, expected_params, expected_return) diff --git a/effectful/handlers/llm/evaluation.py b/effectful/handlers/llm/evaluation.py new file mode 100644 index 00000000..271e55f6 --- /dev/null +++ b/effectful/handlers/llm/evaluation.py @@ -0,0 +1,88 @@ +import ast +import builtins +import linecache +import typing +from types import CodeType +from typing import Any + +from effectful.ops.syntax import ObjectInterpretation, defop, implements + + +@defop +def parse(source: str, filename: str) -> ast.Module: + """ + Parse source text into an AST. + + source: The Python source code to parse. + filename: The filename recorded in the resulting AST for tracebacks and tooling. + + Returns the parsed AST. + """ + raise NotImplementedError( + "An eval provider must be installed in order to parse code." + ) + + +@defop +def compile(module: ast.Module, filename: str) -> CodeType: + """ + Compile an AST into a Python code object. + + module: The AST to compile (typically produced by parse()). + filename: The filename recorded in the resulting code object (CodeType.co_filename), used in tracebacks and by inspect.getsource(). + + Returns the compiled code object. + """ + raise NotImplementedError( + "An eval provider must be installed in order to compile code." + ) + + +@defop +def exec( + bytecode: CodeType, + env: dict[str, Any], +) -> None: + """ + Execute a compiled code object. + + bytecode: A code object to execute (typically produced by compile()). + env: The namespace mapping used during execution. + """ + raise NotImplementedError( + "An eval provider must be installed in order to execute code." + ) + + +class UnsafeEvalProvider(ObjectInterpretation): + """UNSAFE provider that handles parse, comple and exec operations + by shelling out to python *without* any further checks. Only use for testing.""" + + @implements(parse) + def parse(self, source: str, filename: str) -> ast.Module: + # Cache source under `filename` so inspect.getsource() can retrieve it later. + # inspect uses f.__code__.co_filename -> linecache.getlines(filename) + linecache.cache[filename] = ( + len(source), + None, + source.splitlines(True), + filename, + ) + + return ast.parse(source, filename=filename, mode="exec") + + @implements(compile) + def compile(self, module: ast.AST, filename: str) -> CodeType: + return builtins.compile(typing.cast(typing.Any, module), filename, "exec") + + @implements(exec) + def exec( + self, + bytecode: CodeType, + env: dict[str, Any], + ) -> None: + # Ensure builtins exist in the execution environment. + env.setdefault("__builtins__", __builtins__) + + # Execute module-style so top-level defs land in `env`. + builtins.exec(bytecode, env, env) diff --git a/effectful/handlers/llm/synthesis.py b/effectful/handlers/llm/synthesis.py deleted file mode 100644 index 3db32fd7..00000000 --- a/effectful/handlers/llm/synthesis.py +++ /dev/null @@ -1,14 +0,0 @@ -from effectful.ops.syntax import ObjectInterpretation - - -class SynthesisError(Exception): - """Raised when program synthesis fails.""" - - def __init__(self, message, code=None): - super().__init__(message) - self.code = code - - -class ProgramSynthesis(ObjectInterpretation): - def __init__(self, *args, **kwargs): - raise NotImplementedError diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_adder_function.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_adder_function.json new file mode 100644 index 00000000..af02c2d6 --- /dev/null +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_adder_function.json @@ -0,0 +1,44 @@ +{ + "id": "chatcmpl-D3t07MhXj7upN73HTU9CZuGWVI26D", + "created": 1769818779, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_1590f93f9d", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "{\"value\":{\"module_code\":\"def add_two_numbers(a: int, b: int) -> int:\\n return a + b\\n\"}}", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "provider_specific_fields": { + "refusal": null + }, + "annotations": [] + }, + "provider_specific_fields": {} + } + ], + "usage": { + "completion_tokens": 33, + "prompt_tokens": 605, + "total_tokens": 638, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + "text_tokens": null, + "image_tokens": null + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + "text_tokens": null, + "image_tokens": null + } + }, + "service_tier": "default" +} \ No newline at end of file diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_bool_return_type.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_bool_return_type.json new file mode 100644 index 00000000..334c226d --- /dev/null +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_bool_return_type.json @@ -0,0 +1,44 @@ +{ + "id": "chatcmpl-D3t0CeUsrUYRG4uFENTc97eSR8dNV", + "created": 1769818784, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_1590f93f9d", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "{\"value\":{\"module_code\":\"def is_even(number: int) -> bool:\\n return number % 2 == 0\\n\"}}", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "provider_specific_fields": { + "refusal": null + }, + "annotations": [] + }, + "provider_specific_fields": {} + } + ], + "usage": { + "completion_tokens": 32, + "prompt_tokens": 603, + "total_tokens": 635, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + "text_tokens": null, + "image_tokens": null + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + "text_tokens": null, + "image_tokens": null + } + }, + "service_tier": "default" +} \ No newline at end of file diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter.json new file mode 100644 index 00000000..b0c5f9bf --- /dev/null +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter.json @@ -0,0 +1,53 @@ +{ + "id": "chatcmpl-D3t094sc9fvHTU8EcpX5U3jmycshT", + "created": 1769818781, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_1590f93f9d", + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "message": { + "content": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": "{\"char\":\"a\"}", + "name": "create_function" + }, + "id": "call_yHyAUG2fywMoQVlgDgfW5tni", + "type": "function" + } + ], + "function_call": null, + "provider_specific_fields": { + "refusal": null + }, + "annotations": [] + }, + "provider_specific_fields": {} + } + ], + "usage": { + "completion_tokens": 14, + "prompt_tokens": 591, + "total_tokens": 605, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + "text_tokens": null, + "image_tokens": null + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + "text_tokens": null, + "image_tokens": null + } + }, + "service_tier": "default" +} \ No newline at end of file diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter_1.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter_1.json new file mode 100644 index 00000000..f1dc5cb2 --- /dev/null +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter_1.json @@ -0,0 +1,44 @@ +{ + "id": "chatcmpl-D3t09UG6mHSIp7eUBauqT1qp1D1Cz", + "created": 1769818781, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_1590f93f9d", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "{\"value\":{\"module_code\":\"def count_character_a(input_string: str) -> int:\\n return input_string.count('a')\\n\"}}", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "provider_specific_fields": { + "refusal": null + }, + "annotations": [] + }, + "provider_specific_fields": {} + } + ], + "usage": { + "completion_tokens": 34, + "prompt_tokens": 593, + "total_tokens": 627, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + "text_tokens": null, + "image_tokens": null + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + "text_tokens": null, + "image_tokens": null + } + }, + "service_tier": "default" +} \ No newline at end of file diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter_2.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter_2.json new file mode 100644 index 00000000..9970c3d0 --- /dev/null +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_counter_with_parameter_2.json @@ -0,0 +1,44 @@ +{ + "id": "chatcmpl-D3t0A8TB6nGWKFLinAIo6MtI8A9C7", + "created": 1769818782, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_1590f93f9d", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "{\"value\":{\"module_code\":\"def count_character_a(input_string: str) -> int:\\n return input_string.count('a')\\n\"}}", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "provider_specific_fields": { + "refusal": null + }, + "annotations": [] + }, + "provider_specific_fields": {} + } + ], + "usage": { + "completion_tokens": 34, + "prompt_tokens": 641, + "total_tokens": 675, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + "text_tokens": null, + "image_tokens": null + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + "text_tokens": null, + "image_tokens": null + } + }, + "service_tier": "default" +} \ No newline at end of file diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_string_processor.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_string_processor.json new file mode 100644 index 00000000..d49188fd --- /dev/null +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_string_processor.json @@ -0,0 +1,44 @@ +{ + "id": "chatcmpl-D3t08sdCPU3BvwMIBHiYaLfneCLCQ", + "created": 1769818780, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_1590f93f9d", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "{\"value\":{\"module_code\":\"def convert_to_uppercase_with_exclamations(input_string: str) -> str:\\n return input_string.upper() + '!!' \\n\"}}", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "provider_specific_fields": { + "refusal": null + }, + "annotations": [] + }, + "provider_specific_fields": {} + } + ], + "usage": { + "completion_tokens": 41, + "prompt_tokens": 602, + "total_tokens": 643, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + "text_tokens": null, + "image_tokens": null + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + "text_tokens": null, + "image_tokens": null + } + }, + "service_tier": "default" +} \ No newline at end of file diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_three_params.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_three_params.json new file mode 100644 index 00000000..4ae29502 --- /dev/null +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesize_three_params.json @@ -0,0 +1,44 @@ +{ + "id": "chatcmpl-D3t0DNoQDXPHAYMTMFDFalBuDkS6W", + "created": 1769818785, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_1590f93f9d", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "{\"value\":{\"module_code\":\"def multiply_three_numbers(a: int, b: int, c: int) -> int:\\n return a * b * c\\n\"}}", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "provider_specific_fields": { + "refusal": null + }, + "annotations": [] + }, + "provider_specific_fields": {} + } + ], + "usage": { + "completion_tokens": 39, + "prompt_tokens": 605, + "total_tokens": 644, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + "text_tokens": null, + "image_tokens": null + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + "text_tokens": null, + "image_tokens": null + } + }, + "service_tier": "default" +} \ No newline at end of file diff --git a/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesized_function_roundtrip.json b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesized_function_roundtrip.json new file mode 100644 index 00000000..65a95290 --- /dev/null +++ b/tests/fixtures/tests_test_handlers_llm_provider.py__TestCallableSynthesis__test_synthesized_function_roundtrip.json @@ -0,0 +1,44 @@ +{ + "id": "chatcmpl-D3t0BrdzBcHyqvKaYpMKgolp3bdRO", + "created": 1769818783, + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_1590f93f9d", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "{\"value\":{\"module_code\":\"def add_two_integers(a: int, b: int) -> int:\\n return a + b\\n\"}}", + "role": "assistant", + "tool_calls": null, + "function_call": null, + "provider_specific_fields": { + "refusal": null + }, + "annotations": [] + }, + "provider_specific_fields": {} + } + ], + "usage": { + "completion_tokens": 35, + "prompt_tokens": 605, + "total_tokens": 640, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + "text_tokens": null, + "image_tokens": null + }, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + "text_tokens": null, + "image_tokens": null + } + }, + "service_tier": "default" +} \ No newline at end of file diff --git a/tests/test_handlers_llm.py b/tests/test_handlers_llm.py index 2c98a650..c4c8be2c 100644 --- a/tests/test_handlers_llm.py +++ b/tests/test_handlers_llm.py @@ -1,10 +1,7 @@ from collections.abc import Callable from typing import Annotated -import pytest - from effectful.handlers.llm import Template -from effectful.handlers.llm.synthesis import ProgramSynthesis from effectful.handlers.llm.template import IsRecursive from effectful.ops.semantics import NotHandled, handler from effectful.ops.syntax import ObjectInterpretation, implements @@ -119,22 +116,6 @@ def test_primes_decode_int(): assert isinstance(result, int) -@pytest.mark.xfail(reason="Synthesis handler not yet implemented") -def test_count_char_with_program_synthesis(): - """Test the count_char template with program synthesis.""" - mock_code = """ -def count_occurrences(s): - return s.count('a') -""" - mock_provider = SingleResponseLLMProvider(mock_code) - - with handler(mock_provider), handler(ProgramSynthesis()): - count_a = count_char("a") - assert callable(count_a) - assert count_a("banana") == 3 - assert count_a("cherry") == 0 - - class FailingThenSucceedingProvider[T](ObjectInterpretation): """Mock provider that fails a specified number of times before succeeding.""" diff --git a/tests/test_handlers_llm_encoding.py b/tests/test_handlers_llm_encoding.py index dd21436a..775d5f72 100644 --- a/tests/test_handlers_llm_encoding.py +++ b/tests/test_handlers_llm_encoding.py @@ -1,11 +1,14 @@ +from collections.abc import Callable from dataclasses import asdict, dataclass -from typing import NamedTuple, TypedDict +from typing import Any, NamedTuple, TypedDict import pydantic import pytest from PIL import Image -from effectful.handlers.llm.encoding import Encodable +from effectful.handlers.llm.encoding import Encodable, SynthesizedFunction +from effectful.handlers.llm.evaluation import UnsafeEvalProvider +from effectful.ops.semantics import handler from effectful.ops.types import Operation, Term @@ -718,3 +721,380 @@ class Person(pydantic.BaseModel): assert decoded_from_model == person assert isinstance(decoded_from_model, Person) assert isinstance(decoded_from_model.address, Address) + + +class TestCallableEncodable: + """Tests for CallableEncodable - encoding/decoding callables as SynthesizedFunction.""" + + def test_encode_decode_function(self): + def add(a: int, b: int) -> int: + return a + b + + # Use typed Callable with matching signature + encodable = Encodable.define(Callable[[int, int], int], {}) + encoded = encodable.encode(add) + assert isinstance(encoded, SynthesizedFunction) + assert "def add" in encoded.module_code + assert "return a + b" in encoded.module_code + + with handler(UnsafeEvalProvider()): + decoded = encodable.decode(encoded) + assert callable(decoded) + assert decoded(2, 3) == 5 + assert decoded.__name__ == "add" + + def test_decode_with_ellipsis_params(self): + # Callable[..., int] allows any params but validates return type + encodable = Encodable.define(Callable[..., int], {}) + + # Test decoding a function - must end with function def with return annotation + func_source = SynthesizedFunction( + module_code="def double(x) -> int:\n return x * 2" + ) + with handler(UnsafeEvalProvider()): + decoded = encodable.decode(func_source) + assert callable(decoded) + assert decoded(5) == 10 + + def test_decode_with_env(self): + # Test decoding a function that uses env variables + encodable = Encodable.define(Callable[..., int], {"factor": 3}) + source = SynthesizedFunction( + module_code="""def multiply(x) -> int: + return x * factor""" + ) + + with handler(UnsafeEvalProvider()): + decoded = encodable.decode(source) + assert callable(decoded) + assert decoded(4) == 12 + + def test_encode_non_callable_raises(self): + encodable = Encodable.define(Callable[..., int], {}) + with pytest.raises(TypeError, match="Expected callable"): + encodable.encode("not a callable") + + def test_encode_builtin_creates_stub(self): + encodable = Encodable.define(Callable[..., int], {}) + # Built-in functions don't have source code but have docstrings + encoded = encodable.encode(len) + assert isinstance(encoded, SynthesizedFunction) + assert "def len" in encoded.module_code + assert '"""' in encoded.module_code # docstring present + assert "..." in encoded.module_code # stub body + + def test_encode_builtin_no_docstring_raises(self): + # Create a callable without source and without docstring + class NoDocCallable: + __name__ = "nodoc" + __doc__ = None + + def __call__(self): + pass + + encodable = Encodable.define(Callable[..., int], {}) + with pytest.raises(RuntimeError, match="no source code and no docstring"): + encodable.encode(NoDocCallable()) + + def test_decode_no_function_at_end_raises(self): + encodable = Encodable.define(Callable[..., int], {}) + # Source code where last statement is not a function definition + source = SynthesizedFunction(module_code="x = 42") + with pytest.raises( + ValueError, match="last statement to be a function definition" + ): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_decode_multiple_functions_uses_last(self): + encodable = Encodable.define(Callable[..., int], {}) + # Source code that defines multiple functions - should use the last one + source = SynthesizedFunction( + module_code="""def foo() -> int: + return 1 + +def bar() -> int: + return 2""" + ) + with handler(UnsafeEvalProvider()): + decoded = encodable.decode(source) + assert callable(decoded) + assert decoded.__name__ == "bar" + assert decoded() == 2 + + def test_decode_class_raises(self): + encodable = Encodable.define(Callable[..., int], {}) + # Classes are callable but the last statement must be a function definition + source = SynthesizedFunction( + module_code="""class Greeter: + def __init__(self, name): + self.name = name + + def greet(self): + return f"Hello, {self.name}!\"""" + ) + + with pytest.raises( + ValueError, match="last statement to be a function definition" + ): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_roundtrip(self): + def greet(name: str) -> str: + return f"Hello, {name}!" + + encodable = Encodable.define(Callable[[str], str], {}) + with handler(UnsafeEvalProvider()): + encoded = encodable.encode(greet) + decoded = encodable.decode(encoded) + + assert callable(decoded) + assert decoded("Alice") == "Hello, Alice!" + assert decoded.__name__ == "greet" + + def test_serialize_deserialize(self): + def add(a: int, b: int) -> int: + return a + b + + encodable = Encodable.define(Callable[[int, int], int], {}) + encoded = encodable.encode(add) + + # Test serialization + serialized = encodable.serialize(encoded) + assert len(serialized) == 1 + assert serialized[0]["type"] == "text" + assert "module_code" in serialized[0]["text"] + + # Test deserialization + deserialized = encodable.deserialize(serialized[0]["text"]) + assert isinstance(deserialized, SynthesizedFunction) + assert "def add" in deserialized.module_code + + def test_decode_validates_last_statement(self): + encodable = Encodable.define(Callable[..., int], {}) + + # Helper function followed by assignment - should fail + source = SynthesizedFunction( + module_code="""def helper(): + return 42 + +result = helper()""" + ) + with pytest.raises( + ValueError, match="last statement to be a function definition" + ): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_typed_callable_includes_signature_in_docstring(self): + # Test that the enc type has the signature in its docstring + encodable = Encodable.define(Callable[[int, int], int], {}) + assert encodable.enc.__doc__ is not None + assert "Callable[[int, int], int]" in encodable.enc.__doc__ + assert "" in encodable.enc.__doc__ + + def test_typed_callable_validates_param_count(self): + encodable = Encodable.define(Callable[[int, int], int], {}) + + # Function with wrong number of parameters + source = SynthesizedFunction( + module_code="""def add(a: int) -> int: + return a""" + ) + with pytest.raises(ValueError, match="expected function with 2 parameters"): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_typed_callable_validates_return_type(self): + encodable = Encodable.define(Callable[[int, int], int], {}) + + # Function with wrong return type + source = SynthesizedFunction( + module_code="""def add(a: int, b: int) -> str: + return str(a + b)""" + ) + with pytest.raises(ValueError, match="expected function with return type int"): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_typed_callable_requires_return_annotation(self): + encodable = Encodable.define(Callable[[int, int], int], {}) + + # Function missing return type annotation + source = SynthesizedFunction( + module_code="""def add(a: int, b: int): + return a + b""" + ) + with pytest.raises( + ValueError, + match="requires synthesized function to have a return type annotation", + ): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_typed_callable_accepts_correct_signature(self): + encodable = Encodable.define(Callable[[int, int], int], {}) + + # Function with correct signature + source = SynthesizedFunction( + module_code="""def add(a: int, b: int) -> int: + return a + b""" + ) + with handler(UnsafeEvalProvider()): + result = encodable.decode(source) + assert callable(result) + assert result(2, 3) == 5 + + def test_ellipsis_callable_skips_param_validation(self): + # Callable[..., int] should skip param validation but still validate return + encodable = Encodable.define(Callable[..., int], {}) + + source = SynthesizedFunction( + module_code="""def anything(a, b, c, d, e) -> int: + return 42""" + ) + with handler(UnsafeEvalProvider()): + result = encodable.decode(source) + assert callable(result) + assert result(1, 2, 3, 4, 5) == 42 + + def test_typed_callable_json_schema_includes_signature(self): + # Test that the JSON schema includes the type signature for the LLM + encodable = Encodable.define(Callable[[int, int], int], {}) + + # Get the JSON schema from the enc model + schema = encodable.enc.model_json_schema() + + # The description should contain the type signature + assert "description" in schema + assert "Callable[[int, int], int]" in schema["description"] + assert "" in schema["description"] + assert "" in schema["description"] + + def test_typed_callable_json_schema_different_signatures(self): + # Test that different type signatures produce different schemas + enc1 = Encodable.define(Callable[[str], str], {}) + enc2 = Encodable.define(Callable[[int, int, int], bool], {}) + + schema1 = enc1.enc.model_json_schema() + schema2 = enc2.enc.model_json_schema() + + assert "Callable[[str], str]" in schema1["description"] + assert "Callable[[int, int, int], bool]" in schema2["description"] + + def test_validates_param_count_via_ast(self): + # Test that param validation happens via AST analysis + encodable = Encodable.define(Callable[[int, int], int], {}) + + # Function with 3 params when 2 expected + source = SynthesizedFunction( + module_code="""def add(a: int, b: int, c: int) -> int: + return a + b + c""" + ) + with pytest.raises(ValueError, match="expected function with 2 parameters"): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_validates_param_count_zero_params(self): + # Test callable with no params + encodable = Encodable.define(Callable[[], int], {}) + + # Function with params when 0 expected + source = SynthesizedFunction( + module_code="""def get_value(x: int) -> int: + return x""" + ) + with pytest.raises(ValueError, match="expected function with 0 parameters"): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_validates_accepts_zero_params(self): + # Test callable with no params - correct signature + encodable = Encodable.define(Callable[[], int], {}) + + source = SynthesizedFunction( + module_code="""def get_value() -> int: + return 42""" + ) + with handler(UnsafeEvalProvider()): + result = encodable.decode(source) + assert callable(result) + assert result() == 42 + + def test_ellipsis_callable_json_schema_includes_signature(self): + # Test that Callable[..., int] has signature in schema + encodable = Encodable.define(Callable[..., int], {}) + + schema = encodable.enc.model_json_schema() + assert "description" in schema + assert "Callable[[...], int]" in schema["description"] + assert "" in schema["description"] + + def test_ellipsis_callable_validates_return_type(self): + # Callable[..., int] should still validate return type + encodable = Encodable.define(Callable[..., int], {}) + + source = SynthesizedFunction( + module_code="""def get_value() -> str: + return "wrong type\"""" + ) + with pytest.raises(ValueError, match="expected function with return type int"): + with handler(UnsafeEvalProvider()): + encodable.decode(source) + + def test_callable_with_single_param(self): + encodable = Encodable.define(Callable[[str], int], {}) + + source = SynthesizedFunction( + module_code="""def count_chars(s: str) -> int: + return len(s)""" + ) + with handler(UnsafeEvalProvider()): + result = encodable.decode(source) + assert callable(result) + assert result("hello") == 5 + + def test_callable_with_many_params(self): + encodable = Encodable.define(Callable[[int, int, int, int], int], {}) + + source = SynthesizedFunction( + module_code="""def sum_four(a: int, b: int, c: int, d: int) -> int: + return a + b + c + d""" + ) + with handler(UnsafeEvalProvider()): + result = encodable.decode(source) + assert callable(result) + assert result(1, 2, 3, 4) == 10 + + def test_callable_with_bool_return(self): + encodable = Encodable.define(Callable[[int], bool], {}) + + source = SynthesizedFunction( + module_code="""def is_positive(x: int) -> bool: + return x > 0""" + ) + with handler(UnsafeEvalProvider()): + result = encodable.decode(source) + assert callable(result) + assert result(5) is True + assert result(-1) is False + + def test_callable_type_variations_schema(self): + # Test various callable type variations have correct schemas + test_cases = [ + (Callable[[], int], "Callable[[], int]"), + (Callable[[str], str], "Callable[[str], str]"), + (Callable[[int, str], bool], "Callable[[int, str], bool]"), + (Callable[..., int], "Callable[[...], int]"), + (Callable[..., Any], "Callable[[...], Any]"), + ] + + for callable_type, expected_sig in test_cases: + encodable = Encodable.define(callable_type, {}) + schema = encodable.enc.model_json_schema() + assert "description" in schema, f"No description for {callable_type}" + assert expected_sig in schema["description"], ( + f"Expected {expected_sig} in schema for {callable_type}, " + f"got: {schema['description'][:100]}..." + ) diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index 66c7af7b..0a766231 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -5,6 +5,7 @@ """ import functools +import inspect import json import os from collections.abc import Callable @@ -25,7 +26,8 @@ call_assistant, completion, ) -from effectful.handlers.llm.synthesis import ProgramSynthesis, SynthesisError +from effectful.handlers.llm.encoding import Encodable, SynthesizedFunction +from effectful.handlers.llm.evaluation import UnsafeEvalProvider from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import NotHandled @@ -240,29 +242,6 @@ def test_with_config_params(self, request): assert isinstance(result, str) -@pytest.mark.xfail(reason="Program synthesis not implemented") -class TestProgramSynthesis: - """Tests for ProgramSynthesis handler functionality.""" - - @pytest.mark.xfail - @requires_openai - @retry_on_error(error=SynthesisError, n=3) - def test_generates_callable(self, request): - """Test ProgramSynthesis handler generates executable code.""" - with ( - handler(ReplayLiteLLMProvider(request, model="gpt-4o-mini")), - handler(ProgramSynthesis()), - handler(LimitLLMCallsHandler(max_calls=1)), - ): - count_func = create_function("a") - - assert callable(count_func) - # Test the generated function - assert count_func("banana") == 3 - assert count_func("cherry") == 0 - assert count_func("aardvark") == 3 - - def smiley_face() -> Image.Image: bmp = [ "00000000", @@ -367,3 +346,184 @@ def test_litellm_caching_selective(request): p1 = simple_prompt("apples") p2 = simple_prompt("apples") assert p1 != p2, "when caching is not enabled, llm outputs should be different" + + +# ============================================================================ +# Callable Synthesis Tests +# ============================================================================ + + +@Template.define +def synthesize_adder() -> Callable[[int, int], int]: + """Generate a Python function that adds two integers together. + + The function should take two integer parameters and return their sum. + """ + raise NotHandled + + +@Template.define +def synthesize_string_processor() -> Callable[[str], str]: + """Generate a Python function that converts a string to uppercase + and adds exclamation marks at the end. + """ + raise NotHandled + + +@Template.define +def synthesize_counter(char: str) -> Callable[[str], int]: + """Generate a Python function that counts occurrences of the character '{char}' + in a given input string. + + The function should be case-sensitive. + """ + raise NotHandled + + +@Template.define +def synthesize_is_even() -> Callable[[int], bool]: + """Generate a Python function that checks if a number is even. + + Return True if the number is divisible by 2, False otherwise. + """ + raise NotHandled + + +@Template.define +def synthesize_three_param_func() -> Callable[[int, int, int], int]: + """Generate a Python function that takes exactly three integer parameters + and returns their product (multiplication). + """ + raise NotHandled + + +class TestCallableSynthesis: + """Tests for synthesizing callable functions via LLM.""" + + @requires_openai + def test_synthesize_adder_function(self, request): + """Test that LLM can synthesize a simple addition function with correct signature.""" + with ( + handler(ReplayLiteLLMProvider(request, model="gpt-4o-mini")), + handler(UnsafeEvalProvider()), + handler(LimitLLMCallsHandler(max_calls=1)), + ): + add_func = synthesize_adder() + + assert callable(add_func) + assert add_func(2, 3) == 5 + assert add_func(0, 0) == 0 + assert add_func(-1, 1) == 0 + assert add_func(100, 200) == 300 + + @requires_openai + def test_synthesize_string_processor(self, request): + """Test that LLM can synthesize a string processing function.""" + with ( + handler(ReplayLiteLLMProvider(request, model="gpt-4o-mini")), + handler(UnsafeEvalProvider()), + handler(LimitLLMCallsHandler(max_calls=1)), + ): + process_func = synthesize_string_processor() + + assert callable(process_func) + result = process_func("hello") + assert isinstance(result, str) + assert "HELLO" in result + assert "!" in result + + @requires_openai + def test_synthesize_counter_with_parameter(self, request): + """Test that LLM can synthesize a parameterized counting function.""" + with ( + handler(ReplayLiteLLMProvider(request, model="gpt-4o-mini")), + handler(UnsafeEvalProvider()), + handler(LimitLLMCallsHandler(max_calls=3)), + ): + count_a = synthesize_counter("a") + + assert callable(count_a) + assert count_a("banana") == 3 + assert count_a("cherry") == 0 + assert count_a("aardvark") == 3 + assert count_a("AAA") == 0 # case-sensitive + + @requires_openai + def test_callable_type_signature_in_schema(self, request): + """Test that the callable type signature is communicated to the LLM.""" + + # Verify that the enc type includes the signature in its docstring + encodable = Encodable.define(Callable[[int, int], int], {}) + assert encodable.enc.__doc__ is not None + assert "Callable[[int, int], int]" in encodable.enc.__doc__ + + encodable2 = Encodable.define(Callable[[str], str], {}) + assert encodable2.enc.__doc__ is not None + assert "Callable[[str], str]" in encodable2.enc.__doc__ + + @requires_openai + def test_synthesized_function_roundtrip(self, request): + """Test that a synthesized function can be encoded and decoded.""" + + with ( + handler(ReplayLiteLLMProvider(request, model="gpt-4o-mini")), + handler(UnsafeEvalProvider()), + handler(LimitLLMCallsHandler(max_calls=1)), + ): + # Synthesize a function + add_func = synthesize_adder() + assert callable(add_func) + + # Encode it back to SynthesizedFunction + encodable = Encodable.define(Callable[[int, int], int], {}) + encoded = encodable.encode(add_func) + assert isinstance(encoded, SynthesizedFunction) + assert "def " in encoded.module_code + + # Decode it again and verify it still works + decoded = encodable.decode(encoded) + assert callable(decoded) + assert decoded(5, 7) == 12 + + @requires_openai + def test_synthesize_bool_return_type(self, request): + """Test that LLM respects bool return type in signature.""" + + with ( + handler(ReplayLiteLLMProvider(request, model="gpt-4o-mini")), + handler(UnsafeEvalProvider()), + handler(LimitLLMCallsHandler(max_calls=1)), + ): + is_even = synthesize_is_even() + + assert callable(is_even) + # Verify return type annotation + sig = inspect.signature(is_even) + assert sig.return_annotation == bool + + # Verify behavior + assert is_even(2) is True + assert is_even(3) is False + assert is_even(0) is True + assert is_even(-4) is True + + @requires_openai + def test_synthesize_three_params(self, request): + """Test that LLM respects the exact number of parameters in signature.""" + + with ( + handler(ReplayLiteLLMProvider(request, model="gpt-4o-mini")), + handler(UnsafeEvalProvider()), + handler(LimitLLMCallsHandler(max_calls=1)), + ): + multiply_three = synthesize_three_param_func() + + assert callable(multiply_three) + # Verify parameter count + sig = inspect.signature(multiply_three) + assert len(sig.parameters) == 3 + + # Verify behavior + assert multiply_three(2, 3, 4) == 24 + assert multiply_three(1, 1, 1) == 1 + assert multiply_three(5, 0, 10) == 0