diff --git a/effectful/handlers/llm/doctest.py b/effectful/handlers/llm/doctest.py new file mode 100644 index 00000000..27e41dc5 --- /dev/null +++ b/effectful/handlers/llm/doctest.py @@ -0,0 +1,296 @@ +"""Doctest semantic constraints for Templates. + +Provides a :class:`DoctestHandler` that uses ``>>>`` examples in template +docstrings as semantic constraints rather than literal prompts. + +**Case 1 (tool-calling)**: When the template returns a non-Callable type, a +calibration loop runs the doctest inputs through the LLM once per template +definition and caches the entire conversation (including any incorrect +attempts) as a few-shot prefix for future calls, emulating a mini ReAct +agent that learns from its mistakes. + +**Case 2 (code generation)**: When the template returns a ``Callable`` type, +the generated code is required to pass the doctests as post-hoc validation. + +In both cases, ``>>>`` examples are stripped from the prompt sent to the LLM +so it cannot memorise the expected outputs. +""" + +import ast +import collections +import contextlib +import doctest +import typing +from collections.abc import Mapping +from typing import Any + +from effectful.handlers.llm.completions import ( + Message, + _make_message, + call_user, +) +from effectful.handlers.llm.encoding import CallableEncodable, Encodable +from effectful.handlers.llm.evaluation import test +from effectful.handlers.llm.template import Template +from effectful.ops.semantics import fwd +from effectful.ops.syntax import ObjectInterpretation, implements + +_SENTINEL = object() + + +class DoctestHandler(ObjectInterpretation): + """Use ``>>>`` examples in template docstrings as semantic constraints. + + Install with ``handler(DoctestHandler())`` alongside a provider and an + eval provider. See the module docstring for the two cases handled. + """ + + # Per-template extraction cache (stripped template + examples). + _extraction_cache: dict[Template, tuple[str, list[doctest.Example]]] + + # Case 1: calibration conversation prefix, cached per template. + _prefix_cache: dict[Template, list[Message]] + + # Case 2: per-call cached doctest examples for test() validation. + _doctest_stack: list[list[doctest.Example]] + + # Re-entrancy guard: set of templates currently being calibrated. + _calibrating: set[Template] + + def __init__(self) -> None: + self._extraction_cache = {} + self._doctest_stack = [] + self._prefix_cache = {} + self._calibrating = set() + + # -- helpers ------------------------------------------------------------ + + @classmethod + def extract_doctests(cls, docstring: str) -> tuple[str, list[doctest.Example]]: + """Separate a docstring into text-without-examples and a list of examples. + + Uses :class:`doctest.DocTestParser` to identify ``>>>`` blocks, then + reconstructs the docstring with those blocks removed. + + Returns ``(stripped_text, examples)`` where *stripped_text* is the + docstring with all interactive examples removed. + """ + parser = doctest.DocTestParser() + parts = parser.parse(docstring) + text_parts = [p for p in parts if isinstance(p, str)] + examples = [p for p in parts if isinstance(p, doctest.Example)] + return "".join(text_parts), examples + + @staticmethod + def _parse_template_call( + example: doctest.Example, template_name: str + ) -> tuple[list[Any] | None, dict[str, Any] | None]: + """Extract positional and keyword args from a doctest example. + + Returns ``(args, kwargs)`` if the example is a call to + *template_name*, or ``(None, None)`` otherwise. + """ + source = example.source.strip() + try: + tree = ast.parse(source, mode="eval") + except SyntaxError: + return None, None + + expr = tree.body + if not isinstance(expr, ast.Call): + return None, None + if not isinstance(expr.func, ast.Name): + return None, None + if expr.func.id != template_name: + return None, None + + try: + pos_args = [ast.literal_eval(a) for a in expr.args] + kw_args = { + kw.arg: ast.literal_eval(kw.value) + for kw in expr.keywords + if kw.arg is not None + } + except (ValueError, TypeError): + return None, None + + return pos_args, kw_args + + @contextlib.contextmanager + def _bind_history( + self, + template: Template, + history: collections.OrderedDict[str, Message], + ): + """Temporarily bind *history* to ``template.__history__``. + + Uses the same attribute that :class:`Agent` binds via ``__get__``. + The provider reads and writes back to it, so messages accumulate. + """ + old = getattr(template, "__history__", _SENTINEL) + template.__history__ = history # type: ignore[attr-defined] + try: + yield + finally: + if old is _SENTINEL: + try: + del template.__history__ # type: ignore[attr-defined] + except AttributeError: + pass + else: + template.__history__ = old # type: ignore[attr-defined] + + # -- Template.__apply__ ------------------------------------------------- + + @implements(Template.__apply__) + def _handle_template[**P, T]( + self, + template: Template[P, T], + *_args: P.args, + **_kwargs: P.kwargs, + ) -> T: + if template not in self._extraction_cache: + self._extraction_cache[template] = self.extract_doctests( + template.__prompt_template__ + ) + _, examples = self._extraction_cache[template] + + if not examples: + return fwd() + + if isinstance( + Encodable.define(template.__signature__.return_annotation), + CallableEncodable, + ): + # Case 2 – code generation: push cached examples for test(). + self._doctest_stack.append(examples) + return fwd() + + # Case 1 – tool-calling: calibration + prefix. + if template not in self._calibrating and template not in self._prefix_cache: + self._calibrate(template, examples) + + prefix = self._prefix_cache.get(template, []) + if prefix: + # Pre-populate history with the cached calibration prefix + # (Agent-style); the provider will copy it and prepend the + # system message, so the LLM sees: + # system → prefix user/assistant turns → actual user message. + prefix_history: collections.OrderedDict[str, Message] = ( + collections.OrderedDict((m["id"], m) for m in prefix) + ) + with self._bind_history(template, prefix_history): + return fwd() + + return fwd() + + # -- call_user ---------------------------------------------------------- + + @implements(call_user) + def _strip_prompt( + self, + template: str, + env: Mapping[str, Any], + ) -> Message: + """Strip ``>>>`` examples from the prompt before the LLM sees it.""" + stripped, _ = self.extract_doctests(template) + return fwd(stripped, env) + + # -- test (Case 2 validation) ------------------------------------------- + + @implements(test) + def _run_from_stack(self, obj: object, ctx: typing.Mapping[str, Any]) -> None: + if not self._doctest_stack: + return + examples = self._doctest_stack.pop() + if not examples: + return + + name = ( + f"{getattr(obj, '__name__', obj.__class__.__name__)}.__template_doctest__" + ) + test_case = doctest.DocTest( + examples=examples, + globs=dict(ctx), + name=name, + filename=None, + lineno=0, + docstring=None, + ) + + output: list[str] = [] + runner = doctest.DocTestRunner(verbose=False) + runner.run(test_case, out=output.append) + results = runner.summarize(verbose=False) + if results.failed: + report = "".join(output).strip() + if not report: + report = ( + f"{results.failed} doctest(s) failed " + f"out of {results.attempted} attempted." + ) + raise TypeError(f"doctest failed:\n{report}") + + # -- Case 1 calibration ------------------------------------------------- + + def _calibrate( + self, + template: Template, + examples: list[doctest.Example], + ) -> None: + """Run calibration as a mini ReAct agent with Agent-style history. + + Reuses the same persistent-history mechanism as :class:`Agent`: a + shared :class:`~collections.OrderedDict` bound to + ``template.__history__`` that accumulates messages across calls. + Each doctest example is evaluated in order; the LLM sees all prior + conversation turns (including any corrective feedback for incorrect + answers) when processing subsequent examples, enabling it to learn + from the full experience. + """ + self._calibrating.add(template) + + shared_history: collections.OrderedDict[str, Message] = ( + collections.OrderedDict() + ) + + with self._bind_history(template, shared_history): + try: + for example in examples: + self._run_calibration_example(template, example, shared_history) + finally: + self._calibrating.discard(template) + + self._prefix_cache[template] = [ + m for m in shared_history.values() if m["role"] != "system" + ] + + def _run_calibration_example( + self, + template: Template, + example: doctest.Example, + history: collections.OrderedDict[str, Message], + ) -> None: + """Evaluate one doctest example and append corrective feedback.""" + call_args, call_kwargs = self._parse_template_call(example, template.__name__) + if call_args is None or call_kwargs is None: + return + + result = template(*call_args, **call_kwargs) + + checker = doctest.OutputChecker() + actual = repr(result) + "\n" + optionflags = sum(f for f, v in example.options.items() if v) + + if not checker.check_output(example.want, actual, optionflags): + feedback = _make_message( + { + "role": "user", + "content": ( + f"That was incorrect. " + f"Expected {example.want.strip()!r} " + f"but got {repr(result)!r}." + ), + } + ) + history[feedback["id"]] = feedback diff --git a/effectful/handlers/llm/encoding.py b/effectful/handlers/llm/encoding.py index bcfa15a4..0c4da622 100644 --- a/effectful/handlers/llm/encoding.py +++ b/effectful/handlers/llm/encoding.py @@ -583,6 +583,11 @@ def decode(self, encoded_value: SynthesizedFunction) -> Callable: # Validate signature from runtime callable after execution _validate_signature_callable(result, self.expected_params, self.expected_return) + # Run doctests from the original template docstring (if any) + module_obj = types.ModuleType(filename) + module_obj.__dict__.update(g) + evaluation.test(module_obj, module_obj.__dict__) + return result def serialize( diff --git a/effectful/handlers/llm/evaluation.py b/effectful/handlers/llm/evaluation.py index 07348cc9..f3f378d7 100644 --- a/effectful/handlers/llm/evaluation.py +++ b/effectful/handlers/llm/evaluation.py @@ -96,6 +96,19 @@ def exec( ) +@defop +def test(obj: object, ctx: typing.Mapping[str, Any]) -> None: + """ + Run doctests for a synthesized program using the current doctest stack. + + obj: The synthesized module object. + ctx: The namespace used to run doctest examples. + + No-op by default; install a DoctestHandler to actually run doctests. + """ + pass + + # Type checking implementation def type_to_ast(typ: Any) -> ast.expr: """Convert a Python type to an AST expression. diff --git a/tests/conftest.py b/tests/conftest.py index e6ad3e07..e80b4f8f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,3 +18,19 @@ def pytest_runtest_call(item): pytest.xfail(str(e)) else: raise e + + +def pytest_collection_modifyitems(config, items): + """Remove auto-collected doctests from LLM template functions. + + Template docstrings contain ``>>>`` examples that serve as LLM prompts + for the DoctestHandler, not as standalone doctests for pytest to run. + """ + items[:] = [ + item + for item in items + if not ( + type(item).__name__ == "DoctestItem" + and "test_handlers_llm_doctest" in item.nodeid + ) + ] diff --git a/tests/test_handlers_llm_doctest.py b/tests/test_handlers_llm_doctest.py new file mode 100644 index 00000000..c5a5586a --- /dev/null +++ b/tests/test_handlers_llm_doctest.py @@ -0,0 +1,256 @@ +import doctest as _doctest +import os +from collections.abc import Callable + +import pytest + +from effectful.handlers.llm import Template +from effectful.handlers.llm.completions import ( + LiteLLMProvider, + ResultDecodingError, + call_user, +) +from effectful.handlers.llm.doctest import DoctestHandler +from effectful.handlers.llm.encoding import Encodable, SynthesizedFunction +from effectful.handlers.llm.evaluation import UnsafeEvalProvider +from effectful.ops.semantics import NotHandled, handler + +HAS_OPENAI_KEY = "OPENAI_API_KEY" in os.environ and os.environ["OPENAI_API_KEY"] +requires_openai = pytest.mark.skipif( + not HAS_OPENAI_KEY, reason="OPENAI_API_KEY environment variable not set" +) + + +@Template.define +def synthesize_counter_with_doctest(char: str) -> Callable[[str], int]: + """Generate a Python function named count_char that counts occurrences of the character '{char}' + in a given input string. + + The function should be case-sensitive. + + >>> count_char("banana") + 4 + """ + raise NotHandled + + +@Template.define +def synthesize_inner_with_doctest(char: str) -> Callable[[str], int]: + """Generate a Python function named count_char that counts occurrences of the character '{char}' + in a given input string. + + The function should be case-sensitive. + + >>> count_char("orange") + 3 + """ + raise NotHandled + + +@Template.define +def synthesize_outer(char: str) -> Callable[[str], int]: + """Use the synthesize_inner_with_doctest tool to produce the function and return it. + Do not implement the function yourself. + """ + raise NotHandled + + +# --------------------------------------------------------------------------- +# Unit tests: extract_doctests +# --------------------------------------------------------------------------- + + +class TestExtractDoctests: + """Tests for the DoctestHandler.extract_doctests classmethod.""" + + def test_strips_examples(self): + docstring = ( + "Compute something.\n\n >>> foo(1)\n 2\n >>> foo(3)\n 4\n" + ) + stripped, examples = DoctestHandler.extract_doctests(docstring) + assert ">>>" not in stripped + assert len(examples) == 2 + assert examples[0].source.strip() == "foo(1)" + assert examples[0].want == "2\n" + assert examples[1].source.strip() == "foo(3)" + assert examples[1].want == "4\n" + + def test_no_examples(self): + docstring = "Just a description.\nNo examples here.\n" + stripped, examples = DoctestHandler.extract_doctests(docstring) + assert stripped == docstring + assert examples == [] + + def test_preserves_non_example_text(self): + docstring = "Title.\n\nSome details.\n\n >>> f(1)\n 42\n\nMore text.\n" + stripped, examples = DoctestHandler.extract_doctests(docstring) + assert "Title." in stripped + assert "Some details." in stripped + assert "More text." in stripped + assert ">>>" not in stripped + assert len(examples) == 1 + + +# --------------------------------------------------------------------------- +# Unit tests: Case 2 – prompt stripping +# --------------------------------------------------------------------------- + + +class TestCase2PromptStripping: + """Verify that call_user receives a stripped template (no >>> examples).""" + + def test_call_user_receives_stripped_template(self): + """The DoctestHandler should strip >>> from the template before fwd.""" + captured_templates: list[str] = [] + + def spy_call_user(template, env): + captured_templates.append(template) + # Return a dummy message + return { + "role": "user", + "content": template, + "id": "test-id", + } + + doctest_handler = DoctestHandler() + # DoctestHandler must be inner (most recent) so _strip_prompt runs + # first, then fwd() reaches the spy. + with handler({call_user: spy_call_user}), handler(doctest_handler): + # Directly invoke call_user with a template containing >>> + template_str = "Generate function.\n\n >>> foo(1)\n 42\n" + call_user(template_str, {}) + + assert len(captured_templates) == 1 + assert ">>>" not in captured_templates[0] + assert "Generate function." in captured_templates[0] + + +# --------------------------------------------------------------------------- +# Unit tests: Case 2 – doctest execution (existing tests, updated) +# --------------------------------------------------------------------------- + + +class TestDoctestExecution: + """Tests for doctest execution during callable synthesis (Case 2).""" + + def test_decode_runs_doctest(self): + encodable = Encodable.define(Callable[[str], int], {}) + func_source = SynthesizedFunction( + module_code="def count_char(input_string: str) -> int:\n" + " return input_string.count('a')" + ) + doctest_handler = DoctestHandler() + # Push cached Example objects (matching the new _doctest_stack type). + doctest_handler._doctest_stack.append( + [_doctest.Example("count_char('banana')\n", "4\n")] + ) + with ( + handler(UnsafeEvalProvider()), + handler(doctest_handler), + ): + with pytest.raises(TypeError, match="doctest failed"): + encodable.decode(func_source) + + @requires_openai + def test_template_doctest_runs(self): + provider = LiteLLMProvider(model="gpt-4o-mini") + with ( + handler(provider), + handler(UnsafeEvalProvider()), + handler(DoctestHandler()), + ): + with pytest.raises(ResultDecodingError, match="doctest failed"): + synthesize_counter_with_doctest("a") + + @requires_openai + def test_nested_synthesis_doctest_runs(self): + provider = LiteLLMProvider(model="gpt-4o-mini") + with ( + handler(provider), + handler(UnsafeEvalProvider()), + handler(DoctestHandler()), + ): + with pytest.raises(ResultDecodingError, match="doctest failed"): + synthesize_outer("o") + + +# --------------------------------------------------------------------------- +# Unit tests: Case 1 – calibration +# --------------------------------------------------------------------------- + + +@Template.define +def summarize(text: str) -> str: + """Summarize the following text into a single short sentence: '{text}' + + >>> summarize("The quick brown fox jumps over the lazy dog.") + 'A fox jumps over a dog.' + """ + raise NotHandled + + +class TestCase1Calibration: + """Tests for Case 1 (tool-calling) calibration and prefix caching.""" + + def test_callable_detection(self): + """Templates returning Callable should be Case 2, others Case 1.""" + from effectful.handlers.llm.encoding import CallableEncodable, Encodable + + def is_callable_return(t): + return isinstance( + Encodable.define(t.__signature__.return_annotation), + CallableEncodable, + ) + + assert is_callable_return(synthesize_counter_with_doctest) + assert not is_callable_return(summarize) + + def test_extraction_cache_populated(self): + """extract_doctests result should be cached per template.""" + dh = DoctestHandler() + assert summarize not in dh._extraction_cache + # Populate via extract_doctests + stripped, examples = DoctestHandler.extract_doctests( + summarize.__prompt_template__ + ) + dh._extraction_cache[summarize] = (stripped, examples) + assert ">>>" not in stripped + assert len(examples) == 1 + # Second access returns same objects + assert dh._extraction_cache[summarize] is (stripped, examples) or ( + dh._extraction_cache[summarize][0] is stripped + and dh._extraction_cache[summarize][1] is examples + ) + + def test_bind_history_restores_state(self): + """_bind_history should restore template.__history__ after use.""" + import collections + + dh = DoctestHandler() + + # Template starts without __history__ + assert not hasattr(summarize, "__history__") + + history = collections.OrderedDict() + with dh._bind_history(summarize, history): + assert getattr(summarize, "__history__") is history + + # Cleaned up after context exit + assert not hasattr(summarize, "__history__") + + @requires_openai + def test_case1_calibration_integration(self): + """End-to-end: calibration should cache a prefix for tool-calling.""" + provider = LiteLLMProvider(model="gpt-4o-mini") + dh = DoctestHandler() + with handler(provider), handler(dh): + # This should trigger calibration for the summarize template + result = summarize("The quick brown fox jumps over the lazy dog.") + + # After the call, summarize should have a cached prefix + assert summarize in dh._prefix_cache + assert isinstance(result, str) + assert len(result) > 0 + + # Calibration should clean up: no lingering __history__ on template + assert not hasattr(summarize, "__history__")