diff --git a/docs/source/multi_agent_example.py b/docs/source/multi_agent_example.py new file mode 100644 index 00000000..74533d22 --- /dev/null +++ b/docs/source/multi_agent_example.py @@ -0,0 +1,315 @@ +"""Multi-agent system using choreographic endpoint projection. + +Demonstrates: +- Choreographic programming: one function describes the entire workflow +- Automatic endpoint projection: each agent gets its own thread +- Crash tolerance: Ctrl-C and restart, agents resume where they left off +- Scatter: two coder agents share the implementation work via claim-based pull +- PersistentAgent for automatic checkpointing and context compaction + +The scenario: a team of agents collaboratively builds a small Python library. +An architect agent breaks the project into module specs, two coder agents +implement the modules in parallel (via scatter), and two reviewer agents +review modules in parallel and request fixes if needed. + +Usage:: + + # First run — agents start working + python docs/source/multi_agent_example.py + + # Ctrl-C mid-run, then restart — agents pick up where they left off + python docs/source/multi_agent_example.py + +Requirements: + pip install effectful[llm] + export OPENAI_API_KEY=... # or any LiteLLM-supported provider + +""" + +import json +import logging +from pathlib import Path +from typing import Literal, TypedDict + +from effectful.handlers.llm import Template, Tool +from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler +from effectful.handlers.llm.multi import ( + Choreography, + ChoreographyError, + PersistentTaskQueue, + scatter, +) +from effectful.handlers.llm.persistence import PersistenceHandler, PersistentAgent +from effectful.ops.types import NotHandled + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(threadName)s] %(message)s", + datefmt="%H:%M:%S", +) +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +WORKSPACE = Path("./multi_agent_workspace") +STATE_DIR = WORKSPACE / ".state" +OUTPUT_DIR = WORKSPACE / "output" +MODEL = "gpt-4o-mini" + +# The project to build +PROJECT_SPEC = """\ +Build a small Python utility library called 'textkit' with these modules: +1. textkit/slugify.py — convert strings to URL-safe slugs +2. textkit/wrap.py — word-wrap text to a given width +3. textkit/redact.py — redact email addresses and phone numbers from text +Each module should have a clear public API, docstrings, and at least 3 +test cases written as a separate test_.py file. +""" + + +# --------------------------------------------------------------------------- +# Structured types — constrained decoding for LLM output +# --------------------------------------------------------------------------- + + +class ModuleSpec(TypedDict): + """Schema for architect planning output — constrained decoding ensures valid shape.""" + + module_path: str + description: str + public_api: str + test_path: str + + +class PlanResult(TypedDict): + """Wrapper for list output — LiteLLM requires a root object, not bare array.""" + + modules: list[ModuleSpec] + + +class ReviewResult(TypedDict): + """Schema for reviewer output — verdict constrained to PASS or NEEDS_FIXES.""" + + verdict: Literal["PASS", "NEEDS_FIXES"] + feedback: str + + +# --------------------------------------------------------------------------- +# Agents +# --------------------------------------------------------------------------- + + +class ArchitectAgent(PersistentAgent): + """You are a software architect. Given a project specification, you break + it into individual module implementation tasks. Each task should specify + the module filename, its public API, and what tests to write. + Be concrete and specific — the coder will follow your spec exactly. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._output_dir = OUTPUT_DIR + + @Tool.define + def read_existing_files(self) -> str: + """List files already written to the output directory.""" + if not self._output_dir.exists(): + return "No files yet." + files = sorted(self._output_dir.rglob("*.py")) + if not files: + return "No Python files yet." + return "\n".join(str(f.relative_to(self._output_dir)) for f in files) + + @Template.define + def plan_modules(self, project_spec: str) -> PlanResult: + """Given this project specification, output a plan with a "modules" list. + Each module spec has: module_path, description, public_api, test_path. + + Use `read_existing_files` to check what's already been written + and skip those. + + Project spec: + {project_spec}""" + raise NotHandled + + +class CoderAgent(PersistentAgent): + """You are an expert Python developer. Given a module specification, + you write clean, well-documented Python code. You also write thorough + test files. Output ONLY the Python source code, no markdown fences. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._output_dir = OUTPUT_DIR + + @Tool.define + def read_file(self, path: str) -> str: + """Read a file from the output directory.""" + full = self._output_dir / path + if full.exists(): + return full.read_text() + return f"File not found: {path}" + + @Tool.define + def write_file(self, path: str, content: str) -> str: + """Write a file to the output directory.""" + full = self._output_dir / path + full.parent.mkdir(parents=True, exist_ok=True) + full.write_text(content) + return f"Wrote {len(content)} chars to {path}" + + @Template.define + def implement_module(self, module_spec: str) -> str: + """Implement the following module specification. Use `write_file` + to write both the module and its test file. Use `read_file` to + check existing code if needed. + + Specification: + {module_spec}""" + raise NotHandled + + +class ReviewerAgent(PersistentAgent): + """You are a senior code reviewer. You review Python modules for + correctness, style, edge cases, and test coverage. Be specific + about issues and provide actionable feedback. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._output_dir = OUTPUT_DIR + + @Tool.define + def read_file(self, path: str) -> str: + """Read a file from the output directory.""" + full = self._output_dir / path + if full.exists(): + return full.read_text() + return f"File not found: {path}" + + @Template.define + def review_module(self, module_path: str, test_path: str) -> ReviewResult: + """Review the module at {module_path} and its tests at {test_path}. + Use `read_file` to read them. Return verdict "PASS" or "NEEDS_FIXES" + and feedback. If NEEDS_FIXES, explain exactly what to change.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Choreographic program — the entire multi-agent workflow in one function +# --------------------------------------------------------------------------- + + +def build_project( + project_spec: str, + architect: ArchitectAgent, + coder: CoderAgent, + reviewer: ReviewerAgent, +) -> list[ReviewResult]: + """Choreographic program describing the full build workflow. + + 1. Architect breaks the project into module specs. + 2. Coders implement modules in parallel (scatter distributes via claim-based pull). + 3. Reviewers review modules in parallel; coders fix in parallel until all pass. + """ + # Step 1: Architect plans modules + plan = architect.plan_modules(project_spec) + + # Step 2: Scatter implementation across coders + # Each module becomes a task in the queue; coders claim until none remain. + scatter( + plan["modules"], + coder, + lambda c, mod: c.implement_module(json.dumps(mod, indent=2)), + ) + + # Step 3: Review loop — keep fixing until reviewers accept all modules + while True: + reviews: list[ReviewResult] = scatter( + plan["modules"], + reviewer, + lambda r, mod: r.review_module(mod["module_path"], mod["test_path"]), + ) + + needs_fixes = [ + (mod, review) + for mod, review in zip(plan["modules"], reviews) + if review["verdict"] == "NEEDS_FIXES" + ] + + if not needs_fixes: + return reviews + + # Scatter fixes across coders, then re-review + scatter( + needs_fixes, + coder, + lambda c, pair: c.implement_module( + json.dumps( + {**pair[0], "fix_feedback": pair[1]["feedback"]}, + indent=2, + ) + ), + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + WORKSPACE.mkdir(parents=True, exist_ok=True) + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # Create agents + architect = ArchitectAgent(agent_id="architect") + coder1 = CoderAgent(agent_id="coder-1") + coder2 = CoderAgent(agent_id="coder-2") + reviewer1 = ReviewerAgent(agent_id="reviewer-1") + reviewer2 = ReviewerAgent(agent_id="reviewer-2") + + # Build the choreography — all boilerplate (threads, queues, signal + # handling, crash recovery) is handled automatically. + choreo = Choreography( + build_project, + agents=[architect, coder1, coder2, reviewer1, reviewer2], + queue=PersistentTaskQueue(STATE_DIR / "task_queue.db"), + handlers=[ + LiteLLMProvider(model=MODEL), + RetryLLMHandler(), + PersistenceHandler(STATE_DIR / "checkpoints.db"), + ], + ) + + log.info("Starting multi-agent build (Ctrl-C to pause, re-run to resume)") + + try: + reviews = choreo.run( + project_spec=PROJECT_SPEC, + architect=architect, + coder=[coder1, coder2], + reviewer=[reviewer1, reviewer2], + ) + except ChoreographyError as e: + log.error("Choreography failed: %s", e) + return + + # Summary + output_files = list(OUTPUT_DIR.rglob("*.py")) + passed = sum(1 for r in reviews if r["verdict"] == "PASS") + log.info( + "Done: %d modules reviewed (%d passed), %d output files", + len(reviews), + passed, + len(output_files), + ) + for f in output_files: + log.info(" %s", f.relative_to(WORKSPACE)) + + +if __name__ == "__main__": + main() diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index fc6ca47a..e1108772 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -6,6 +6,7 @@ import inspect import string import textwrap +import threading import traceback import typing import uuid @@ -25,7 +26,7 @@ ) from effectful.handlers.llm.encoding import DecodedToolCall, Encodable -from effectful.handlers.llm.template import Template, Tool +from effectful.handlers.llm.template import Template, Tool, get_bound_agent from effectful.internals.unification import nested_type from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements @@ -71,6 +72,30 @@ def append_message(message: Message): pass +@Operation.define +def get_agent_history(agent_id: str) -> collections.OrderedDict[str, Message]: + """Get the message history for an agent. Returns empty OrderedDict by default.""" + return collections.OrderedDict() + + +class AgentHistoryHandler(ObjectInterpretation): + """Handler that stores per-agent message histories in memory. + + Install this handler to give :class:`Agent` instances persistent + in-memory histories across template calls:: + + with handler(AgentHistoryHandler()), handler(LiteLLMProvider()): + bot.ask("question") # history accumulates across calls + """ + + def __init__(self) -> None: + self._histories: dict[str, collections.OrderedDict[str, Message]] = {} + + @implements(get_agent_history) + def _get(self, agent_id: str) -> collections.OrderedDict[str, Message]: + return self._histories.setdefault(agent_id, collections.OrderedDict()) + + def _make_message(content: dict) -> Message: m_id = content.get("id") or str(uuid.uuid1()) message = typing.cast(Message, {**content, "id": m_id}) @@ -442,7 +467,11 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: class LiteLLMProvider(ObjectInterpretation): - """Implements templates using the LiteLLM API.""" + """Implements templates using the LiteLLM API. + + Also provides per-agent message history storage via + :func:`get_agent_history`. + """ config: collections.abc.Mapping[str, typing.Any] @@ -451,6 +480,19 @@ def __init__(self, model="gpt-4o", **config): "model": model, **inspect.signature(litellm.completion).bind_partial(**config).kwargs, } + self._histories: dict[str, collections.OrderedDict[str, Message]] = {} + self._tls = threading.local() + + def _get_depths(self) -> dict[str, int]: + if not hasattr(self._tls, "depths"): + self._tls.depths = {} + return self._tls.depths + + @implements(get_agent_history) + def _get_agent_history( + self, agent_id: str + ) -> collections.OrderedDict[str, Message]: + return self._histories.setdefault(agent_id, collections.OrderedDict()) @implements(Template.__apply__) def _call[**P, T]( @@ -464,29 +506,50 @@ def _call[**P, T]( # Create response_model with env so tools passed as arguments are available response_model = Encodable.define(template.__signature__.return_annotation, env) - history: collections.OrderedDict[str, Message] = getattr( - template, "__history__", collections.OrderedDict() - ) # type: ignore - history_copy = history.copy() + # Get history: from agent history handler if bound to an agent, else fresh + agent = get_bound_agent(template) + if agent is not None: + agent_id = agent.__agent_id__ + history = get_agent_history(agent_id) + else: + agent_id = None + history = collections.OrderedDict() + + # Track nesting depth per agent so only the outermost call writes back. + # Inner calls work on their own copy but discard it on return. + # See: TestNestedTemplateCalling.test_only_outermost_writes_to_history + depths = self._get_depths() + if agent_id is not None: + depth = depths.get(agent_id, 0) + depths[agent_id] = depth + 1 + is_outermost = depth == 0 + else: + depth = 0 + is_outermost = False - with handler({_get_history: lambda: history_copy}): - call_system(template) - - message: Message = call_user(template.__prompt_template__, env) - - # loop based on: https://cookbook.openai.com/examples/reasoning_function_calls - tool_calls: list[DecodedToolCall] = [] - result: T | None = None - while message["role"] != "assistant" or tool_calls: - message, tool_calls, result = call_assistant( - template.tools, response_model, **self.config - ) - for tool_call in tool_calls: - message = call_tool(tool_call) + history_copy = history.copy() try: - _get_history() - except NotImplementedError: - history.clear() - history.update(history_copy) - return typing.cast(T, result) + with handler({_get_history: lambda: history_copy}): + call_system(template) + + message: Message = call_user(template.__prompt_template__, env) + + # loop based on: https://cookbook.openai.com/examples/reasoning_function_calls + tool_calls: list[DecodedToolCall] = [] + result: T | None = None + while message["role"] != "assistant" or tool_calls: + message, tool_calls, result = call_assistant( + template.tools, response_model, **self.config + ) + for tool_call in tool_calls: + message = call_tool(tool_call) + + # Only outermost call writes back to canonical history + if is_outermost: + history.clear() + history.update(history_copy) + return typing.cast(T, result) + finally: + if agent_id is not None: + depths[agent_id] = depth diff --git a/effectful/handlers/llm/multi.py b/effectful/handlers/llm/multi.py new file mode 100644 index 00000000..019c78fb --- /dev/null +++ b/effectful/handlers/llm/multi.py @@ -0,0 +1,993 @@ +"""Choreographic programming for multi-agent LLM systems. + +Write a single function describing how agents interact from a global +perspective, then run it with automatic endpoint projection (EPP). +Each agent gets its own thread, inter-agent communication is +handled automatically via a persistent :class:`TaskQueue`, and the +entire process is crash-tolerant and restartable. + +**How it works.** The choreographic program is a plain Python function +whose arguments are agent instances. All agent threads run this same +function. The :class:`EndpointProjection` handler intercepts +:attr:`~effectful.handlers.llm.template.Template.__apply__`: + +- When it is the current agent's template: claim a task in the + queue, execute via ``fwd``, and store the result. +- When it is another agent's template: poll the queue until + the result appears. + +Each statement in the choreography is assigned an incrementing step ID. +Completed steps are persisted to disk. On restart, the program re-runs +from the start; completed steps return their cached results instantly, +and execution resumes from the first incomplete step. + +Example — sequential choreography with a review loop:: + + from pathlib import Path + from typing import Literal, TypedDict + + from effectful.handlers.llm import Template + from effectful.handlers.llm.completions import LiteLLMProvider, RetryLLMHandler + from effectful.handlers.llm.multi import Choreography + from effectful.handlers.llm.persistence import PersistenceHandler, PersistentAgent + from effectful.ops.types import NotHandled + + class ModuleSpec(TypedDict): + module_path: str + description: str + + class PlanResult(TypedDict): + modules: list[ModuleSpec] + + class ReviewResult(TypedDict): + verdict: Literal["PASS", "NEEDS_FIXES"] + feedback: str + + class Architect(PersistentAgent): + \"\"\"You are a software architect.\"\"\" + + @Template.define + def plan_modules(self, project_spec: str) -> PlanResult: + \"\"\"Break this project into modules: {project_spec}\"\"\" + raise NotHandled + + class Coder(PersistentAgent): + \"\"\"You are a Python developer.\"\"\" + + @Template.define + def implement_module(self, spec: str) -> str: + \"\"\"Implement the module: {spec}\"\"\" + raise NotHandled + + class Reviewer(PersistentAgent): + \"\"\"You are a code reviewer.\"\"\" + + @Template.define + def review_code(self, code: str) -> ReviewResult: + \"\"\"Review this code: {code}\"\"\" + raise NotHandled + + def build_codebase( + project_spec: str, + architect: Architect, + coder: Coder, + reviewer: Reviewer, + ) -> str: + plan = architect.plan_modules(project_spec) + code = coder.implement_module(str(plan)) + while True: + result = reviewer.review_code(code) + if result["verdict"] == "PASS": + return code + code = coder.implement_module(result["feedback"]) + + architect = Architect(agent_id="architect") + coder = Coder(agent_id="coder") + reviewer = Reviewer(agent_id="reviewer") + + choreo = Choreography( + build_codebase, + agents=[architect, coder, reviewer], + queue=PersistentTaskQueue(Path("./state/task_queue.db")), + handlers=[ + LiteLLMProvider(model="gpt-4o-mini"), + RetryLLMHandler(), + PersistenceHandler(Path("./state/checkpoints.db")), + ], + ) + # Kill at any point, restart, and it resumes where it left off. + result = choreo.run( + project_spec="Build a URL slugify library", + architect=architect, + coder=coder, + reviewer=reviewer, + ) + +Example — parallel scatter across multiple coders:: + + from effectful.handlers.llm.multi import Choreography, PersistentTaskQueue, scatter + + def build_parallel( + project_spec: str, + architect: Architect, + coder: Coder, + reviewer: Reviewer, + ) -> list[ReviewResult]: + plan = architect.plan_modules(project_spec) + # Each module becomes a task; coders claim from the queue + # until none remain — natural load balancing. + codes = scatter( + plan["modules"], coder, + lambda coder, mod: coder.implement_module(str(mod)), + ) + return [reviewer.review_code(code) for code in codes] + + coder1 = Coder(agent_id="coder-1") + coder2 = Coder(agent_id="coder-2") + coder3 = Coder(agent_id="coder-3") + + choreo = Choreography( + build_parallel, + agents=[architect, coder1, coder2, coder3, reviewer], + queue=PersistentTaskQueue(Path("./state/task_queue.db")), + handlers=[LiteLLMProvider(model="gpt-4o-mini"), RetryLLMHandler()], + ) + # Pass coder as a list — scatter distributes across all three + reviews = choreo.run( + project_spec="Build textkit with slugify, wrap, and redact modules", + architect=architect, + coder=[coder1, coder2, coder3], + reviewer=reviewer, + ) + +""" + +import abc +import contextlib +import json +import sqlite3 +import threading +import time +import uuid +from collections.abc import Callable, Sequence +from enum import StrEnum +from pathlib import Path +from typing import Any + +from effectful.handlers.llm.template import Agent, Template, get_bound_agent +from effectful.ops.semantics import fwd, handler +from effectful.ops.syntax import ObjectInterpretation, implements +from effectful.ops.types import Interpretation, Operation + +# ── TaskQueue ────────────────────────────────────────────────────── + + +class TaskStatus(StrEnum): + PENDING = "pending" + CLAIMED = "claimed" + DONE = "done" + FAILED = "failed" + + +class TaskQueue(abc.ABC): + """Abstract task queue with claim-based ownership. + + Subclasses implement persistent (file-based) or in-memory storage. + All methods are thread-safe. + """ + + @abc.abstractmethod + def submit( + self, + task_type: str, + payload: dict, + task_id: str | None = None, + ) -> str: + """Add a new task. Returns the task ID. + + Idempotent when *task_id* is specified: if a task with that ID + already exists (in any state), the call is a no-op. + """ + + @abc.abstractmethod + def claim(self, task_type: str, owner: str) -> dict | None: + """Atomically claim the next pending task of the given type. + + Returns the task dict if one was claimed, or ``None``. + """ + + @abc.abstractmethod + def claim_by_prefix(self, prefix: str, owner: str) -> dict | None: + """Claim any pending task whose ID starts with *prefix*.""" + + @abc.abstractmethod + def complete(self, task_id: str, owner: str, result: Any = None) -> None: + """Mark a claimed task as done with *result*.""" + + @abc.abstractmethod + def fail(self, task_id: str, owner: str, error: str) -> None: + """Mark a claimed task as failed.""" + + @abc.abstractmethod + def get_result(self, task_id: str) -> Any | None: + """Return the result of a completed task, or ``None``.""" + + @abc.abstractmethod + def release_stale_claims(self, owner: str) -> int: + """Release tasks claimed by *owner* back to pending. + + Call on startup to reclaim work from a prior crashed session. + """ + + @abc.abstractmethod + def pending_count(self, task_type: str | None = None) -> int: + """Count pending tasks, optionally filtered by type.""" + + @abc.abstractmethod + def all_done(self) -> bool: + """``True`` if no pending or claimed tasks remain.""" + + +class InMemoryTaskQueue(TaskQueue): + """In-memory task queue for testing or ephemeral workflows. + + Not crash-tolerant — all state is lost when the process exits. + Thread-safe via a single lock. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._tasks: dict[str, dict] = {} # task_id -> task dict + + def submit( + self, + task_type: str, + payload: dict, + task_id: str | None = None, + ) -> str: + if task_id is None: + task_id = str(uuid.uuid4())[:8] + with self._lock: + if task_id in self._tasks: + return task_id + self._tasks[task_id] = { + "id": task_id, + "type": task_type, + "payload": payload, + "status": TaskStatus.PENDING, + "owner": "", + "result": None, + } + return task_id + + def claim(self, task_type: str, owner: str) -> dict | None: + with self._lock: + for task_id in sorted(self._tasks): + task = self._tasks[task_id] + if task["status"] == TaskStatus.PENDING and task["type"] == task_type: + task["status"] = TaskStatus.CLAIMED + task["owner"] = owner + return dict(task) + return None + + def claim_by_prefix(self, prefix: str, owner: str) -> dict | None: + with self._lock: + for task_id in sorted(self._tasks): + task = self._tasks[task_id] + if task["status"] == TaskStatus.PENDING and task_id.startswith(prefix): + task["status"] = TaskStatus.CLAIMED + task["owner"] = owner + return dict(task) + return None + + def complete(self, task_id: str, owner: str, result: Any = None) -> None: + with self._lock: + task = self._tasks.get(task_id) + if task is None or task["status"] != TaskStatus.CLAIMED: + return + task["status"] = TaskStatus.DONE + task["result"] = result + + def fail(self, task_id: str, owner: str, error: str) -> None: + with self._lock: + task = self._tasks.get(task_id) + if task is None or task["status"] != TaskStatus.CLAIMED: + return + task["status"] = TaskStatus.FAILED + task["result"] = {"error": error} + + def get_result(self, task_id: str) -> Any | None: + with self._lock: + task = self._tasks.get(task_id) + if task is not None and task["status"] == TaskStatus.DONE: + return task["result"] + return None + + def release_stale_claims(self, owner: str) -> int: + count = 0 + with self._lock: + for task in self._tasks.values(): + if task["status"] == TaskStatus.CLAIMED and task["owner"] == owner: + task["status"] = TaskStatus.PENDING + task["owner"] = "" + count += 1 + return count + + def pending_count(self, task_type: str | None = None) -> int: + with self._lock: + return sum( + 1 + for t in self._tasks.values() + if t["status"] == TaskStatus.PENDING + and (task_type is None or t["type"] == task_type) + ) + + def all_done(self) -> bool: + with self._lock: + return not any( + t["status"] in (TaskStatus.PENDING, TaskStatus.CLAIMED) + for t in self._tasks.values() + ) + + +def _init_queue_db(conn: sqlite3.Connection) -> None: + """Create the tasks table and configure WAL mode for crash tolerance.""" + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + payload TEXT NOT NULL DEFAULT '{}', + status TEXT NOT NULL DEFAULT 'pending', + owner TEXT NOT NULL DEFAULT '', + result TEXT + ) + """ + ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_tasks_status_type ON tasks(status, type)" + ) + conn.commit() + + +class PersistentTaskQueue(TaskQueue): + """SQLite-backed task queue with claim-based ownership. + + All task state is stored in a single SQLite database using WAL + journal mode for crash tolerance. If the process is killed + mid-transaction, SQLite's journal-based recovery ensures the + database remains consistent. + + Claiming a task atomically updates its status from ``pending`` to + ``claimed`` inside a transaction, preventing double-claiming even + across process restarts. + + The queue is fully crash-tolerant: call + :meth:`release_stale_claims` on restart to reclaim work from a + prior crashed session. + + Args: + db_path: Path to the SQLite database file. + """ + + def __init__(self, db_path: Path): + self._db_path = Path(db_path) + self._lock = threading.Lock() + self._db_initialized = False + self._init_lock = threading.Lock() + + @property + def db_path(self) -> Path: + """Path to the SQLite database file.""" + return self._db_path + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(str(self._db_path), timeout=10) + conn.execute("PRAGMA busy_timeout=5000") + if not self._db_initialized: + with self._init_lock: + if not self._db_initialized: + _init_queue_db(conn) + self._db_initialized = True + return conn + + def submit( + self, + task_type: str, + payload: dict, + task_id: str | None = None, + ) -> str: + if task_id is None: + task_id = str(uuid.uuid4())[:8] + payload_json = json.dumps(payload, default=str) + conn = self._connect() + try: + conn.execute( + """ + INSERT OR IGNORE INTO tasks (id, type, payload, status, owner, result) + VALUES (?, ?, ?, ?, '', NULL) + """, + (task_id, task_type, payload_json, TaskStatus.PENDING), + ) + conn.commit() + finally: + conn.close() + return task_id + + def claim(self, task_type: str, owner: str) -> dict | None: + with self._lock: + conn = self._connect() + try: + row = conn.execute( + """ + SELECT id, type, payload, status, owner, result + FROM tasks + WHERE status = ? AND type = ? + ORDER BY id LIMIT 1 + """, + (TaskStatus.PENDING, task_type), + ).fetchone() + if row is None: + return None + task_id = row[0] + conn.execute( + "UPDATE tasks SET status = ?, owner = ? WHERE id = ?", + (TaskStatus.CLAIMED, owner, task_id), + ) + conn.commit() + return { + "id": task_id, + "type": row[1], + "payload": json.loads(row[2]), + "status": TaskStatus.CLAIMED, + "owner": owner, + "result": json.loads(row[5]) if row[5] is not None else None, + } + finally: + conn.close() + + def claim_by_prefix(self, prefix: str, owner: str) -> dict | None: + with self._lock: + conn = self._connect() + try: + row = conn.execute( + """ + SELECT id, type, payload, status, owner, result + FROM tasks + WHERE status = ? AND id LIKE ? + ORDER BY id LIMIT 1 + """, + (TaskStatus.PENDING, prefix + "%"), + ).fetchone() + if row is None: + return None + task_id = row[0] + conn.execute( + "UPDATE tasks SET status = ?, owner = ? WHERE id = ?", + (TaskStatus.CLAIMED, owner, task_id), + ) + conn.commit() + return { + "id": task_id, + "type": row[1], + "payload": json.loads(row[2]), + "status": TaskStatus.CLAIMED, + "owner": owner, + "result": json.loads(row[5]) if row[5] is not None else None, + } + finally: + conn.close() + + def complete(self, task_id: str, owner: str, result: Any = None) -> None: + result_json = json.dumps(result, default=str) + conn = self._connect() + try: + conn.execute( + """ + UPDATE tasks SET status = ?, result = ? + WHERE id = ? AND status = ? + """, + (TaskStatus.DONE, result_json, task_id, TaskStatus.CLAIMED), + ) + conn.commit() + finally: + conn.close() + + def fail(self, task_id: str, owner: str, error: str) -> None: + error_json = json.dumps({"error": error}, default=str) + conn = self._connect() + try: + conn.execute( + """ + UPDATE tasks SET status = ?, result = ? + WHERE id = ? AND status = ? + """, + (TaskStatus.FAILED, error_json, task_id, TaskStatus.CLAIMED), + ) + conn.commit() + finally: + conn.close() + + def get_result(self, task_id: str) -> Any | None: + conn = self._connect() + try: + row = conn.execute( + "SELECT result FROM tasks WHERE id = ? AND status = ?", + (task_id, TaskStatus.DONE), + ).fetchone() + if row is None: + return None + return json.loads(row[0]) if row[0] is not None else None + finally: + conn.close() + + def release_stale_claims(self, owner: str) -> int: + with self._lock: + conn = self._connect() + try: + cursor = conn.execute( + """ + UPDATE tasks SET status = ?, owner = '' + WHERE status = ? AND owner = ? + """, + (TaskStatus.PENDING, TaskStatus.CLAIMED, owner), + ) + conn.commit() + return cursor.rowcount + finally: + conn.close() + + def pending_count(self, task_type: str | None = None) -> int: + conn = self._connect() + try: + if task_type is None: + row = conn.execute( + "SELECT COUNT(*) FROM tasks WHERE status = ?", + (TaskStatus.PENDING,), + ).fetchone() + else: + row = conn.execute( + "SELECT COUNT(*) FROM tasks WHERE status = ? AND type = ?", + (TaskStatus.PENDING, task_type), + ).fetchone() + return row[0] if row else 0 + finally: + conn.close() + + def all_done(self) -> bool: + conn = self._connect() + try: + row = conn.execute( + "SELECT COUNT(*) FROM tasks WHERE status IN (?, ?)", + (TaskStatus.PENDING, TaskStatus.CLAIMED), + ).fetchone() + return row[0] == 0 if row else True + finally: + conn.close() + + +# ── scatter ──────────────────────────────────────────────────────── + + +@Operation.define +def scatter(items: list, agent: Agent, fn: Callable) -> list: + """Distribute *items* by calling ``fn(agent, item)`` for each item. + + **Default** (no EPP handler): sequential + ``[fn(agent, item) for item in items]``. + + **Under** :class:`EndpointProjection`: each item becomes a task in + the queue. When a list of agents is passed for the same role + (e.g. ``coder=[coder1, coder2]``), agents claim tasks until none + remain — providing natural load balancing with crash recovery. + On restart, completed items are returned from cache; only + remaining items are re-executed. + + .. warning:: + + ``fn`` should only call templates on the assigned agent. + Cross-agent template calls inside scatter are not supported. + """ + return [fn(agent, item) for item in items] + + +@Operation.define +def fan_out(groups: list[tuple[list, Agent, Callable]]) -> list[list]: + """Run multiple scatter-like operations concurrently. + + Each element of *groups* is a ``(items, agent, fn)`` triple — the + same arguments you would pass to :func:`scatter`. Returns a list + of result lists, one per group, in the same order as *groups*. + + **Default** (no EPP handler): sequential execution of each group:: + + [ + [fn(agent, item) for item in items] + for items, agent, fn in groups + ] + + **Under** :class:`EndpointProjection`: all groups' items are + submitted as tasks under a single step ID. Agents from *every* + group claim and execute work concurrently, so a spec-writer, + tester, and prover can all be working at the same time rather + than waiting for the previous scatter to finish. + + Example:: + + spec_results, test_results, proof_results = fan_out([ + (spec_tasks, spec_writer, + lambda w, b: w.write_spec(json.dumps(b, indent=2))), + (test_tasks, tester, + lambda t, b: t.write_tests_and_validate(json.dumps(b, indent=2))), + (proof_tasks, prover, + lambda p, b: p.prove_theorem(json.dumps(b, indent=2))), + ]) + + .. warning:: + + ``fn`` should only call templates on the assigned agent. + Cross-agent template calls inside fan_out are not supported. + """ + return [[fn(agent, item) for item in items] for items, agent, fn in groups] + + +# ── Endpoint Projection ─────────────────────────────────────────── + + +class ChoreographyError(Exception): + """Raised when a choreography fails due to an agent error.""" + + +class CancelledError(Exception): + """Raised inside an agent thread when the choreography is cancelled.""" + + +class EndpointProjection(ObjectInterpretation): + """Handler that projects a choreographic program onto a single agent. + + Each template call in the choreography is assigned a step ID + (incrementing counter). Steps become tasks in the + :class:`TaskQueue`. + + - **Own agent's templates**: check if the step is already done + (cached); if not, claim the task, execute, and store the result. + - **Other agent's templates**: poll the queue until the result + appears. + - **Unbound templates**: execute directly on all threads. + + Also handles :func:`scatter` for data-parallel distribution via + claim-based pull. + """ + + def __init__( + self, + agent: Agent, + queue: TaskQueue, + agent_ids: frozenset[str], + poll_interval: float = 0.1, + cancel_event: threading.Event | None = None, + ) -> None: + self._agent = agent + self._agent_id = agent.__agent_id__ + self._queue = queue + self._agent_ids = agent_ids + self._poll = poll_interval + self._step = 0 + self._in_scatter = False + self._cancel = cancel_event + + def _next_step(self) -> str: + step_id = f"step-{self._step:04d}" + self._step += 1 + return step_id + + def _check_cancelled(self) -> None: + if self._cancel is not None and self._cancel.is_set(): + raise CancelledError("Choreography cancelled") + + def _wait_result(self, step_id: str) -> Any: + """Poll queue until task result is available.""" + while True: + self._check_cancelled() + r = self._queue.get_result(step_id) + if r is not None: + return r + time.sleep(self._poll) + + @implements(Template.__apply__) + def _call[**P, T]( + self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs + ) -> T: + bound = get_bound_agent(template) + + # Inside scatter: execute directly, no task management + if self._in_scatter: + if bound and bound.__agent_id__ == self._agent_id: + return fwd(template, *args, **kwargs) + raise RuntimeError( + f"Cross-agent call in scatter: {self._agent_id} -> " + f"{bound.__agent_id__ if bound else '?'}" + ) + + step_id = self._next_step() + + if bound is not None and bound.__agent_id__ == self._agent_id: + # My template — check done cache, or claim and execute + cached = self._queue.get_result(step_id) + if cached is not None: + return cached + + self._queue.submit( + task_type=template.__name__, + payload={"agent": self._agent_id}, + task_id=step_id, + ) + task = self._queue.claim(template.__name__, self._agent_id) + if task is None: + # Already claimed (e.g. restarted while another thread + # is executing) — poll for result + return self._wait_result(step_id) + + try: + result = fwd(template, *args, **kwargs) + self._queue.complete(step_id, self._agent_id, result) + return result + except Exception as e: + self._queue.fail(step_id, self._agent_id, str(e)) + raise + + elif bound is not None: + # Another agent's template — poll for result + return self._wait_result(step_id) + + else: + # Unbound template — execute directly + return fwd(template, *args, **kwargs) + + @implements(scatter) + def _scatter(self, items: list, agent: Agent, fn: Callable) -> list: + step_id = self._next_step() + + # agent may be a single Agent or a list of Agents (passed + # transparently from choreo.run kwargs). Normalize to a list. + agents = agent if isinstance(agent, list) else [agent] + scatter_ids = {a.__agent_id__ for a in agents} + + # Submit one task per item. All agent threads execute this + # loop, but submit() is idempotent on task_id — the + # deterministic ID (step_id:index) ensures each task is + # created exactly once regardless of how many threads call it. + for i in range(len(items)): + self._queue.submit( + task_type=f"scatter-{step_id}", + payload={"item_index": i}, + task_id=f"{step_id}:{i:04d}", + ) + + # If I'm a scatter agent, claim and execute until none left + if self._agent_id in scatter_ids: + while True: + task = self._queue.claim_by_prefix(f"{step_id}:", self._agent_id) + if task is None: + break + idx = task["payload"]["item_index"] + self._in_scatter = True + try: + result = fn(self._agent, items[idx]) + self._queue.complete(task["id"], self._agent_id, result) + except Exception as e: + self._queue.fail(task["id"], self._agent_id, str(e)) + raise + finally: + self._in_scatter = False + + # Gather all results (blocking until done) + return [self._wait_result(f"{step_id}:{i:04d}") for i in range(len(items))] + + @implements(fan_out) + def _fan_out(self, groups: list[tuple[list, Agent, Callable]]) -> list[list]: + step_id = self._next_step() + + # For each group, normalize agents and build a mapping from + # agent_id → list of (group_index, items, fn) so each agent + # knows which groups it participates in. + group_agents: list[set[str]] = [] + group_fns: list[Callable] = [] + group_items: list[list] = [] + + for g, (items, agent, fn) in enumerate(groups): + agents = agent if isinstance(agent, list) else [agent] + group_agents.append({a.__agent_id__ for a in agents}) + group_fns.append(fn) + group_items.append(items) + + # Submit all tasks across all groups. Deterministic IDs: + # {step_id}:g{group}:{item_index} + for g in range(len(groups)): + for i in range(len(group_items[g])): + self._queue.submit( + task_type=f"fan-{step_id}:g{g}", + payload={"group": g, "item_index": i}, + task_id=f"{step_id}:g{g}:{i:04d}", + ) + + # Claim and execute tasks from my groups + my_groups = [g for g in range(len(groups)) if self._agent_id in group_agents[g]] + for g in my_groups: + prefix = f"{step_id}:g{g}:" + while True: + task = self._queue.claim_by_prefix(prefix, self._agent_id) + if task is None: + break + idx = task["payload"]["item_index"] + self._in_scatter = True + try: + result = group_fns[g](self._agent, group_items[g][idx]) + self._queue.complete(task["id"], self._agent_id, result) + except Exception as e: + self._queue.fail(task["id"], self._agent_id, str(e)) + raise + finally: + self._in_scatter = False + + # Gather results per group (blocking) + return [ + [ + self._wait_result(f"{step_id}:g{g}:{i:04d}") + for i in range(len(group_items[g])) + ] + for g in range(len(groups)) + ] + + +# ── Choreography runner ─────────────────────────────────────────── + + +class Choreography: + """Run a choreographic program with endpoint projection. + + Each agent gets its own thread. Template calls are routed via + the :class:`TaskQueue`: the owning agent claims and executes, + others poll for results. On restart, completed steps are + returned from cache. + + Args: + program: The choreographic function. All agent threads run + this same function; EPP makes each thread behave + differently. + agents: The agents participating in the choreography. + queue: The task queue to use. Defaults to + :class:`InMemoryTaskQueue` if not provided. Pass a + :class:`PersistentTaskQueue` for crash tolerance. + handlers: Handler instances to install per-thread beneath + the EPP handler (e.g. LLM provider, retry handler, + persistence handler). + poll_interval: Seconds between polling for task results + (default 0.1). + + Example:: + + choreo = Choreography( + build_codebase, + agents=[architect, coder, reviewer], + queue=PersistentTaskQueue(Path("./state/task_queue.db")), + handlers=[ + LiteLLMProvider(model="gpt-4o-mini"), + RetryLLMHandler(), + PersistenceHandler(Path("./state/checkpoints.db")), + ], + ) + result = choreo.run( + project_spec="Build a library...", + architect=architect, + coder=coder, + reviewer=reviewer, + ) + """ + + def __init__( + self, + program: Callable[..., Any], + agents: Sequence[Agent], + queue: TaskQueue | None = None, + handlers: Sequence[Interpretation | ObjectInterpretation] | None = None, + poll_interval: float = 0.1, + ) -> None: + self.program = program + self.agents = list(agents) + self.handlers = list(handlers or []) + self.poll_interval = poll_interval + self._queue = queue if queue is not None else InMemoryTaskQueue() + + @property + def queue(self) -> TaskQueue: + """The underlying task queue (for inspection or manual ops).""" + return self._queue + + def project( + self, + agent: Agent, + cancel_event: threading.Event | None = None, + ) -> EndpointProjection: + """Return the EPP handler for a specific agent. + + Useful for manual thread management:: + + proj = choreo.project(agent) + with handler(provider), handler(proj): + result = choreo.program(**kwargs) + """ + return EndpointProjection( + agent, + self._queue, + frozenset(a.__agent_id__ for a in self.agents), + self.poll_interval, + cancel_event=cancel_event, + ) + + def run(self, **kwargs: Any) -> Any: + """Run the choreography to completion. + + Keyword arguments are forwarded to the choreographic function. + Returns the result (identical across all agent threads). + + On restart after a crash, completed steps return cached + results; stale claims are released and re-executed. + + Raises: + ChoreographyError: If any agent thread fails. + """ + # Release stale claims from prior crashed run + for agent in self.agents: + self._queue.release_stale_claims(agent.__agent_id__) + + cancel = threading.Event() + results: dict[str, Any] = {} + errors: list[tuple[str, BaseException]] = [] + lock = threading.Lock() + + def agent_main(agent: Agent) -> None: + try: + proj = self.project(agent, cancel_event=cancel) + result = self._run_with_handlers(proj, **kwargs) + with lock: + results[agent.__agent_id__] = result + except CancelledError: + pass # another agent failed; this thread was cancelled + except BaseException as e: + cancel.set() # signal other threads to stop + with lock: + errors.append((agent.__agent_id__, e)) + + threads = [] + for agent in self.agents: + t = threading.Thread( + target=agent_main, + args=(agent,), + name=f"choreo-{agent.__agent_id__}", + daemon=True, + ) + t.start() + threads.append(t) + + for t in threads: + t.join() + + if errors: + agent_id, exc = errors[0] + raise ChoreographyError(f"Agent '{agent_id}' failed: {exc}") from exc + + # All agents compute the same result; return any. + return next(iter(results.values())) + + def _run_with_handlers(self, proj: EndpointProjection, **kwargs: Any) -> Any: + """Install handlers and EPP, then run the program.""" + with contextlib.ExitStack() as stack: + for h in self.handlers: + stack.enter_context(handler(h)) + # EPP outermost — intercepts before providers + stack.enter_context(handler(proj)) + return self.program(**kwargs) diff --git a/effectful/handlers/llm/persistence.py b/effectful/handlers/llm/persistence.py new file mode 100644 index 00000000..1d93ae9a --- /dev/null +++ b/effectful/handlers/llm/persistence.py @@ -0,0 +1,444 @@ +import dataclasses +import json +import sqlite3 +import threading +from collections import OrderedDict +from pathlib import Path +from typing import Any + +from effectful.handlers.llm.completions import get_agent_history +from effectful.handlers.llm.template import Agent, Template, get_bound_agent +from effectful.ops.semantics import fwd +from effectful.ops.syntax import ObjectInterpretation, implements +from effectful.ops.types import NotHandled + + +@Template.define +def summarize_context(transcript: str) -> str: + """Summarise the following conversation transcript into a concise + context summary. Preserve key facts, decisions, and any + information the agent would need to continue working. + + Transcript: + {transcript}""" + raise NotHandled + + +class PersistentAgent(Agent): + """An :class:`Agent` whose history can be persisted by :class:`PersistenceHandler`. + + This is a lightweight marker class. All persistence *behaviour* + (checkpointing, handoff, DB I/O) lives in :class:`PersistenceHandler`, + a composable handler following the same pattern as + :class:`~effectful.handlers.llm.completions.RetryLLMHandler`. + + Unlike plain :class:`Agent` (which uses ``id(self)`` by default), + ``PersistentAgent`` **requires** a stable ``agent_id`` so that + checkpoints can be matched across process restarts. + + Override :meth:`checkpoint_state` and :meth:`restore_state` to persist + custom subclass state alongside the message history. + + **Usage**:: + + from pathlib import Path + from effectful.handlers.llm.persistence import PersistentAgent, PersistenceHandler + from effectful.handlers.llm import Template + from effectful.handlers.llm.completions import LiteLLMProvider + from effectful.ops.semantics import handler + from effectful.ops.types import NotHandled + + class ResearchBot(PersistentAgent): + \"""You are a research assistant that remembers prior sessions.\""" + + @Template.define + def ask(self, question: str) -> str: + \"""Answer: {question}\""" + raise NotHandled + + bot = ResearchBot(agent_id="research-bot") + + with handler(LiteLLMProvider()), handler(PersistenceHandler(Path("./state/checkpoints.db"))): + bot.ask("What is the capital of France?") + # Kill process here, restart, and the bot resumes with context. + """ + + def __init__(self, *, agent_id: str): + self.__agent_id__ = agent_id + + def checkpoint_state(self) -> dict[str, Any]: + """Return a JSON-serialisable dict of subclass state to persist. + + The default implementation serialises all + :func:`dataclasses.dataclass` fields. Override this (and + :meth:`restore_state`) for custom serialisation. + """ + if not dataclasses.is_dataclass(self): + return {} + state: dict[str, Any] = {} + for f in dataclasses.fields(self): + val = getattr(self, f.name) + try: + json.dumps(val) + state[f.name] = val + except (TypeError, ValueError): + pass # skip non-serialisable fields + return state + + def restore_state(self, state: dict[str, Any]) -> None: + """Restore subclass state from *state* dict. + + The default implementation sets each key as an attribute. + Override this (and :meth:`checkpoint_state`) for custom + deserialisation. + """ + for key, value in state.items(): + setattr(self, key, value) + + +def _init_db(conn: sqlite3.Connection) -> None: + """Create the checkpoints table and configure WAL mode for crash tolerance.""" + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS checkpoints ( + agent_id TEXT PRIMARY KEY, + handoff TEXT NOT NULL DEFAULT '', + state TEXT NOT NULL DEFAULT '{}', + history TEXT NOT NULL DEFAULT '[]' + ) + """ + ) + conn.commit() + + +class PersistenceHandler(ObjectInterpretation): + """Handler that persists :class:`PersistentAgent` history to a SQLite database. + + Install alongside + :class:`~effectful.handlers.llm.completions.LiteLLMProvider`:: + + with handler(LiteLLMProvider()), handler(PersistenceHandler(Path("./state/checkpoints.db"))): + bot.ask("question") + + Uses SQLite WAL mode for crash tolerance. If the process is killed + mid-write, SQLite's journal-based recovery ensures the database + remains consistent. + + All state is read from and written to the database directly — no + in-memory caching. This makes the handler stateless (aside from + nesting depth tracking) and easy to reason about. + + **Automatic checkpointing**: + + - **Before** each top-level template call: saves a checkpoint with a + handoff note describing the in-progress work. + - **After** each successful call: clears the handoff and saves again. + - **On failure**: saves the checkpoint (with handoff) so the next + session can resume. + + **Crash recovery**: on the next run, the handoff note from the prior + crash is injected into the system prompt so the LLM can resume. + + **Nested calls** (e.g. a tool calling another template on the same + agent) are passed through without additional checkpointing. + + Composes with :class:`~effectful.handlers.llm.completions.RetryLLMHandler` + and :class:`CompactionHandler`:: + + with ( + handler(LiteLLMProvider()), + handler(RetryLLMHandler()), + handler(CompactionHandler()), + handler(PersistenceHandler(Path("./state/checkpoints.db"))), + ): + bot.ask("question") + + **Crash recovery example**:: + + from pathlib import Path + from effectful.handlers.llm.persistence import PersistentAgent, PersistenceHandler + from effectful.handlers.llm import Template + from effectful.handlers.llm.completions import LiteLLMProvider + from effectful.ops.semantics import handler + from effectful.ops.types import NotHandled + + class Bot(PersistentAgent): + \"""You are a helpful assistant.\""" + + @Template.define + def work(self, task: str) -> str: + \"""Do: {task}\""" + raise NotHandled + + bot = Bot(agent_id="worker") + persist = PersistenceHandler(Path("./state/checkpoints.db")) + + # Session 1 — process crashes mid-call + with handler(LiteLLMProvider(model="gpt-4o-mini")), handler(persist): + bot.work("step 1") # completes, checkpointed + bot.work("step 2") # process killed here + + # Session 2 — restart with the same db_path + bot2 = Bot(agent_id="worker") + persist2 = PersistenceHandler(Path("./state/checkpoints.db")) + with handler(LiteLLMProvider(model="gpt-4o-mini")), handler(persist2): + # History from session 1 is restored automatically. + # The handoff note "Executing work ..." tells the LLM what + # was in progress when the crash occurred. + bot2.work("step 2") # resumes with full context + + Use :meth:`save` for manual checkpointing outside the automatic flow + (e.g. after initialising agent state in a choreography). + + Args: + db_path: Path to the SQLite database file. + """ + + def __init__(self, db_path: Path) -> None: + self._db_path = Path(db_path) + self._tls = threading.local() + self._db_lock = threading.Lock() + self._db_initialized = False + + def _connect(self) -> sqlite3.Connection: + """Open a new SQLite connection to the checkpoint database. + + Each call returns a fresh connection, making it safe to use from + any thread. WAL mode and table creation are applied once on the + first call (guarded by ``_db_lock``). + """ + conn = sqlite3.connect(str(self._db_path)) + conn.execute("PRAGMA busy_timeout=5000") + if not self._db_initialized: + with self._db_lock: + if not self._db_initialized: + _init_db(conn) + self._db_initialized = True + return conn + + @property + def db_path(self) -> Path: + """Path to the SQLite database file.""" + return self._db_path + + def _get_depths(self) -> dict[str, int]: + if not hasattr(self._tls, "depths"): + self._tls.depths = {} + return self._tls.depths + + def _load_row(self, agent_id: str) -> tuple[str, str, str] | None: + """Read a checkpoint row from the database. + + Returns ``(handoff, state_json, history_json)`` or ``None``. + """ + conn = self._connect() + try: + row = conn.execute( + "SELECT handoff, state, history FROM checkpoints WHERE agent_id = ?", + (agent_id,), + ).fetchone() + finally: + conn.close() + return row + + def _ensure_loaded(self, agent: PersistentAgent) -> bool: + """Load an agent's checkpoint from the database into the in-process history. + + Safe to call multiple times — only loads once per agent (tracked + via thread-local ``_loaded`` set to avoid re-seeding history that + is already live in memory). + + Returns ``True`` if a checkpoint was found and loaded. + """ + agent_id = agent.__agent_id__ + loaded = self._get_loaded() + if agent_id in loaded: + return False + loaded.add(agent_id) + + row = self._load_row(agent_id) + if row is None: + return False + + _handoff, state_json, history_json = row + agent.restore_state(json.loads(state_json)) + stored = get_agent_history(agent_id) + stored.clear() + stored.update({msg["id"]: msg for msg in json.loads(history_json)}) + return True + + def _get_loaded(self) -> set[str]: + if not hasattr(self._tls, "loaded"): + self._tls.loaded = set() + return self._tls.loaded + + def save(self, agent: PersistentAgent, handoff: str = "") -> Path: + """Write an agent's current state to the database and return the db path.""" + agent_id = agent.__agent_id__ + history = get_agent_history(agent_id) + state_json = json.dumps(agent.checkpoint_state(), default=str) + history_json = json.dumps(list(history.values()), default=str) + + conn = self._connect() + try: + conn.execute( + """ + INSERT INTO checkpoints (agent_id, handoff, state, history) + VALUES (?, ?, ?, ?) + ON CONFLICT(agent_id) DO UPDATE SET + handoff = excluded.handoff, + state = excluded.state, + history = excluded.history + """, + (agent_id, handoff, state_json, history_json), + ) + conn.commit() + finally: + conn.close() + + return self.db_path + + def _get_handoff(self, agent_id: str) -> str: + """Return the current handoff note for *agent_id* (reads from DB).""" + row = self._load_row(agent_id) + if row is None: + return "" + return row[0] + + @implements(Template.__apply__) + def _call[**P, T]( + self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs + ) -> T: + agent = get_bound_agent(template) + if not isinstance(agent, PersistentAgent): + return fwd(template, *args, **kwargs) + + agent_id = agent.__agent_id__ + self._ensure_loaded(agent) + + # Nesting: only checkpoint for outermost call per agent + depths = self._get_depths() + depth = depths.get(agent_id, 0) + depths[agent_id] = depth + 1 + is_outermost = depth == 0 + + try: + if is_outermost: + # Inject prior-session handoff into system prompt + prior_handoff = self._get_handoff(agent_id) + if prior_handoff: + template.__system_prompt__ = ( + f"{template.__system_prompt__}\n\n" + f"[HANDOFF FROM PRIOR SESSION] {prior_handoff}" + ) + + # Record current call as handoff for crash recovery + current_handoff = ( + f"Executing {template.__name__} with args={repr(args)[:200]}" + ) + self.save(agent, handoff=current_handoff) + + result = fwd(template, *args, **kwargs) + + if is_outermost: + self.save(agent, handoff="") + + return result + except BaseException: + if is_outermost: + # Preserve handoff so next session knows what was in progress + self.save(agent, handoff=current_handoff) + raise + finally: + depths[agent_id] = depth + + +class CompactionHandler(ObjectInterpretation): + """Handler that compacts agent history when it exceeds a threshold. + + After each top-level template call on an :class:`Agent`, if the + message history exceeds ``max_history_len``, older messages are + summarised into a single context-summary message via an LLM call.:: + + with handler(LiteLLMProvider()), handler(CompactionHandler(max_history_len=20)): + agent.ask("question") # history auto-compacted after call + """ + + def __init__(self, max_history_len: int = 50) -> None: + self._max_history_len = max_history_len + self._tls = threading.local() + + def _get_depths(self) -> dict[str, int]: + if not hasattr(self._tls, "depths"): + self._tls.depths = {} + return self._tls.depths + + @implements(Template.__apply__) + def _call[**P, T]( + self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs + ) -> T: + agent = get_bound_agent(template) + if not isinstance(agent, Agent): + return fwd(template, *args, **kwargs) + + agent_id = agent.__agent_id__ + depths = self._get_depths() + depth = depths.get(agent_id, 0) + depths[agent_id] = depth + 1 + is_outermost = depth == 0 + + try: + result = fwd(template, *args, **kwargs) + + if is_outermost: + history = get_agent_history(agent_id) + if len(history) > self._max_history_len: + self._compact(agent_id, history) + + return result + finally: + depths[agent_id] = depth + + def _compact(self, agent_id: str, history: OrderedDict[str, Any]) -> None: + keep_recent = max(self._max_history_len // 2, 4) + items = list(history.items()) + if len(items) <= keep_recent: + return + + split = len(items) - keep_recent + # Never split between a tool_use and its tool_result(s). + while split > 0 and items[split][1].get("role") == "tool": + split -= 1 + if split <= 0: + return + + old_items = items[:split] + recent_items = items[split:] + + old_text_parts: list[str] = [] + for _, msg in old_items: + role = msg.get("role", "unknown") + content = msg.get("content", "") + if isinstance(content, list): + text_parts = [p.get("text", "") for p in content if isinstance(p, dict)] + content = " ".join(text_parts) + if content: + old_text_parts.append(f"[{role}]: {content}") + old_transcript = "\n".join(old_text_parts) + + if not old_transcript.strip(): + return + + summary = summarize_context(old_transcript) + + summary_msg: dict[str, Any] = { + "id": f"compaction-{agent_id}", + "role": "user", + "content": f"[CONTEXT SUMMARY FROM PRIOR CONVERSATION]\n{summary}", + } + history.clear() + history[summary_msg["id"]] = summary_msg + for key, msg in recent_items: + history[key] = msg diff --git a/effectful/handlers/llm/template.py b/effectful/handlers/llm/template.py index 93e7f085..3bb418d5 100644 --- a/effectful/handlers/llm/template.py +++ b/effectful/handlers/llm/template.py @@ -5,7 +5,7 @@ import string import types import typing -from collections import ChainMap, OrderedDict +from collections import ChainMap from collections.abc import Callable, Mapping, MutableMapping from typing import Annotated, Any @@ -257,8 +257,6 @@ def __get__[S](self, instance: S | None, owner: type[S] | None = None): self_param_name = list(self.__signature__.parameters.keys())[0] result.__context__ = self.__context__.new_child({self_param_name: instance}) if isinstance(instance, Agent): - assert isinstance(result, Template) and not hasattr(result, "__history__") - result.__history__ = instance.__history__ # type: ignore[attr-defined] result.__system_prompt__ = "\n\n".join( part for part in ( @@ -378,18 +376,31 @@ def send(self, user_input: str) -> str: """ - __history__: OrderedDict[str, Mapping[str, Any]] __system_prompt__: str + @functools.cached_property + def __agent_id__(self) -> str: + return str(id(self)) + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if not hasattr(cls, "__history__"): - prop = functools.cached_property(lambda _: OrderedDict()) - prop.__set_name__(cls, "__history__") - cls.__history__ = prop if not hasattr(cls, "__system_prompt__"): sp = functools.cached_property( lambda self: inspect.getdoc(type(self)) or "" ) sp.__set_name__(cls, "__system_prompt__") cls.__system_prompt__ = sp + + +def get_bound_agent(template: Template) -> "Agent | None": + """Extract the bound :class:`Agent` instance from a template, if any. + + Bound method templates have a first context map with exactly one entry + (``{self_param_name: instance}``), while standalone templates have a + larger map (module globals). + """ + ctx = getattr(template, "__context__", None) + if ctx is None or not ctx.maps or len(ctx.maps[0]) != 1: + return None + val = next(iter(ctx.maps[0].values())) + return val if isinstance(val, Agent) else None diff --git a/tests/test_handlers_llm_encoding.py b/tests/test_handlers_llm_encoding.py index 9cccf93d..adadee23 100644 --- a/tests/test_handlers_llm_encoding.py +++ b/tests/test_handlers_llm_encoding.py @@ -686,7 +686,6 @@ def test_callable_encode_non_callable(): def test_callable_encode_no_source_no_docstring(): - class _NoDocCallable: __name__ = "nodoc" __doc__ = None diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index eec2fc78..528f2853 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -10,6 +10,7 @@ import json import os import re +import sqlite3 from collections.abc import Callable from enum import StrEnum from pathlib import Path @@ -37,9 +38,15 @@ call_assistant, call_tool, completion, + get_agent_history, ) from effectful.handlers.llm.encoding import Encodable, SynthesizedFunction from effectful.handlers.llm.evaluation import UnsafeEvalProvider +from effectful.handlers.llm.persistence import ( + CompactionHandler, + PersistenceHandler, + PersistentAgent, +) from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import NotHandled @@ -1797,17 +1804,17 @@ def simple_task(self, instruction: str) -> str: agent = SimpleAgent() - # No outer _get_history handler: LiteLLMProvider._call detects this is the - # outermost template and writes back to the agent's __history__. + provider = LiteLLMProvider(model="test") with ( - handler(LiteLLMProvider(model="test")), + handler(provider), handler(mock_handler), ): result = agent.simple_task("go") assert result == "done" - # Agent's __history__ should have messages written back (system + user + assistant) - assert len(agent.__history__) >= 2 + # Agent's history should have messages written back (system + user + assistant) + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) >= 2 class TestAgentCrossTemplateRecovery: @@ -1862,8 +1869,9 @@ def _completion(self, model, messages=None, **kwargs): agent = TestAgent() - with handler(TwoPhaseCompletionHandler()): - with handler(LiteLLMProvider(model="test")): + provider = LiteLLMProvider(model="test") + with handler(provider): + with handler(TwoPhaseCompletionHandler()): # First call should fail with tool execution error with pytest.raises(ToolCallExecutionError): agent.step_with_tool("stage 1") @@ -1874,7 +1882,7 @@ def _completion(self, model, messages=None, **kwargs): assert result == "summary result" # Verify history doesn't contain messages from the failed call - history = agent.__history__ + history = provider._histories.get(agent.__agent_id__, {}) for msg in history.values(): tool_calls = msg.get("tool_calls") if tool_calls: @@ -1916,12 +1924,14 @@ def do_work(self, task: str) -> str: mock = MockCompletionHandler(responses) agent = CleanupAgent() + provider = LiteLLMProvider(model="test") with pytest.raises(ToolCallExecutionError): - with handler(LiteLLMProvider(model="test")), handler(mock): + with handler(provider), handler(mock): agent.do_work("go") # Agent history should be empty — all messages from failed call pruned - assert len(agent.__history__) == 0 + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) == 0 def test_agent_history_preserved_for_successful_calls(self): """Successful calls should leave messages in agent history.""" @@ -1943,12 +1953,14 @@ def greet(self, name: str) -> str: mock = MockCompletionHandler(responses) agent = SuccessAgent() - with handler(LiteLLMProvider(model="test")), handler(mock): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(mock): result = agent.greet("world") assert result == "Hello!" # History should contain messages from the successful call - assert len(agent.__history__) >= 2 # user + assistant at minimum + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) >= 2 # user + assistant at minimum def test_agent_multiple_successful_calls_accumulate_history(self): """Multiple successful calls should accumulate in agent history.""" @@ -1977,14 +1989,16 @@ def _completion(self, model, messages=None, **kwargs): agent = ChatAgent() - with handler(LiteLLMProvider(model="test")), handler(MultiResponseHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(MultiResponseHandler()): r1 = agent.chat("first") r2 = agent.chat("second") assert r1 == "reply 1" assert r2 == "reply 2" # History should have messages from both calls - assert len(agent.__history__) >= 4 # 2 * (user + assistant) + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) >= 4 # 2 * (user + assistant) def test_agent_error_then_success_accumulates_only_success(self): """After a failed call, only the subsequent successful call's messages remain.""" @@ -2025,19 +2039,21 @@ def _completion(self, model, messages=None, **kwargs): agent = RecoveryAgent() - with handler(LiteLLMProvider(model="test")), handler(PhaseHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(PhaseHandler()): with pytest.raises(ToolCallExecutionError): agent.risky("step 1") - history_after_error = len(agent.__history__) + history_after_error = len(get_agent_history(agent.__agent_id__)) assert history_after_error == 0 result = agent.safe("step 2") assert result == "safe result" # Only messages from the successful call should be in history - assert len(agent.__history__) >= 2 - assert len(agent.__history__) > history_after_error + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) >= 2 + assert len(history) > history_after_error class TestAgentSystemMessageDeduplication: @@ -2109,13 +2125,15 @@ def _completion(self, model, messages=None, **kwargs): agent = SystemMsgAgent() - with handler(LiteLLMProvider(model="test")), handler(MultiHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(MultiHandler()): agent.do("a") agent.do("b") agent.do("c") agent.do("d") - system_msgs = [m for m in agent.__history__.values() if m["role"] == "system"] + history = provider._histories.get(agent.__agent_id__, {}) + system_msgs = [m for m in history.values() if m["role"] == "system"] assert len(system_msgs) == 1, ( f"Expected exactly 1 system message, got {len(system_msgs)}" ) @@ -2151,14 +2169,16 @@ def _completion(self, model, messages=None, **kwargs): agent = MemoryAgent() - with handler(LiteLLMProvider(model="test")), handler(MemoryHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(MemoryHandler()): agent.chat("first") agent.chat("second") agent.chat("third") # History should have: 1 system + 3 user + 3 assistant = 7 - assert len(agent.__history__) == 7 - roles = [m["role"] for m in agent.__history__.values()] + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) == 7 + roles = [m["role"] for m in history.values()] assert roles.count("system") == 1 assert roles.count("user") == 3 assert roles.count("assistant") == 3 @@ -2187,12 +2207,597 @@ def _completion(self, model, messages=None, **kwargs): agent = OrderAgent() - with handler(LiteLLMProvider(model="test")), handler(OrderHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(OrderHandler()): agent.step(1) agent.step(2) agent.step(3) - messages = list(agent.__history__.values()) + history = provider._histories.get(agent.__agent_id__, {}) + messages = list(history.values()) assert messages[0]["role"] == "system", ( "System message should be the first message in history" ) + + +# --------------------------------------------------------------------------- +# Integration tests: Agent & PersistentAgent with real LLM +# --------------------------------------------------------------------------- + + +class _PlainHelper(Agent): + """You are a concise helper. Reply with at most 10 words.""" + + @Template.define + def answer(self, q: str) -> str: + """Answer concisely: {q}""" + raise NotHandled + + +class _PersistentOrchestrator(Agent): + """You are an orchestrator. Use `ask_helper` to get answers. + + Reply with the helper's answer verbatim — do NOT call the tool more + than once. + """ + + def __init__(self, helper: _PlainHelper): + self._helper = helper + + @Tool.define + def ask_helper(self, question: str) -> str: + """Ask the helper agent a question. Call this exactly once.""" + return self._helper.answer(question) + + @Template.define + def orchestrate(self, task: str) -> str: + """Task: {task}. Call `ask_helper` once, then return the answer.""" + raise NotHandled + + +@requires_openai +def test_plain_agent_simple_call_integration(): + """Plain Agent makes a single LLM call and returns a string.""" + helper = _PlainHelper() + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=30)), + handler(LimitLLMCallsHandler(max_calls=1)), + ): + result = helper.answer("What is 2+2?") + + assert isinstance(result, str) + assert len(result) > 0 + + +@requires_openai +def test_agent_nested_tool_call_integration(): + """An Agent delegates to another Agent via a tool call. + + The orchestrator calls ask_helper (one tool round-trip) then returns. + LimitLLMCallsHandler caps total LLM calls at 4 to prevent runaway + recursion. + """ + helper = _PlainHelper() + orchestrator = _PersistentOrchestrator(helper) + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=60)), + handler(LimitLLMCallsHandler(max_calls=4)), + ): + result = orchestrator.orchestrate("What is the capital of France?") + + assert isinstance(result, str) + assert len(result) > 0 + + +@requires_openai +def test_persistent_agent_with_persistence_integration(tmp_path): + """PersistentAgent checkpoints to disk after a real LLM call.""" + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def ask(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + bot = Bot(agent_id="integration-bot") + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=30)), + handler(LimitLLMCallsHandler(max_calls=1)), + handler(persist), + ): + result = bot.ask("Say hello") + + assert isinstance(result, str) + db_path = tmp_path / "checkpoints.db" + assert db_path.exists() + conn = sqlite3.connect(str(db_path)) + row = conn.execute( + "SELECT handoff, history FROM checkpoints WHERE agent_id = ?", + ("integration-bot",), + ).fetchone() + conn.close() + assert row is not None + assert len(json.loads(row[1])) > 0 + assert row[0] == "" + + +@requires_openai +def test_persistent_and_plain_agent_cooperate_integration(tmp_path): + """A plain Agent and PersistentAgent work together via tool delegation. + + The persistent orchestrator calls ask_helper (a plain Agent) via a tool. + LimitLLMCallsHandler caps calls at 5 to prevent runaway tool loops. + """ + + class Orchestrator(PersistentAgent): + """You orchestrate tasks. Use `ask_helper` exactly once, then + return the helper's answer verbatim. Do NOT call tools more than once. + """ + + def __init__(self, helper_agent: _PlainHelper, **kwargs): + super().__init__(**kwargs) + self._helper = helper_agent + + @Tool.define + def ask_helper(self, question: str) -> str: + """Ask the helper a question. Call exactly once.""" + return self._helper.answer(question) + + @Template.define + def run(self, task: str) -> str: + """Task: {task}. Use `ask_helper` once, then return the result.""" + raise NotHandled + + helper = _PlainHelper() + orch = Orchestrator(helper, agent_id="orch") + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=60)), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(persist), + ): + result = orch.run("What is 3 * 7?") + + assert isinstance(result, str) + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + row = conn.execute( + "SELECT 1 FROM checkpoints WHERE agent_id = ?", ("orch",) + ).fetchone() + conn.close() + assert row is not None + + +@requires_openai +def test_compaction_after_multiple_calls_integration(): + """History is compacted after enough calls exceed the threshold.""" + helper = _PlainHelper() + provider = LiteLLMProvider(model="gpt-4o-mini", max_tokens=30) + + with ( + handler(provider), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(CompactionHandler(max_history_len=4)), + ): + # Each call adds ~2 msgs (user + assistant). + # After the 2nd call history exceeds 4 msgs, triggering compaction. + for i in range(2): + helper.answer(f"What is {i} + 1?") + + history = provider._histories.get(helper.__agent_id__, {}) + first_msg = next(iter(history.values())) + assert "CONTEXT SUMMARY" in first_msg["content"] + + +@requires_openai +def test_compaction_with_persistence_integration(tmp_path): + """Compaction and persistence compose: compacted history is checkpointed.""" + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def ask(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + bot = Bot(agent_id="compact-bot") + provider = LiteLLMProvider(model="gpt-4o-mini", max_tokens=30) + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(provider), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(CompactionHandler(max_history_len=4)), + handler(persist), + ): + for i in range(2): + bot.ask(f"What is {i} + 1?") + + # Compacted history should be persisted to disk + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + row = conn.execute( + "SELECT history FROM checkpoints WHERE agent_id = ?", ("compact-bot",) + ).fetchone() + conn.close() + history = json.loads(row[0]) + first_msg = history[0] + assert "CONTEXT SUMMARY" in first_msg["content"] + + +class _ToolAgent(Agent): + """You are a concise assistant. Answer in at most 10 words.""" + + @Tool.define + def lookup(self, query: str) -> str: + """Look up factual information about a topic.""" + return f"Result: The answer to '{query}' is 42." + + @Template.define + def ask(self, question: str) -> str: + """Answer: {question}. You MUST use the lookup tool first.""" + raise NotHandled + + +@pytest.mark.parametrize( + "model", + [ + pytest.param("gpt-4o-mini", marks=requires_openai), + pytest.param("claude-haiku-4-5-20251001", marks=requires_anthropic), + ], +) +def test_compaction_with_tool_calls_does_not_break_api(model): + """After compaction of history containing tool pairs, subsequent calls succeed. + + Each tool-using call generates ~4 messages (user, assistant/tool_use, + tool/result, assistant/final). With max_history_len=4, compaction fires + after the 1st call. The 2nd call must succeed — if compaction orphaned + a tool_result the API would reject the conversation. + """ + bot = _ToolAgent() + provider = LiteLLMProvider(model=model, max_tokens=60) + + with ( + handler(provider), + handler(RetryLLMHandler(stop=tenacity.stop_after_attempt(2))), + handler(LimitLLMCallsHandler(max_calls=8)), + handler(CompactionHandler(max_history_len=4)), + ): + bot.ask("What is the meaning of life?") + # Compaction should have fired. This call must not fail. + result = bot.ask("Summarize what you told me.") + + assert isinstance(result, str) + assert len(result) > 0 + + # Verify no orphaned tool_result messages in final history. + history = provider._histories.get(bot.__agent_id__, {}) + tool_use_ids: set[str] = set() + for msg in history.values(): + # Anthropic format: tool_use blocks in content + content = msg.get("content", "") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_use": + tool_use_ids.add(block["id"]) + # OpenAI format: tool_calls field on assistant messages + for tc in msg.get("tool_calls") or []: + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + if tc_id: + tool_use_ids.add(tc_id) + + for msg in history.values(): + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id", "") + assert tc_id in tool_use_ids, ( + f"Orphaned tool_result with tool_call_id={tc_id!r}" + ) + + +# --------------------------------------------------------------------------- +# Integration tests: SQLite persistence +# --------------------------------------------------------------------------- + + +@requires_openai +def test_sqlite_persistence_crash_recovery_integration(tmp_path): + """After a simulated crash, a new session resumes from the SQLite checkpoint.""" + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def ask(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + # Session 1: successful call + bot = Bot(agent_id="crash-test-bot") + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=30)), + handler(LimitLLMCallsHandler(max_calls=1)), + handler(persist), + ): + result1 = bot.ask("What is 2+2?") + + assert isinstance(result1, str) + + # Verify SQLite DB exists and has data + db_path = tmp_path / "checkpoints.db" + assert db_path.exists() + conn = sqlite3.connect(str(db_path)) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert result == "ok" + row = conn.execute( + "SELECT handoff, history FROM checkpoints WHERE agent_id = ?", + ("crash-test-bot",), + ).fetchone() + conn.close() + assert row is not None + assert row[0] == "" # handoff cleared after success + assert len(json.loads(row[1])) > 0 + + # Session 2: new process loads from SQLite and continues + bot2 = Bot(agent_id="crash-test-bot") + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=30)), + handler(LimitLLMCallsHandler(max_calls=1)), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result2 = bot2.ask("What did I just ask?") + + assert isinstance(result2, str) + + # History should have grown (session 2 sees session 1 messages) + conn = sqlite3.connect(str(db_path)) + row = conn.execute( + "SELECT history FROM checkpoints WHERE agent_id = ?", + ("crash-test-bot",), + ).fetchone() + conn.close() + history = json.loads(row[0]) + assert len(history) > 3 # at least system + user + assistant from each session + + +@requires_openai +def test_sqlite_persistence_db_integrity_after_compaction_integration(tmp_path): + """Compacted history is correctly persisted to SQLite.""" + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def ask(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + bot = Bot(agent_id="compact-sqlite-bot") + provider = LiteLLMProvider(model="gpt-4o-mini", max_tokens=30) + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(provider), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(CompactionHandler(max_history_len=4)), + handler(persist), + ): + for i in range(2): + bot.ask(f"What is {i} + 1?") + + # Verify SQLite DB integrity + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert result == "ok" + row = conn.execute( + "SELECT history FROM checkpoints WHERE agent_id = ?", + ("compact-sqlite-bot",), + ).fetchone() + conn.close() + history = json.loads(row[0]) + first_msg = history[0] + assert "CONTEXT SUMMARY" in first_msg["content"] + + +# --------------------------------------------------------------------------- +# Integration tests: Multi-agent choreography (based on multi_agent_example.py) +# --------------------------------------------------------------------------- + + +@requires_openai +def test_multi_agent_choreography_integration(tmp_path): + """Multi-agent choreography: architect plans, coder implements, reviewer reviews. + + Based on docs/source/multi_agent_example.py but with constrained scope + (one module, capped LLM calls) so the test completes quickly. + """ + from typing import Literal, TypedDict + + from effectful.handlers.llm.multi import Choreography, scatter + from effectful.handlers.llm.persistence import PersistenceHandler, PersistentAgent + + class ModuleSpec(TypedDict): + module_path: str + description: str + + class PlanResult(TypedDict): + modules: list[ModuleSpec] + + class ReviewResult(TypedDict): + verdict: Literal["PASS", "NEEDS_FIXES"] + feedback: str + + class Architect(PersistentAgent): + """You are a software architect. Given a project spec, output a plan + with a 'modules' list. Each module has module_path and description. + Output exactly ONE module.""" + + @Template.define + def plan_modules(self, project_spec: str) -> PlanResult: + """Plan modules for: {project_spec}. Return exactly one module.""" + raise NotHandled + + class Coder(PersistentAgent): + """You are a Python developer. Write clean code. Reply with ONLY + the Python source code, no markdown.""" + + @Template.define + def implement_module(self, module_spec: str) -> str: + """Implement: {module_spec}""" + raise NotHandled + + class Reviewer(PersistentAgent): + """You are a code reviewer. Review for correctness and style.""" + + @Template.define + def review_code(self, code: str) -> ReviewResult: + """Review this code. Return verdict PASS or NEEDS_FIXES: {code}""" + raise NotHandled + + MAX_REVIEW_ROUNDS = 3 + + def build_project( + project_spec: str, + architect: Architect, + coders: list[Coder], + reviewer: Reviewer, + ) -> list[ReviewResult]: + # Step 1: Architect plans modules + plan = architect.plan_modules(project_spec) + + # Step 2: Scatter implementation across coders + codes = scatter( + plan["modules"], + coders, + lambda coder, mod: coder.implement_module(str(mod)), + ) + + # Step 3: Review each module; if NEEDS_FIXES, re-implement with feedback + reviews: list[ReviewResult] = [] + for mod, code in zip(plan["modules"], codes): + for _round in range(MAX_REVIEW_ROUNDS): + review = reviewer.review_code(code) + if review["verdict"] == "PASS": + break + # Re-implement with reviewer feedback + fix_spec = f"{mod}\nFix feedback: {review['feedback']}" + code = coders[0].implement_module(fix_spec) + reviews.append(review) + + return reviews + + architect = Architect(agent_id="architect") + coder1 = Coder(agent_id="coder-1") + reviewer = Reviewer(agent_id="reviewer") + + from effectful.handlers.llm.multi import PersistentTaskQueue + + state_dir = tmp_path / "state" + state_dir.mkdir() + choreo = Choreography( + build_project, + agents=[architect, coder1, reviewer], + queue=PersistentTaskQueue(state_dir / "task_queue.db"), + handlers=[ + LiteLLMProvider(model="gpt-4o-mini", max_tokens=200), + LimitLLMCallsHandler(max_calls=8), + PersistenceHandler(state_dir / "checkpoints.db"), + ], + poll_interval=0.05, + ) + + reviews = choreo.run( + project_spec="Build a single-function module: slugify(text) -> str", + architect=architect, + coders=[coder1], + reviewer=reviewer, + ) + + assert isinstance(reviews, list) + assert len(reviews) >= 1 + for r in reviews: + assert r["verdict"] in ("PASS", "NEEDS_FIXES") + assert isinstance(r["feedback"], str) + + +# --------------------------------------------------------------------------- +# Integration tests: SQLite task queue crash tolerance +# --------------------------------------------------------------------------- + + +@requires_openai +def test_sqlite_task_queue_crash_recovery_integration(tmp_path): + """Choreography with SQLite queue: completed steps survive restart.""" + from effectful.handlers.llm.multi import Choreography, PersistentTaskQueue + from effectful.handlers.llm.persistence import PersistenceHandler, PersistentAgent + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def step1(self, q: str) -> str: + """Answer briefly: {q}""" + raise NotHandled + + @Template.define + def step2(self, context: str) -> str: + """Given context, reply in one word: {context}""" + raise NotHandled + + bot = Bot(agent_id="crash-q-bot") + + db_path = tmp_path / "task_queue.db" + + # Run 1: complete full choreography + def two_step(bot): + r1 = bot.step1("What is 2+2?") + return bot.step2(r1) + + choreo1 = Choreography( + two_step, + agents=[bot], + queue=PersistentTaskQueue(db_path), + handlers=[ + LiteLLMProvider(model="gpt-4o-mini", max_tokens=30), + RetryLLMHandler(), + LimitLLMCallsHandler(max_calls=3), + PersistenceHandler(tmp_path / "checkpoints.db"), + ], + poll_interval=0.05, + ) + result1 = choreo1.run(bot=bot) + assert isinstance(result1, str) + + # Verify SQLite DB exists and has task data + conn = sqlite3.connect(str(db_path)) + integrity = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert integrity == "ok" + done_count = conn.execute( + "SELECT COUNT(*) FROM tasks WHERE status = 'done'" + ).fetchone()[0] + conn.close() + assert done_count == 2 # step-0000 and step-0001 + + # Run 2: new Choreography on same DB — steps return cached results + bot2 = Bot(agent_id="crash-q-bot") + choreo2 = Choreography( + two_step, + agents=[bot2], + queue=PersistentTaskQueue(db_path), + handlers=[ + LiteLLMProvider(model="gpt-4o-mini", max_tokens=30), + RetryLLMHandler(), + LimitLLMCallsHandler(max_calls=0), # no LLM calls allowed — must use cache + PersistenceHandler(tmp_path / "checkpoints.db"), + ], + poll_interval=0.05, + ) + result2 = choreo2.run(bot=bot2) + assert result2 == result1 # exact same cached result diff --git a/tests/test_handlers_llm_template.py b/tests/test_handlers_llm_template.py index 7c3bd5bc..cff3da06 100644 --- a/tests/test_handlers_llm_template.py +++ b/tests/test_handlers_llm_template.py @@ -11,12 +11,15 @@ from effectful.handlers.llm import Agent, Template, Tool from effectful.handlers.llm.completions import ( DEFAULT_SYSTEM_PROMPT, + AgentHistoryHandler, LiteLLMProvider, RetryLLMHandler, + call_assistant, call_user, completion, + get_agent_history, ) -from effectful.ops.semantics import handler +from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import NotHandled @@ -281,10 +284,10 @@ def test_history_contains_all_messages_after_two_calls(self): bot.send("a") bot.send("b") - # After two complete calls the history should have: - # call 1: system, user, assistant (3) - # call 2: system, user, assistant (3) - assert len(bot.__history__) >= 4 + # After two complete calls the history should have: + # call 1: system, user, assistant (3) + # call 2: system, user, assistant (3) + assert len(get_agent_history(bot.__agent_id__)) >= 4 def test_message_ids_are_unique(self): mock = MockCompletionHandler( @@ -299,8 +302,8 @@ def test_message_ids_are_unique(self): bot.send("a") bot.send("b") - ids = list(bot.__history__.keys()) - assert len(ids) == len(set(ids)), "message IDs must be unique" + ids = list(get_agent_history(bot.__agent_id__).keys()) + assert len(ids) == len(set(ids)), "message IDs must be unique" class TestAgentIsolation: @@ -314,20 +317,24 @@ def test_two_agents_have_independent_histories(self): ] ) bot1 = ChatBot() + bot1.__agent_id__ = "bot1" bot2 = ChatBot() + bot2.__agent_id__ = "bot2" with handler(LiteLLMProvider()), handler(mock): bot1.send("msg for bot1") bot2.send("msg for bot2") - # bot2's call should NOT contain bot1's messages — only system + user - assert len(mock.received_messages[1]) == len(mock.received_messages[0]) + # bot2's call should NOT contain bot1's messages — only system + user + assert len(mock.received_messages[1]) == len(mock.received_messages[0]) - # Each bot made exactly one call, so their histories should be equal in size - assert len(bot1.__history__) == len(bot2.__history__) + # Each bot made exactly one call, so their histories should be equal in size + h1 = get_agent_history(bot1.__agent_id__) + h2 = get_agent_history(bot2.__agent_id__) + assert len(h1) == len(h2) - # Histories share no message IDs - assert set(bot1.__history__.keys()).isdisjoint(set(bot2.__history__.keys())) + # Histories share no message IDs + assert set(h1.keys()).isdisjoint(set(h2.keys())) def test_non_agent_template_gets_fresh_sequence(self): @Template.define @@ -494,13 +501,13 @@ class ValidDocAgent(Agent): ) -class TestAgentCachedProperty: - """__history__ is lazily created per instance without requiring __init__.""" +class TestAgentHistoryViaHandler: + """History is managed via AgentHistoryHandler, not as a cached property.""" - def test_no_init_required(self): + def test_history_defaults_to_empty(self): class MinimalAgent(Agent): - """You are a minimal cached-property test agent. - Your goal is to expose lazily initialized Agent state. + """You are a minimal history-handler test agent. + Your goal is to expose handler-managed Agent history. """ @Template.define @@ -509,9 +516,10 @@ def greet(self, name: str) -> str: raise NotHandled agent = MinimalAgent() - # Should be an OrderedDict, created on first access - assert isinstance(agent.__history__, collections.OrderedDict) - assert len(agent.__history__) == 0 + with handler(AgentHistoryHandler()): + history = get_agent_history(agent.__agent_id__) + assert isinstance(history, collections.OrderedDict) + assert len(history) == 0 def test_subclass_with_own_init(self): class CustomAgent(Agent): @@ -529,13 +537,20 @@ def greet(self) -> str: agent = CustomAgent("Alice") assert agent.name == "Alice" - assert isinstance(agent.__history__, collections.OrderedDict) + with handler(AgentHistoryHandler()): + assert isinstance( + get_agent_history(agent.__agent_id__), collections.OrderedDict + ) def test_history_is_per_instance(self): a = ChatBot() + a.__agent_id__ = "a" b = ChatBot() - a.__history__["fake"] = {"id": "fake", "role": "user", "content": "x"} - assert "fake" not in b.__history__ + b.__agent_id__ = "b" + with handler(AgentHistoryHandler()): + hist_a = get_agent_history(a.__agent_id__) + hist_a["fake"] = {"id": "fake", "role": "user", "content": "x"} + assert "fake" not in get_agent_history(b.__agent_id__) class TestAgentWithToolCalls: @@ -570,13 +585,13 @@ def compute(self, question: str) -> str: with handler(LiteLLMProvider()), handler(mock): result = agent.compute("what is 2+3?") - assert result == "The answer is 5" + assert result == "The answer is 5" - # History should contain: system, user, assistant (tool_call), - # tool (result), assistant (final) - roles = [m["role"] for m in agent.__history__.values()] - assert "tool" in roles - assert roles.count("assistant") == 2 + # History should contain: system, user, assistant (tool_call), + # tool (result), assistant (final) + roles = [m["role"] for m in get_agent_history(agent.__agent_id__).values()] + assert "tool" in roles + assert roles.count("assistant") == 2 class TestAgentWithRetryHandler: @@ -611,13 +626,13 @@ def pick_number(self) -> int: ): result = agent.pick_number() - assert result == 42 + assert result == 42 - # The malformed assistant message and error feedback from the retry - # should NOT appear in the agent's history. Only the final successful - # assistant message should be there. - roles = {m["role"] for m in agent.__history__.values()} - assert {"user", "assistant"} == roles - {"system"} + # The malformed assistant message and error feedback from the retry + # should NOT appear in the agent's history. Only the final successful + # assistant message should be there. + roles = {m["role"] for m in get_agent_history(agent.__agent_id__).values()} + assert {"user", "assistant"} == roles - {"system"} class TestNestedTemplateCalling: @@ -626,7 +641,7 @@ class TestNestedTemplateCalling: When a Template triggers a tool call whose implementation invokes another Template on the same Agent, the inner call must: - work on a fresh copy of the agent's history - - NOT write its messages back to agent.__history__ + - NOT write its messages back to the agent's history - return its result correctly so the outer template can continue """ @@ -647,7 +662,7 @@ def test_same_agent_nested_template_via_tool(self): assert result == "all good" def test_only_outermost_writes_to_history(self): - """Inner template's messages are absent from agent.__history__.""" + """Inner template's messages are absent from agent history.""" mock = MockCompletionHandler( [ make_tool_call_response("self__nested_tool", '{"payload": "demo"}'), @@ -660,14 +675,14 @@ def test_only_outermost_writes_to_history(self): with handler(LiteLLMProvider()), handler(mock): agent.outer("demo") - roles = [m["role"] for m in agent.__history__.values()] - # Outer call produces: user, assistant(tool_call), tool, assistant(final) - # Inner call's user + assistant are NOT written back - assert set(roles) <= {"system", "user", "assistant", "tool"} - assert roles.count("system") == 1 - assert roles.count("user") == 1 - assert roles.count("assistant") == 2 # tool_call + final - assert roles.count("tool") == 1 + roles = [m["role"] for m in get_agent_history(agent.__agent_id__).values()] + # Outer call produces: user, assistant(tool_call), tool, assistant(final) + # Inner call's user + assistant are NOT written back + assert set(roles) <= {"system", "user", "assistant", "tool"} + assert roles.count("system") == 1 + assert roles.count("user") == 1 + assert roles.count("assistant") == 2 # tool_call + final + assert roles.count("tool") == 1 def test_inner_template_gets_fresh_messages(self): """The nested template's LLM call sees only its own system + user, @@ -709,7 +724,7 @@ def test_inner_template_sees_prior_completed_history(self): agent.outer("first") agent.outer("second") - # After first call, agent.__history__ has 2 messages (user + assistant). + # After first call, agent history has 2 messages (user + assistant). # Second outer call (call 1): starts from history(2) + own user = 3. # Inner call (call 2): starts from history(2) + own user = 3. # Both see the same base history. If inner saw the outer's in-flight @@ -747,6 +762,44 @@ def test_sequential_call_after_nested_sees_history(self): second_call_roles = [m["role"] for m in mock.received_messages[3]] assert second_call_roles.count("assistant") >= 2 # from first call's history + def test_inner_success_outer_failure_no_history_leak(self): + """When inner call succeeds but outer fails, canonical history must + not be left with inner call's stale messages.""" + + class LimitCallsHandler(ObjectInterpretation): + def __init__(self, max_calls): + self.max_calls = max_calls + self.count = 0 + + @implements(call_assistant) + def _call(self, *args, **kwargs): + self.count += 1 + if self.count > self.max_calls: + raise RuntimeError(f"Exceeded {self.max_calls} calls") + return fwd() + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__nested_tool", '{"payload": "demo"}'), + make_text_response("inner"), + make_text_response("outer"), # won't be reached + ] + ) + agent = _DesignerAgent() + provider = LiteLLMProvider() + # Allow 2 calls: outer's first + inner's. Outer's 2nd call (#3) fails. + limiter = LimitCallsHandler(max_calls=2) + + with pytest.raises(RuntimeError, match="Exceeded"): + with handler(provider), handler(mock), handler(limiter): + agent.outer("demo") + + with handler(provider): + history = get_agent_history(agent.__agent_id__) + # Canonical history should be empty — outer never completed. + # It must NOT contain inner call's system/user/assistant. + assert len(history) == 0 + # --------------------------------------------------------------------------- # Template method and scoping tests (moved from test_handlers_llm_template.py) @@ -1185,8 +1238,6 @@ def static_method(x: int) -> int: # static_method remains a plain Template accessible on class and instance assert isinstance(MyAgent.static_method, Template) assert isinstance(agent.static_method, Template) - # static_method should NOT have __history__ set - assert not hasattr(MyAgent.static_method, "__history__") def test_agent_skips_classmethod_template(self): """Agent.__init_subclass__ does not wrap classmethod Templates @@ -1211,8 +1262,6 @@ def class_method(cls) -> str: agent = MyAgent() assert isinstance(agent.instance_method, Template) assert isinstance(MyAgent.class_method, Template) - # class_method should NOT have __history__ set - assert not hasattr(MyAgent.class_method, "__history__") def test_template_formatting_scoped(): diff --git a/tests/test_multi_agent_epp.py b/tests/test_multi_agent_epp.py new file mode 100644 index 00000000..459271e9 --- /dev/null +++ b/tests/test_multi_agent_epp.py @@ -0,0 +1,1791 @@ +"""Tests for effectful.handlers.llm.multi — choreographic EPP with TaskQueue.""" + +import itertools +import shutil +import threading +import time +from pathlib import Path +from typing import Any + +import pytest + +from effectful.handlers.llm import Template +from effectful.handlers.llm.multi import ( + Choreography, + ChoreographyError, + EndpointProjection, + InMemoryTaskQueue, + PersistentTaskQueue, + TaskStatus, + fan_out, + scatter, +) +from effectful.handlers.llm.persistence import PersistentAgent +from effectful.handlers.llm.template import get_bound_agent +from effectful.ops.semantics import handler +from effectful.ops.syntax import ObjectInterpretation, implements +from effectful.ops.types import NotHandled + +# ── Fixtures and helpers ────────────────────────────────────────── + +STATE_DIR = Path("/tmp/test_multi_epp") + +# Default timeout for all concurrent tests (seconds). +# Concurrency bugs often manifest as hangs — this catches them. +THREAD_TIMEOUT = 10 + + +@pytest.fixture(autouse=True) +def clean_state(): + if STATE_DIR.exists(): + shutil.rmtree(STATE_DIR) + STATE_DIR.mkdir(parents=True) + yield + if STATE_DIR.exists(): + shutil.rmtree(STATE_DIR) + + +class MockLLM(ObjectInterpretation): + """Mock LLM handler that returns canned responses.""" + + def __init__(self, responses: dict[str, Any]): + self._responses = responses + self.calls: list[str] = [] + + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + bound = get_bound_agent(template) + key = ( + f"{bound.__agent_id__}.{template.__name__}" if bound else template.__name__ + ) + self.calls.append(key) + return self._responses.get( + key, self._responses.get(template.__name__, f"mock-{template.__name__}") + ) + + +class DelayedMockLLM(ObjectInterpretation): + """Mock LLM that introduces per-agent delays to force scheduling orderings. + + ``delays`` maps agent_id to a sleep duration (seconds) applied before + each template call for that agent. This lets tests deterministically + force one thread to run before another. + """ + + def __init__(self, responses: dict[str, Any], delays: dict[str, float]): + self._responses = responses + self._delays = delays + self.calls: list[str] = [] + self._lock = threading.Lock() + + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + bound = get_bound_agent(template) + agent_id = bound.__agent_id__ if bound else None + if agent_id and agent_id in self._delays: + time.sleep(self._delays[agent_id]) + key = ( + f"{bound.__agent_id__}.{template.__name__}" if bound else template.__name__ + ) + with self._lock: + self.calls.append(key) + return self._responses.get( + key, self._responses.get(template.__name__, f"mock-{template.__name__}") + ) + + +class FailingMockLLM(ObjectInterpretation): + """Mock LLM that raises on specific agent.template keys.""" + + def __init__(self, responses: dict[str, Any], fail_on: set[str]): + self._responses = responses + self._fail_on = fail_on + + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + bound = get_bound_agent(template) + key = ( + f"{bound.__agent_id__}.{template.__name__}" if bound else template.__name__ + ) + if key in self._fail_on: + raise RuntimeError(f"Simulated failure on {key}") + return self._responses.get( + key, self._responses.get(template.__name__, f"mock-{template.__name__}") + ) + + +def _run_threads_with_timeout(targets, timeout=THREAD_TIMEOUT): + """Start threads and join with timeout. Raises if any thread hangs.""" + threads = [threading.Thread(target=t, daemon=True) for t in targets] + for t in threads: + t.start() + for t in threads: + t.join(timeout=timeout) + if t.is_alive(): + raise TimeoutError( + f"Thread {t.name} did not finish within {timeout}s — " + "possible deadlock or infinite poll" + ) + + +# ── Agent definitions ───────────────────────────────────────────── + + +class Architect(PersistentAgent): + """Plans modules.""" + + @Template.define + def plan(self, spec: str) -> str: + """Plan modules for: {spec}""" + raise NotHandled + + +class Coder(PersistentAgent): + """Writes code.""" + + @Template.define + def implement(self, spec: str) -> str: + """Implement: {spec}""" + raise NotHandled + + +class Reviewer(PersistentAgent): + """Reviews code.""" + + @Template.define + def review(self, code: str) -> str: + """Review: {code}""" + raise NotHandled + + +class TesterAgent(PersistentAgent): + """Writes tests.""" + + @Template.define + def write_tests(self, spec: str) -> str: + """Write tests for: {spec}""" + raise NotHandled + + +class Prover(PersistentAgent): + """Proves theorems.""" + + @Template.define + def prove(self, spec: str) -> str: + """Prove: {spec}""" + raise NotHandled + + +# ── TaskQueue tests ─────────────────────────────────────────────── + +# Counter to avoid directory collisions between parametrized persistent tests. +_ptq_counter = 0 + + +def _make_persistent_queue(): + global _ptq_counter + _ptq_counter += 1 + return PersistentTaskQueue(STATE_DIR / f"q-{_ptq_counter}.db") + + +@pytest.fixture(params=["persistent", "in_memory"]) +def make_queue(request): + """Parametrized fixture — runs each test against both queue backends.""" + if request.param == "persistent": + return _make_persistent_queue + else: + return InMemoryTaskQueue + + +class TestTaskQueue: + def test_submit_and_claim(self, make_queue): + tq = make_queue() + tid = tq.submit("code", {"file": "main.py"}, task_id="t1") + assert tid == "t1" + + task = tq.claim("code", "worker1") + assert task is not None + assert task["id"] == "t1" + assert task["status"] == TaskStatus.CLAIMED + + # Can't claim again + assert tq.claim("code", "worker2") is None + + def test_idempotent_submit(self, make_queue): + tq = make_queue() + tq.submit("code", {}, task_id="t1") + tq.submit("code", {}, task_id="t1") # no-op + assert tq.pending_count() == 1 + + def test_complete_and_get_result(self, make_queue): + tq = make_queue() + tq.submit("code", {}, task_id="t1") + tq.claim("code", "w1") + tq.complete("t1", "w1", {"output": "hello"}) + assert tq.get_result("t1") == {"output": "hello"} + + def test_release_stale_claims(self, make_queue): + tq = make_queue() + tq.submit("code", {}, task_id="t1") + tq.claim("code", "crashed_worker") + assert tq.pending_count() == 0 + + released = tq.release_stale_claims("crashed_worker") + assert released == 1 + assert tq.pending_count() == 1 + + # Can re-claim + task = tq.claim("code", "new_worker") + assert task is not None + + def test_claim_by_prefix(self, make_queue): + tq = make_queue() + tq.submit("scatter", {}, task_id="step-0001:0000") + tq.submit("scatter", {}, task_id="step-0001:0001") + tq.submit("other", {}, task_id="step-0002") + + task = tq.claim_by_prefix("step-0001:", "w1") + assert task is not None + assert task["id"].startswith("step-0001:") + + task2 = tq.claim_by_prefix("step-0001:", "w1") + assert task2 is not None + assert task2["id"] != task["id"] + + assert tq.claim_by_prefix("step-0001:", "w1") is None + + def test_all_done(self, make_queue): + tq = make_queue() + assert tq.all_done() + + tq.submit("code", {}, task_id="t1") + assert not tq.all_done() + + tq.claim("code", "w1") + assert not tq.all_done() # claimed but not done + + tq.complete("t1", "w1", "result") + assert tq.all_done() + + def test_fail(self, make_queue): + tq = make_queue() + tq.submit("code", {}, task_id="t1") + tq.claim("code", "w1") + tq.fail("t1", "w1", "boom") + # Failed tasks are not pending/claimed, so all_done is True + assert tq.all_done() + # get_result returns None for failed tasks + assert tq.get_result("t1") is None + + def test_concurrent_claims(self, make_queue): + """Multiple threads claiming — no double claims.""" + tq = make_queue() + n_tasks = 20 + for i in range(n_tasks): + tq.submit("work", {"i": i}, task_id=f"t{i:03d}") + + claimed: list[dict] = [] + lock = threading.Lock() + + def claimer(owner): + while True: + task = tq.claim("work", owner) + if task is None: + break + with lock: + claimed.append(task) + + threads = [threading.Thread(target=claimer, args=(f"w{i}",)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=THREAD_TIMEOUT) + + # Each task claimed exactly once + ids = [t["id"] for t in claimed] + assert len(ids) == n_tasks + assert len(set(ids)) == n_tasks + + +# ── EPP tests ───────────────────────────────────────────────────── + + +class TestEndpointProjection: + def _run_choreo( + self, + agents, + choreo_fn, + responses, + *, + mock_cls=None, + timeout=THREAD_TIMEOUT, + **kwargs, + ): + """Helper: run a choreography with mock LLM. + + *mock_cls* can be a callable ``(responses) -> ObjectInterpretation`` + to inject custom mock behaviour (e.g. delays). + """ + tq = InMemoryTaskQueue() + ids = frozenset(a.__agent_id__ for a in agents) + results: dict[str, Any] = {} + llm_calls: dict[str, list[str]] = {} + errors: list[tuple[str, Exception]] = [] + lock = threading.Lock() + + def run_agent(agent): + try: + mock = mock_cls(responses) if mock_cls else MockLLM(responses) + tq.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo_fn(**kwargs) + with lock: + results[agent.__agent_id__] = r + llm_calls[agent.__agent_id__] = list( + mock.calls if hasattr(mock, "calls") else [] + ) + except Exception as e: + with lock: + errors.append((agent.__agent_id__, e)) + + _run_threads_with_timeout( + [lambda a=a: run_agent(a) for a in agents], + timeout=timeout, + ) + + return results, llm_calls, errors + + def test_basic_sequential(self): + """Two agents: planner plans, worker executes.""" + planner = Architect(agent_id="arch") + worker = Coder(agent_id="coder") + + def choreo(spec, arch, coder): + plan = arch.plan(spec) + return coder.implement(plan) + + results, calls, errors = self._run_choreo( + [planner, worker], + choreo, + {"plan": "the plan", "implement": "code"}, + spec="build it", + arch=planner, + coder=worker, + ) + + assert not errors, errors + assert results["arch"] == "code" + assert results["coder"] == "code" + # Planner executed plan, worker executed implement + assert "arch.plan" in calls["arch"] + assert "coder.implement" in calls["coder"] + + def test_all_agents_same_result(self): + """All agent threads produce the same result.""" + arch = Architect(agent_id="a") + coder = Coder(agent_id="c") + reviewer = Reviewer(agent_id="r") + + def choreo(arch, coder, reviewer): + plan = arch.plan("spec") + code = coder.implement(plan) + return reviewer.review(code) + + results, _, errors = self._run_choreo( + [arch, coder, reviewer], + choreo, + {"plan": "plan", "implement": "code", "review": "PASS"}, + arch=arch, + coder=coder, + reviewer=reviewer, + ) + + assert not errors + assert results["a"] == results["c"] == results["r"] == "PASS" + + def test_while_loop(self): + """Control flow: reviewer retries, then passes.""" + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + reviewer = Reviewer(agent_id="rev") + + review_count = {"n": 0} + + class LoopMock(ObjectInterpretation): + def __init__(self): + self.calls: list[str] = [] + + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + self.calls.append(template.__name__) + if template.__name__ == "plan": + return "plan" + elif template.__name__ == "implement": + return f"code-v{len(self.calls)}" + elif template.__name__ == "review": + review_count["n"] += 1 + return "RETRY" if review_count["n"] <= 1 else "PASS" + return "?" + + def choreo(arch, coder, reviewer): + plan = arch.plan("spec") + code = coder.implement(plan) + while True: + verdict = reviewer.review(code) + if verdict == "PASS": + return code + code = coder.implement(verdict) + + tq = InMemoryTaskQueue() + ids = frozenset(["arch", "coder", "rev"]) + results: dict[str, Any] = {} + errors: list = [] + lock = threading.Lock() + + def run(agent): + try: + mock = LoopMock() + epp = EndpointProjection(agent, tq, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(arch=arch, coder=coder, reviewer=reviewer) + with lock: + results[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in [arch, coder, reviewer]]) + + assert not errors, errors + assert all(r == results["arch"] for r in results.values()) + + def test_crash_recovery(self): + """Pre-cache step 0, restart: step 0 from cache, step 1 fresh.""" + tq = InMemoryTaskQueue() + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + + # Simulate prior run: step 0 done + tq.submit("plan", {"agent": "arch"}, task_id="step-0000") + tq.claim("plan", "arch") + tq.complete("step-0000", "arch", "cached-plan") + + def choreo(arch, coder): + plan = arch.plan("spec") + return coder.implement(plan) + + ids = frozenset(["arch", "coder"]) + results: dict[str, Any] = {} + llm_calls: dict[str, list[str]] = {} + lock = threading.Lock() + errors: list = [] + + def run(agent): + try: + mock = MockLLM({"plan": "SHOULD NOT RUN", "implement": "fresh-code"}) + tq.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(arch=arch, coder=coder) + with lock: + results[agent.__agent_id__] = r + llm_calls[agent.__agent_id__] = list(mock.calls) + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in [arch, coder]]) + + assert not errors, errors + assert results["arch"] == "fresh-code" + assert results["coder"] == "fresh-code" + # arch should NOT have called LLM for plan + assert "arch.plan" not in llm_calls.get("arch", []) + # coder should have called implement + assert "coder.implement" in llm_calls.get("coder", []) + + +# ── Ordering permutation tests ──────────────────────────────────── + + +class TestOrderingPermutations: + """Run choreographies under every possible thread-scheduling order. + + Uses controlled delays to deterministically force one agent to + execute before another. For N agents there are N! orderings; + all must produce the same result. + """ + + def _run_with_ordering(self, agents, choreo_fn, responses, ordering, **kwargs): + """Run *choreo_fn* with agents delayed so they execute in *ordering*.""" + # Give each agent a staggered delay: first in ordering gets 0, + # second gets a small delay, etc. + delays = {agent.__agent_id__: i * 0.03 for i, agent in enumerate(ordering)} + tq = InMemoryTaskQueue() + ids = frozenset(a.__agent_id__ for a in agents) + results: dict[str, Any] = {} + errors: list[tuple[str, Exception]] = [] + lock = threading.Lock() + + def run_agent(agent): + try: + mock = DelayedMockLLM(responses, delays) + tq.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo_fn(**kwargs) + with lock: + results[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append((agent.__agent_id__, e)) + + _run_threads_with_timeout([lambda a=a: run_agent(a) for a in agents]) + return results, errors + + def test_two_agent_all_orderings(self): + """Two agents, two orderings — both must agree on the result.""" + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + agents = [arch, coder] + responses = {"plan": "the-plan", "implement": "the-code"} + + def choreo(arch, coder): + plan = arch.plan("spec") + return coder.implement(plan) + + all_results = [] + for perm in itertools.permutations(agents): + results, errors = self._run_with_ordering( + agents, + choreo, + responses, + perm, + arch=arch, + coder=coder, + ) + assert not errors, f"Ordering {[a.__agent_id__ for a in perm]}: {errors}" + all_results.append(results) + + # All orderings must produce the same result for every agent + for r in all_results: + assert r["arch"] == r["coder"] == "the-code" + + def test_three_agent_all_orderings(self): + """Three agents, six orderings — all must agree.""" + arch = Architect(agent_id="a") + coder = Coder(agent_id="c") + reviewer = Reviewer(agent_id="r") + agents = [arch, coder, reviewer] + responses = {"plan": "plan", "implement": "code", "review": "PASS"} + + def choreo(arch, coder, reviewer): + plan = arch.plan("spec") + code = coder.implement(plan) + return reviewer.review(code) + + expected = "PASS" + for perm in itertools.permutations(agents): + results, errors = self._run_with_ordering( + agents, + choreo, + responses, + perm, + arch=arch, + coder=coder, + reviewer=reviewer, + ) + assert not errors, f"Ordering {[a.__agent_id__ for a in perm]}: {errors}" + for aid, r in results.items(): + assert r == expected, ( + f"Agent {aid} got {r!r} != {expected!r} " + f"with ordering {[a.__agent_id__ for a in perm]}" + ) + + def test_scatter_all_orderings(self): + """Scatter with two coders + reviewer, all orderings agree.""" + c1 = Coder(agent_id="c1") + c2 = Coder(agent_id="c2") + rev = Reviewer(agent_id="rev") + agents = [c1, c2, rev] + responses = {"implement": "code", "review": "ok"} + + def choreo(items, coders, reviewer): + codes = scatter(items, coders, lambda c, m: c.implement(m)) + return [reviewer.review(c) for c in codes] + + for perm in itertools.permutations(agents): + results, errors = self._run_with_ordering( + agents, + choreo, + responses, + perm, + items=["A", "B"], + coders=[c1, c2], + reviewer=rev, + ) + assert not errors, f"Ordering {[a.__agent_id__ for a in perm]}: {errors}" + for aid, r in results.items(): + assert r == ["ok", "ok"], ( + f"Agent {aid} got {r!r} with ordering " + f"{[a.__agent_id__ for a in perm]}" + ) + + +# ── Race condition tests ────────────────────────────────────────── + + +class TestRaceConditions: + def test_concurrent_claims_no_double_execution(self): + """Multiple workers racing to claim the same task type — + exactly one wins, no double execution.""" + tq = InMemoryTaskQueue() + tq.submit("work", {"data": "x"}, task_id="t1") + + claimed_by: list[str] = [] + lock = threading.Lock() + barrier = threading.Barrier(5) + + def try_claim(worker_id): + barrier.wait() # all threads start claiming simultaneously + task = tq.claim("work", worker_id) + if task is not None: + with lock: + claimed_by.append(worker_id) + + _run_threads_with_timeout([lambda w=f"w{i}": try_claim(w) for i in range(5)]) + assert len(claimed_by) == 1, f"Double-claim: {claimed_by}" + + def test_concurrent_submit_idempotent(self): + """Multiple threads submitting the same task_id — only one task created.""" + tq = InMemoryTaskQueue() + barrier = threading.Barrier(5) + + def submit(i): + barrier.wait() + tq.submit("work", {"thread": i}, task_id="same-id") + + _run_threads_with_timeout([lambda i=i: submit(i) for i in range(5)]) + assert tq.pending_count() == 1 + + def test_step_counters_stay_in_sync(self): + """All agent threads must assign the same step ID to each + choreographic statement, even under different scheduling.""" + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + + step_ids_seen: dict[str, list[str]] = {"arch": [], "coder": []} + lock = threading.Lock() + + class StepTrackingMock(ObjectInterpretation): + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + return "result" + + tq = InMemoryTaskQueue() + ids = frozenset(["arch", "coder"]) + + def run_agent(agent): + mock = StepTrackingMock() + tq.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq, ids, poll_interval=0.02) + with handler(mock), handler(epp): + arch.plan("spec") + coder.implement("plan") + # After execution, check the step counter + with lock: + step_ids_seen[agent.__agent_id__].append(epp._step) + + _run_threads_with_timeout([lambda a=a: run_agent(a) for a in [arch, coder]]) + # Both agents should have advanced through the same number of steps + assert step_ids_seen["arch"] == step_ids_seen["coder"] + + def test_many_agents_many_steps(self): + """Stress test: 5 coders scatter over 20 items.""" + coders = [Coder(agent_id=f"c{i}") for i in range(5)] + items = list(range(20)) + responses = {"implement": "done"} + + def choreo(items, coders): + return scatter(items, coders, lambda c, m: c.implement(str(m))) + + mock = MockLLM(responses) + c = Choreography( + choreo, + agents=coders, + queue=InMemoryTaskQueue(), + handlers=[mock], + poll_interval=0.02, + ) + result = c.run(items=items, coders=coders) + assert result == ["done"] * 20 + + +# ── Edge case tests ─────────────────────────────────────────────── + + +class TestEdgeCases: + def test_single_agent_choreography(self): + """A choreography with only one agent works without deadlock.""" + coder = Coder(agent_id="solo") + + def choreo(coder): + return coder.implement("just me") + + mock = MockLLM({"implement": "solo-code"}) + c = Choreography( + choreo, + agents=[coder], + queue=InMemoryTaskQueue(), + handlers=[mock], + poll_interval=0.02, + ) + result = c.run(coder=coder) + assert result == "solo-code" + + def test_empty_scatter(self): + """Scatter over an empty list returns [] without hanging.""" + c1 = Coder(agent_id="c1") + c2 = Coder(agent_id="c2") + + def choreo(items, coders): + return scatter(items, coders, lambda c, m: c.implement(m)) + + mock = MockLLM({"implement": "code"}) + c = Choreography( + choreo, + agents=[c1, c2], + queue=InMemoryTaskQueue(), + handlers=[mock], + poll_interval=0.02, + ) + result = c.run(items=[], coders=[c1, c2]) + assert result == [] + + def test_agent_error_propagates(self): + """An exception in one agent's template propagates as ChoreographyError.""" + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + + def choreo(arch, coder): + plan = arch.plan("spec") + return coder.implement(plan) + + mock = FailingMockLLM( + responses={"implement": "code"}, + fail_on={"arch.plan"}, + ) + c = Choreography( + choreo, + agents=[arch, coder], + queue=InMemoryTaskQueue(), + handlers=[mock], + poll_interval=0.02, + ) + with pytest.raises(ChoreographyError, match="arch"): + c.run(arch=arch, coder=coder) + + def test_scatter_single_worker(self): + """Scatter with one worker still completes all items.""" + coder = Coder(agent_id="c1") + + call_count = {"n": 0} + call_lock = threading.Lock() + + class CountingMock(ObjectInterpretation): + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + with call_lock: + call_count["n"] += 1 + return f"result-{call_count['n']}" + + def choreo(items, coders): + return scatter(items, coders, lambda c, m: c.implement(m)) + + c = Choreography( + choreo, + agents=[coder], + queue=InMemoryTaskQueue(), + handlers=[CountingMock()], + poll_interval=0.02, + ) + result = c.run(items=["a", "b", "c"], coders=[coder]) + assert len(result) == 3 + # All results should be non-None + assert all(r is not None for r in result) + + def test_scatter_error_propagates(self): + """An error inside a scatter item propagates as ChoreographyError.""" + c1 = Coder(agent_id="c1") + + class ScatterFailMock(ObjectInterpretation): + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + raise RuntimeError("scatter item failed") + + def choreo(items, coders): + return scatter(items, coders, lambda c, m: c.implement(m)) + + c = Choreography( + choreo, + agents=[c1], + queue=InMemoryTaskQueue(), + handlers=[ScatterFailMock()], + poll_interval=0.02, + ) + with pytest.raises(ChoreographyError): + c.run(items=["x"], coders=[c1]) + + def test_result_none_vs_not_done(self): + """get_result returns None for non-existent tasks, not for tasks + whose result *is* None — ensuring poll loops don't confuse the two.""" + tq = InMemoryTaskQueue() + # Non-existent task + assert tq.get_result("nonexistent") is None + + # Task with an actual result of a falsy value + tq.submit("work", {}, task_id="t1") + tq.claim("work", "w1") + tq.complete("t1", "w1", 0) # result is 0 (falsy but not None) + assert tq.get_result("t1") == 0 + + def test_repeated_runs_deterministic(self): + """Running the same choreography 5 times gives identical results.""" + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + + def choreo(arch, coder): + plan = arch.plan("spec") + return coder.implement(plan) + + results = [] + for i in range(5): + q_dir = STATE_DIR / "q" + if q_dir.exists(): + shutil.rmtree(q_dir) + mock = MockLLM({"plan": "plan", "implement": "code"}) + c = Choreography( + choreo, + agents=[arch, coder], + queue=InMemoryTaskQueue(), + handlers=[mock], + poll_interval=0.02, + ) + results.append(c.run(arch=arch, coder=coder)) + + assert all(r == results[0] for r in results) + + +# ── Scatter tests ───────────────────────────────────────────────── + + +class TestScatter: + def test_scatter_distributes_work(self): + """Scatter distributes items across agents via claim mechanism.""" + c1 = Coder(agent_id="c1") + c2 = Coder(agent_id="c2") + rev = Reviewer(agent_id="rev") + + tq = InMemoryTaskQueue() + ids = frozenset(["c1", "c2", "rev"]) + items = ["A", "B", "C", "D"] + + def choreo(items, coders, reviewer): + codes = scatter(items, coders, lambda c, m: c.implement(m)) + return [reviewer.review(code) for code in codes] + + results: dict[str, Any] = {} + llm_calls: dict[str, list[str]] = {} + errors: list = [] + lock = threading.Lock() + + def run(agent): + try: + mock = MockLLM({"implement": "code", "review": "PASS"}) + tq.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(items, [c1, c2], rev) + with lock: + results[agent.__agent_id__] = r + llm_calls[agent.__agent_id__] = list(mock.calls) + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in [c1, c2, rev]]) + + assert not errors, errors + assert results["c1"] == results["c2"] == results["rev"] + # Total execute calls across coders should be 4 + c1_impl = [c for c in llm_calls.get("c1", []) if "implement" in c] + c2_impl = [c for c in llm_calls.get("c2", []) if "implement" in c] + assert len(c1_impl) + len(c2_impl) == 4 + # Reviewer did all 4 reviews + rev_reviews = [c for c in llm_calls.get("rev", []) if "review" in c] + assert len(rev_reviews) == 4 + + def test_scatter_crash_recovery(self): + """Scatter with some items cached: only remaining items executed.""" + tq = InMemoryTaskQueue() + c1 = Coder(agent_id="c1") + c2 = Coder(agent_id="c2") + + items = ["A", "B", "C", "D"] + + # Pre-cache items 0 and 1 via submit/claim/complete + for i in range(2): + step = f"step-0000:{i:04d}" + tq.submit("scatter-step-0000", {"item_index": i}, task_id=step) + tq.claim("scatter-step-0000", "prior-worker") + tq.complete(step, "prior-worker", f"cached-{items[i]}") + + def choreo(items, coders): + return scatter(items, coders, lambda c, m: c.implement(m)) + + ids = frozenset(["c1", "c2"]) + results: dict[str, Any] = {} + llm_calls: dict[str, list[str]] = {} + errors: list = [] + lock = threading.Lock() + + def run(agent): + try: + mock = MockLLM({"implement": "fresh"}) + tq.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(items, [c1, c2]) + with lock: + results[agent.__agent_id__] = r + llm_calls[agent.__agent_id__] = list(mock.calls) + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in [c1, c2]]) + + assert not errors, errors + r = results["c1"] + assert r[0] == "cached-A" + assert r[1] == "cached-B" + assert r[2] == "fresh" + assert r[3] == "fresh" + + # Only 2 LLM calls total (items 2 and 3) + total = sum(len(calls) for calls in llm_calls.values()) + assert total == 2 + + +# ── Choreography runner tests ───────────────────────────────────── + + +class TestChoreography: + def test_run(self): + """High-level Choreography.run() orchestrates everything.""" + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + + def choreo(arch, coder): + plan = arch.plan("spec") + return coder.implement(plan) + + mock = MockLLM({"plan": "plan", "implement": "code"}) + c = Choreography( + choreo, + agents=[arch, coder], + queue=InMemoryTaskQueue(), + handlers=[mock], + poll_interval=0.02, + ) + result = c.run(arch=arch, coder=coder) + assert result == "code" + + def test_run_restart(self): + """Run, restart, verify cached results are used.""" + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + + def choreo(arch, coder): + plan = arch.plan("spec") + return coder.implement(plan) + + # Shared persistent queue survives across Choreography instances + shared_queue = PersistentTaskQueue(STATE_DIR / "restart_q.db") + + mock1 = MockLLM({"plan": "plan-v1", "implement": "code-v1"}) + c1 = Choreography( + choreo, + agents=[arch, coder], + queue=shared_queue, + handlers=[mock1], + poll_interval=0.02, + ) + result1 = c1.run(arch=arch, coder=coder) + assert result1 == "code-v1" + + # "Restart" — new Choreography, same persistent queue + # Even with different responses, should use cached + mock2 = MockLLM({"plan": "plan-v2", "implement": "code-v2"}) + c2 = Choreography( + choreo, + agents=[arch, coder], + queue=shared_queue, + handlers=[mock2], + poll_interval=0.02, + ) + result2 = c2.run(arch=arch, coder=coder) + # Both steps were cached from first run + assert result2 == "code-v1" + + def test_scatter_with_choreography(self): + """Choreography with scatter.""" + c1 = Coder(agent_id="c1") + c2 = Coder(agent_id="c2") + + def choreo(items, coders): + return scatter(items, coders, lambda c, m: c.implement(m)) + + mock = MockLLM({"implement": "code"}) + c = Choreography( + choreo, + agents=[c1, c2], + queue=InMemoryTaskQueue(), + handlers=[mock], + poll_interval=0.02, + ) + result = c.run(items=["A", "B", "C"], coders=[c1, c2]) + assert result == ["code", "code", "code"] + + +# ── SQLite crash tolerance tests ───────────────────────────────── + + +class TestSQLiteTaskQueueCrashTolerance: + """Tests that verify SQLite-specific crash tolerance properties.""" + + def test_wal_mode_enabled(self): + """WAL journal mode is enabled for crash tolerance.""" + import sqlite3 + + tq = PersistentTaskQueue(STATE_DIR / "wal_test.db") + tq.submit("work", {}, task_id="t1") + + conn = sqlite3.connect(str(tq.db_path)) + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + conn.close() + assert mode == "wal" + + def test_db_integrity_after_operations(self): + """Database passes integrity check after various operations.""" + import sqlite3 + + tq = PersistentTaskQueue(STATE_DIR / "integrity_test.db") + + # Submit, claim, complete, fail + tq.submit("work", {"data": "a"}, task_id="t1") + tq.submit("work", {"data": "b"}, task_id="t2") + tq.claim("work", "w1") + tq.complete("t1", "w1", {"result": "done"}) + tq.claim("work", "w1") + tq.fail("t2", "w1", "boom") + + conn = sqlite3.connect(str(tq.db_path)) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + conn.close() + assert result == "ok" + + def test_persistence_across_instances(self): + """Data survives creating a new PersistentTaskQueue on same db.""" + db_path = STATE_DIR / "persist_test.db" + + # Instance 1: submit and complete + tq1 = PersistentTaskQueue(db_path) + tq1.submit("work", {"key": "value"}, task_id="t1") + tq1.claim("work", "w1") + tq1.complete("t1", "w1", "result-1") + + # Instance 2: verify result survives + tq2 = PersistentTaskQueue(db_path) + assert tq2.get_result("t1") == "result-1" + assert tq2.all_done() + + def test_stale_claims_survive_restart(self): + """Claimed-but-not-completed tasks are recoverable after restart.""" + db_path = STATE_DIR / "stale_test.db" + + # Instance 1: submit and claim (simulating crash before completion) + tq1 = PersistentTaskQueue(db_path) + tq1.submit("work", {"data": "x"}, task_id="t1") + tq1.submit("work", {"data": "y"}, task_id="t2") + tq1.claim("work", "crashed_worker") + tq1.claim("work", "crashed_worker") + + # Instance 2: restart, release stale claims, re-claim + tq2 = PersistentTaskQueue(db_path) + assert tq2.pending_count() == 0 # both are claimed + released = tq2.release_stale_claims("crashed_worker") + assert released == 2 + assert tq2.pending_count() == 2 + + # Can re-claim and complete + task = tq2.claim("work", "new_worker") + assert task is not None + tq2.complete(task["id"], "new_worker", "recovered") + assert tq2.get_result(task["id"]) == "recovered" + + def test_partial_choreography_restart(self): + """Choreography can resume from a partially completed state.""" + db_path = STATE_DIR / "partial_choreo.db" + arch = Architect(agent_id="arch") + coder = Coder(agent_id="coder") + + # "Run 1": complete step 0 (plan), simulate crash before step 1 + tq1 = PersistentTaskQueue(db_path) + tq1.submit("plan", {"agent": "arch"}, task_id="step-0000") + tq1.claim("plan", "arch") + tq1.complete("step-0000", "arch", "the-plan") + # step 1 never submitted (crash) + + # "Run 2": restart with same db — step 0 cached, step 1 fresh + tq2 = PersistentTaskQueue(db_path) + + def choreo(arch, coder): + plan = arch.plan("spec") + return coder.implement(plan) + + ids = frozenset(["arch", "coder"]) + results: dict[str, Any] = {} + errors: list = [] + lock = threading.Lock() + + def run(agent): + try: + mock = MockLLM({"plan": "SHOULD-NOT-RUN", "implement": "fresh-code"}) + tq2.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq2, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(arch=arch, coder=coder) + with lock: + results[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in [arch, coder]]) + + assert not errors, errors + # Step 0 was cached as "the-plan", step 1 ran fresh + assert results["arch"] == "fresh-code" + assert results["coder"] == "fresh-code" + + def test_concurrent_claims_across_connections(self): + """Multiple threads with separate connections cannot double-claim.""" + db_path = STATE_DIR / "concurrent_conn.db" + tq = PersistentTaskQueue(db_path) + + n_tasks = 20 + for i in range(n_tasks): + tq.submit("work", {"i": i}, task_id=f"t{i:03d}") + + claimed: list[dict] = [] + lock = threading.Lock() + + def claimer(owner): + while True: + task = tq.claim("work", owner) + if task is None: + break + with lock: + claimed.append(task) + + _run_threads_with_timeout([lambda w=f"w{i}": claimer(w) for i in range(5)]) + + ids = [t["id"] for t in claimed] + assert len(ids) == n_tasks + assert len(set(ids)) == n_tasks + + def test_scatter_crash_recovery_sqlite(self): + """Scatter with SQLite queue: cached items survive restart.""" + db_path = STATE_DIR / "scatter_crash.db" + + # "Run 1": complete scatter items 0 and 1 + tq1 = PersistentTaskQueue(db_path) + for i in range(2): + step = f"step-0000:{i:04d}" + tq1.submit("scatter-step-0000", {"item_index": i}, task_id=step) + tq1.claim("scatter-step-0000", "prior-worker") + tq1.complete(step, "prior-worker", f"cached-{i}") + + # "Run 2": scatter should pick up cached items 0,1 and run 2,3 fresh + tq2 = PersistentTaskQueue(db_path) + c1 = Coder(agent_id="c1") + c2 = Coder(agent_id="c2") + items = ["A", "B", "C", "D"] + + def choreo(items, coders): + return scatter(items, coders, lambda c, m: c.implement(m)) + + ids = frozenset(["c1", "c2"]) + results: dict[str, Any] = {} + errors: list = [] + lock = threading.Lock() + + def run(agent): + try: + mock = MockLLM({"implement": "fresh"}) + tq2.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq2, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(items, [c1, c2]) + with lock: + results[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in [c1, c2]]) + + assert not errors, errors + r = results["c1"] + assert r[0] == "cached-0" + assert r[1] == "cached-1" + assert r[2] == "fresh" + assert r[3] == "fresh" + + def test_idempotent_submit_across_restarts(self): + """Re-submitting existing task_id after restart is a no-op.""" + db_path = STATE_DIR / "idempotent_restart.db" + + tq1 = PersistentTaskQueue(db_path) + tq1.submit("work", {"original": True}, task_id="t1") + + tq2 = PersistentTaskQueue(db_path) + tq2.submit("work", {"duplicate": True}, task_id="t1") # should be ignored + assert tq2.pending_count() == 1 + + # Verify original payload preserved + task = tq2.claim("work", "w1") + assert task is not None + assert task["payload"] == {"original": True} + + +# ── fan_out tests ───────────────────────────────────────────────── + + +class TestFanOutDefault: + """fan_out without EPP handler — sequential fallback.""" + + def test_default_sequential(self): + """Default fan_out runs groups sequentially.""" + coder = Coder(agent_id="c1") + tester = TesterAgent(agent_id="t1") + + mock = MockLLM( + { + "c1.implement": "code-result", + "t1.write_tests": "test-result", + } + ) + with handler(mock): + results = fan_out( + [ + (["a", "b"], coder, lambda c, m: c.implement(m)), + (["x"], tester, lambda t, m: t.write_tests(m)), + ] + ) + assert len(results) == 2 + assert results[0] == ["code-result", "code-result"] + assert results[1] == ["test-result"] + + def test_default_empty_groups(self): + """fan_out with empty item lists returns empty lists.""" + coder = Coder(agent_id="c1") + results = fan_out( + [ + ([], coder, lambda c, m: c.implement(m)), + ] + ) + assert results == [[]] + + def test_default_no_groups(self): + """fan_out with no groups returns empty list.""" + results = fan_out([]) + assert results == [] + + +class TestFanOutEPP: + """fan_out under EndpointProjection — concurrent execution.""" + + def test_concurrent_different_agent_types(self): + """Three different agent types work concurrently via fan_out.""" + coder = Coder(agent_id="c1") + tester = TesterAgent(agent_id="t1") + prover = Prover(agent_id="p1") + + mock = MockLLM( + { + "c1.implement": "impl-result", + "t1.write_tests": "test-result", + "p1.prove": "proof-result", + } + ) + queue = InMemoryTaskQueue() + agents = [coder, tester, prover] + ids = frozenset(a.__agent_id__ for a in agents) + + results_map: dict[str, Any] = {} + errors: list[Exception] = [] + lock = threading.Lock() + + def choreo(c, t, p): + return fan_out( + [ + (["a", "b"], c, lambda c, m: c.implement(m)), + (["x", "y", "z"], t, lambda t, m: t.write_tests(m)), + (["th1"], p, lambda p, m: p.prove(m)), + ] + ) + + def run(agent): + try: + epp = EndpointProjection(agent, queue, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(coder, tester, prover) + with lock: + results_map[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in agents]) + + assert not errors, errors + # All agents compute same result + for aid in ["c1", "t1", "p1"]: + r = results_map[aid] + assert r[0] == ["impl-result", "impl-result"] + assert r[1] == ["test-result", "test-result", "test-result"] + assert r[2] == ["proof-result"] + + def test_fan_out_with_multiple_workers_per_group(self): + """fan_out with multiple coders in one group, single tester in another.""" + c1 = Coder(agent_id="c1") + c2 = Coder(agent_id="c2") + t1 = TesterAgent(agent_id="t1") + + execution_log: list[str] = [] + log_lock = threading.Lock() + + class TrackingMockLLM(ObjectInterpretation): + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + bound = get_bound_agent(template) + key = f"{bound.__agent_id__}.{template.__name__}" if bound else "" + with log_lock: + execution_log.append(key) + return f"result-{key}" + + mock = TrackingMockLLM() + queue = InMemoryTaskQueue() + agents = [c1, c2, t1] + ids = frozenset(a.__agent_id__ for a in agents) + + results_map: dict[str, Any] = {} + errors: list[Exception] = [] + lock = threading.Lock() + + def choreo(coders, tester): + return fan_out( + [ + (["m1", "m2", "m3", "m4"], coders, lambda c, m: c.implement(m)), + (["t1", "t2"], tester, lambda t, m: t.write_tests(m)), + ] + ) + + def run(agent): + try: + epp = EndpointProjection(agent, queue, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo([c1, c2], t1) + with lock: + results_map[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in agents]) + + assert not errors, errors + # Coders split the 4 items; tester handles 2 items + r = results_map["c1"] + assert len(r[0]) == 4 # 4 code results + assert len(r[1]) == 2 # 2 test results + # Both coders should have executed some implement calls + coder_calls = [c for c in execution_log if c.endswith(".implement")] + tester_calls = [c for c in execution_log if c.endswith(".write_tests")] + assert len(coder_calls) == 4 + assert len(tester_calls) == 2 + + def test_fan_out_empty_group_no_hang(self): + """Empty items in one group don't cause hangs.""" + coder = Coder(agent_id="c1") + tester = TesterAgent(agent_id="t1") + + mock = MockLLM({"c1.implement": "code", "t1.write_tests": "test"}) + queue = InMemoryTaskQueue() + agents = [coder, tester] + ids = frozenset(a.__agent_id__ for a in agents) + + results_map: dict[str, Any] = {} + errors: list[Exception] = [] + lock = threading.Lock() + + def choreo(c, t): + return fan_out( + [ + ([], c, lambda c, m: c.implement(m)), # empty! + (["x"], t, lambda t, m: t.write_tests(m)), + ] + ) + + def run(agent): + try: + epp = EndpointProjection(agent, queue, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(coder, tester) + with lock: + results_map[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in agents]) + + assert not errors, errors + r = results_map["c1"] + assert r[0] == [] + assert r[1] == ["test"] + + def test_fan_out_step_counter_sync(self): + """fan_out consumes exactly one step ID across all agent threads.""" + coder = Coder(agent_id="c1") + reviewer = Reviewer(agent_id="r1") + + mock = MockLLM( + { + "c1.implement": "code", + "r1.review": "lgtm", + } + ) + queue = InMemoryTaskQueue() + agents = [coder, reviewer] + ids = frozenset(a.__agent_id__ for a in agents) + + results_map: dict[str, Any] = {} + errors: list[Exception] = [] + lock = threading.Lock() + + def choreo(c, r): + # fan_out uses 1 step, then review uses 1 step + fan_results = fan_out( + [ + (["a"], c, lambda c, m: c.implement(m)), + ] + ) + verdict = r.review(str(fan_results)) + return verdict + + def run(agent): + try: + epp = EndpointProjection(agent, queue, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(coder, reviewer) + with lock: + results_map[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in agents]) + + assert not errors, errors + # Both agents compute the same final result + assert results_map["c1"] == "lgtm" + assert results_map["r1"] == "lgtm" + + def test_fan_out_error_in_one_group(self): + """Error in one group's fn propagates as ChoreographyError.""" + coder = Coder(agent_id="c1") + tester = TesterAgent(agent_id="t1") + + class ErrorMockLLM(ObjectInterpretation): + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + bound = get_bound_agent(template) + key = f"{bound.__agent_id__}.{template.__name__}" if bound else "" + if key == "t1.write_tests": + raise RuntimeError("test failure!") + return "ok" + + mock = ErrorMockLLM() + + def program(c, t): + return fan_out( + [ + (["a"], c, lambda c, m: c.implement(m)), + (["x"], t, lambda t, m: t.write_tests(m)), + ] + ) + + choreo = Choreography( + program, + agents=[coder, tester], + handlers=[mock], + ) + with pytest.raises(ChoreographyError, match="test failure"): + choreo.run(c=coder, t=tester) + + def test_fan_out_all_orderings_agree(self): + """All thread scheduling orderings produce the same result.""" + for delays in [ + {"c1": 0.0, "t1": 0.05}, + {"c1": 0.05, "t1": 0.0}, + {"c1": 0.0, "t1": 0.0}, + ]: + coder = Coder(agent_id="c1") + tester = TesterAgent(agent_id="t1") + + mock = DelayedMockLLM( + {"c1.implement": "code", "t1.write_tests": "test"}, + delays, + ) + queue = InMemoryTaskQueue() + agents = [coder, tester] + ids = frozenset(a.__agent_id__ for a in agents) + + results_map: dict[str, Any] = {} + errors: list[Exception] = [] + lock = threading.Lock() + + def choreo(c, t): + return fan_out( + [ + (["a", "b"], c, lambda c, m: c.implement(m)), + (["x"], t, lambda t, m: t.write_tests(m)), + ] + ) + + def run(agent): + try: + epp = EndpointProjection(agent, queue, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(coder, tester) + with lock: + results_map[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in agents]) + + assert not errors, errors + assert results_map["c1"] == results_map["t1"] + + +class TestFanOutChoreography: + """fan_out via the high-level Choreography runner.""" + + def test_choreography_fan_out(self): + """fan_out works through Choreography.run().""" + coder = Coder(agent_id="c1") + tester = TesterAgent(agent_id="t1") + + mock = MockLLM({"c1.implement": "code-out", "t1.write_tests": "test-out"}) + + def program(c, t): + return fan_out( + [ + (["a", "b"], c, lambda c, m: c.implement(m)), + (["x"], t, lambda t, m: t.write_tests(m)), + ] + ) + + choreo = Choreography( + program, + agents=[coder, tester], + handlers=[mock], + ) + result = choreo.run(c=coder, t=tester) + assert result == [["code-out", "code-out"], ["test-out"]] + + def test_choreography_fan_out_with_scatter(self): + """fan_out and scatter can be mixed in the same choreography.""" + architect = Architect(agent_id="arch") + coder = Coder(agent_id="c1") + tester = TesterAgent(agent_id="t1") + reviewer = Reviewer(agent_id="r1") + + mock = MockLLM( + { + "arch.plan": "plan-result", + "c1.implement": "code", + "t1.write_tests": "test", + "r1.review": "lgtm", + } + ) + + def program(arch, c, t, r): + plan = arch.plan("project") + # fan_out: code and test in parallel + code_results, test_results = fan_out( + [ + (["m1", "m2"], c, lambda c, m: c.implement(m)), + (["t1"], t, lambda t, m: t.write_tests(m)), + ] + ) + # Then scatter reviews sequentially + reviews = scatter(code_results, r, lambda r, code: r.review(code)) + return {"plan": plan, "reviews": reviews, "tests": test_results} + + choreo = Choreography( + program, + agents=[architect, coder, tester, reviewer], + handlers=[mock], + ) + result = choreo.run(arch=architect, c=coder, t=tester, r=reviewer) + assert result["plan"] == "plan-result" + assert result["reviews"] == ["lgtm", "lgtm"] + assert result["tests"] == ["test"] + + +class TestFanOutCrashTolerance: + """fan_out crash recovery with PersistentTaskQueue.""" + + def test_fan_out_crash_recovery_with_cached_results(self): + """Pre-cached fan_out items are reused on restart.""" + db_path = STATE_DIR / "fan_out_crash.db" + + # "Run 1": simulate partial completion + tq1 = PersistentTaskQueue(db_path) + # Pre-cache group 0 items + tq1.submit( + "fan-step-0000:g0", + {"group": 0, "item_index": 0}, + task_id="step-0000:g0:0000", + ) + tq1.claim_by_prefix("step-0000:g0:0000", "prior-worker") + tq1.complete("step-0000:g0:0000", "prior-worker", "cached-code") + + tq1.submit( + "fan-step-0000:g1", + {"group": 1, "item_index": 0}, + task_id="step-0000:g1:0000", + ) + tq1.claim_by_prefix("step-0000:g1:0000", "prior-worker") + tq1.complete("step-0000:g1:0000", "prior-worker", "cached-test") + + # "Run 2": restart with one uncached item per group + tq2 = PersistentTaskQueue(db_path) + c1 = Coder(agent_id="c1") + t1 = TesterAgent(agent_id="t1") + + mock = MockLLM({"c1.implement": "fresh-code", "t1.write_tests": "fresh-test"}) + + def choreo(c, t): + return fan_out( + [ + (["A", "B"], c, lambda c, m: c.implement(m)), + (["X", "Y"], t, lambda t, m: t.write_tests(m)), + ] + ) + + agents = [c1, t1] + ids = frozenset(a.__agent_id__ for a in agents) + results_map: dict[str, Any] = {} + errors: list[Exception] = [] + lock = threading.Lock() + + def run(agent): + try: + tq2.release_stale_claims(agent.__agent_id__) + epp = EndpointProjection(agent, tq2, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo(c1, t1) + with lock: + results_map[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in agents]) + + assert not errors, errors + r = results_map["c1"] + assert r[0][0] == "cached-code" + assert r[0][1] == "fresh-code" + assert r[1][0] == "cached-test" + assert r[1][1] == "fresh-test" + + def test_fan_out_concurrent_no_double_execution(self): + """No item is executed twice even under concurrent claiming.""" + c1 = Coder(agent_id="c1") + c2 = Coder(agent_id="c2") + t1 = TesterAgent(agent_id="t1") + t2 = TesterAgent(agent_id="t2") + + executed: list[str] = [] + exec_lock = threading.Lock() + + class TrackingMock(ObjectInterpretation): + @implements(Template.__apply__) + def _call(self, template, *args, **kwargs): + bound = get_bound_agent(template) + key = f"{bound.__agent_id__}.{template.__name__}" if bound else "" + with exec_lock: + executed.append(key) + return f"result-{key}" + + mock = TrackingMock() + queue = InMemoryTaskQueue() + agents = [c1, c2, t1, t2] + ids = frozenset(a.__agent_id__ for a in agents) + + results_map: dict[str, Any] = {} + errors: list[Exception] = [] + lock = threading.Lock() + + items_code = [f"mod-{i}" for i in range(10)] + items_test = [f"test-{i}" for i in range(8)] + + def choreo(coders, testers): + return fan_out( + [ + (items_code, coders, lambda c, m: c.implement(m)), + (items_test, testers, lambda t, m: t.write_tests(m)), + ] + ) + + def run(agent): + try: + epp = EndpointProjection(agent, queue, ids, poll_interval=0.02) + with handler(mock), handler(epp): + r = choreo([c1, c2], [t1, t2]) + with lock: + results_map[agent.__agent_id__] = r + except Exception as e: + with lock: + errors.append(e) + + _run_threads_with_timeout([lambda a=a: run(a) for a in agents]) + + assert not errors, errors + # Exactly 10 implement + 8 write_tests calls + impl_calls = [c for c in executed if ".implement" in c] + test_calls = [c for c in executed if ".write_tests" in c] + assert len(impl_calls) == 10 + assert len(test_calls) == 8 + # All agents see the same result + for aid in ["c1", "c2", "t1", "t2"]: + assert results_map[aid] == results_map["c1"] diff --git a/tests/test_persistent_agent.py b/tests/test_persistent_agent.py new file mode 100644 index 00000000..2c4458f4 --- /dev/null +++ b/tests/test_persistent_agent.py @@ -0,0 +1,1612 @@ +"""Tests for PersistentAgent + PersistenceHandler + CompactionHandler. + +Checkpointing, compaction, crash recovery, nested calls, subclass state +persistence, and system prompt augmentation. +""" + +import dataclasses +import json +import sqlite3 +from collections import OrderedDict +from pathlib import Path +from typing import Any + +import pytest +from litellm import ModelResponse + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import ( + AgentHistoryHandler, + LiteLLMProvider, + RetryLLMHandler, + ToolCallExecutionError, + completion, + get_agent_history, +) +from effectful.handlers.llm.persistence import ( + CompactionHandler, + PersistenceHandler, + PersistentAgent, +) +from effectful.handlers.llm.template import get_bound_agent +from effectful.ops.semantics import handler +from effectful.ops.syntax import ObjectInterpretation, implements +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_text_response(content: str) -> ModelResponse: + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + model="test-model", + ) + + +def make_tool_call_response( + tool_name: str, tool_args: str, tool_call_id: str = "call_1" +) -> ModelResponse: + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": {"name": tool_name, "arguments": tool_args}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + model="test-model", + ) + + +class MockCompletionHandler(ObjectInterpretation): + """Returns pre-configured responses and captures messages sent to the LLM.""" + + def __init__(self, responses: list[ModelResponse]): + self.responses = responses + self.call_count = 0 + self.received_messages: list[list] = [] + + @implements(completion) + def _completion(self, model, messages=None, **kwargs): + self.received_messages.append(list(messages) if messages else []) + response = self.responses[min(self.call_count, len(self.responses) - 1)] + self.call_count += 1 + return response + + +def read_checkpoint(tmp_path: Path, agent_id: str) -> dict: + """Read a checkpoint from the SQLite database.""" + db_path = tmp_path / "checkpoints.db" + conn = sqlite3.connect(str(db_path)) + row = conn.execute( + "SELECT handoff, state, history FROM checkpoints WHERE agent_id = ?", + (agent_id,), + ).fetchone() + conn.close() + if row is None: + raise FileNotFoundError(f"No checkpoint for {agent_id}") + return { + "agent_id": agent_id, + "handoff": row[0], + "state": json.loads(row[1]), + "history": json.loads(row[2]), + } + + +def has_checkpoint(tmp_path: Path, agent_id: str) -> bool: + """Check if a checkpoint exists in the SQLite database.""" + db_path = tmp_path / "checkpoints.db" + if not db_path.exists(): + return False + conn = sqlite3.connect(str(db_path)) + row = conn.execute( + "SELECT 1 FROM checkpoints WHERE agent_id = ?", + (agent_id,), + ).fetchone() + conn.close() + return row is not None + + +# --------------------------------------------------------------------------- +# Test agents +# --------------------------------------------------------------------------- + + +class ChatBot(PersistentAgent): + """You are a persistent chat bot for testing.""" + + @Template.define + def send(self, user_input: str) -> str: + """User says: {user_input}""" + raise NotHandled + + +@dataclasses.dataclass +class StatefulBot(PersistentAgent): + """You are a stateful bot that tracks learned patterns.""" + + __agent_id__ = "StatefulBot" + + learned_patterns: list[str] = dataclasses.field(default_factory=list) + call_count: int = 0 + + @Template.define + def send(self, user_input: str) -> str: + """User says: {user_input}""" + raise NotHandled + + +class NestedBot(PersistentAgent): + """You are a nested-call test bot.""" + + @Template.define + def inner_check(self, payload: str) -> str: + """Check: {payload}. Do not use tools.""" + raise NotHandled + + @Tool.define + def check_tool(self, payload: str) -> str: + """Check payload by calling an inner template.""" + return self.inner_check(payload) + + @Template.define + def outer(self, payload: str) -> str: + """Call `check_tool` for: {payload}, then return final answer.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Tests: Agent.__agent_id__ +# --------------------------------------------------------------------------- + + +class TestAgentId: + """All Agent subclasses get agent_id.""" + + def test_plain_agent_defaults_to_id(self): + class PlainAgent(Agent): + """Plain.""" + + @Template.define + def ask(self, q: str) -> str: + """Q: {q}""" + raise NotHandled + + agent = PlainAgent() + assert agent.__agent_id__ == str(id(agent)) + + def test_persistent_agent_requires_agent_id(self): + bot = ChatBot(agent_id="my-chatbot") + assert bot.__agent_id__ == "my-chatbot" + + def test_dataclass_class_level_id(self): + """Dataclass subclasses can set __agent_id__ as a class attribute.""" + bot = StatefulBot() + assert bot.__agent_id__ == "StatefulBot" + + def test_bound_template_has_agent_via_context(self): + bot = ChatBot(agent_id="ChatBot") + bound = bot.send + assert get_bound_agent(bound) is bot + + +# --------------------------------------------------------------------------- +# Tests: basic persistence (PersistenceHandler) +# --------------------------------------------------------------------------- + + +class TestCheckpointing: + """PersistenceHandler save/load round-trip correctly.""" + + def test_save_creates_file(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + path = persist.save(bot) + assert path.exists() + assert path.suffix == ".db" + + def test_save_round_trip_empty(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "ChatBot") + assert len(data["history"]) == 0 + assert data["handoff"] == "" + + def test_save_round_trip_with_history(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + history = get_agent_history(bot.__agent_id__) + history["msg1"] = { + "id": "msg1", + "role": "user", + "content": "hello", + } + persist.save(bot, handoff="working on X") + + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "working on X" + assert len(data["history"]) == 1 + assert data["history"][0]["content"] == "hello" + + def test_atomic_write(self, tmp_path: Path): + """Checkpoint write uses SQLite transactions for atomicity.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + path = persist.save(bot) + assert path.exists() + data = read_checkpoint(tmp_path, "ChatBot") + assert data["agent_id"] == "ChatBot" + + def test_custom_agent_id(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + class CustomBot(PersistentAgent): + """Custom.""" + + @Template.define + def ask(self, q: str) -> str: + """Q: {q}""" + raise NotHandled + + bot = CustomBot(agent_id="custom-bot") + with handler(AgentHistoryHandler()): + persist.save(bot) + data = read_checkpoint(tmp_path, "custom-bot") + assert data["agent_id"] == "custom-bot" + + +# --------------------------------------------------------------------------- +# Tests: subclass state persistence (checkpoint_state / restore_state) +# --------------------------------------------------------------------------- + + +class TestSubclassStatePersistence: + """Dataclass fields on subclasses are automatically persisted.""" + + def test_dataclass_fields_round_trip(self, tmp_path: Path): + """Dataclass state survives a save and is visible in the DB.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = StatefulBot() + bot.learned_patterns = ["pattern A", "pattern B"] + bot.call_count = 5 + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "StatefulBot") + assert data["state"]["learned_patterns"] == ["pattern A", "pattern B"] + assert data["state"]["call_count"] == 5 + + def test_non_dataclass_has_empty_state(self): + """Non-dataclass subclass returns empty state dict.""" + bot = ChatBot(agent_id="ChatBot") + assert bot.checkpoint_state() == {} + + def test_non_serializable_fields_skipped(self, tmp_path: Path): + @dataclasses.dataclass + class WeirdBot(PersistentAgent): + """Bot with a non-serializable field.""" + + __agent_id__ = "WeirdBot" + + callback: object = dataclasses.field(default=None) + name: str = "test" + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = WeirdBot() + bot.callback = lambda x: x # not JSON serializable + bot.name = "Alice" + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "WeirdBot") + assert data["state"]["name"] == "Alice" + # callback is not JSON serializable, so it should be skipped + assert "callback" not in data["state"] + + def test_custom_checkpoint_restore(self, tmp_path: Path): + """Users can override checkpoint_state / restore_state.""" + + class CustomBot(PersistentAgent): + """Custom serialisation bot.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.data = {"counter": 0} + + def checkpoint_state(self): + return {"data": self.data} + + def restore_state(self, state): + self.data = state.get("data", {"counter": 0}) + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = CustomBot(agent_id="CustomBot") + bot.data["counter"] = 42 + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "CustomBot") + assert data["state"]["data"]["counter"] == 42 + + def test_state_saved_in_checkpoint_file(self, tmp_path: Path): + """The checkpoint DB contains state with subclass fields.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = StatefulBot() + bot.learned_patterns = ["X"] + bot.call_count = 3 + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "StatefulBot") + assert "state" in data + assert data["state"]["learned_patterns"] == ["X"] + assert data["state"]["call_count"] == 3 + + +# --------------------------------------------------------------------------- +# Tests: automatic checkpointing around template calls +# --------------------------------------------------------------------------- + + +class TestAutomaticCheckpointing: + """Template calls on PersistentAgent trigger auto-checkpointing.""" + + def test_checkpoint_saved_after_successful_call(self, tmp_path: Path): + mock = MockCompletionHandler([make_text_response("hello")]) + bot = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("hi") + + data = read_checkpoint(tmp_path, "ChatBot") + assert len(data["history"]) > 0 + assert data["handoff"] == "" + + def test_checkpoint_saved_on_exception(self, tmp_path: Path): + class FailingMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *args, **kwargs): + raise RuntimeError("boom") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError, match="boom"): + with ( + handler(LiteLLMProvider()), + handler(FailingMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("hi") + + data = read_checkpoint(tmp_path, "ChatBot") + assert "Executing send" in data["handoff"] + + def test_handoff_describes_current_call(self, tmp_path: Path): + """Before the template runs, handoff records what's in progress.""" + handoff_during_call = [] + + class SpyMock(ObjectInterpretation): + @implements(completion) + def _completion(self_, model, messages=None, **kwargs): + data = read_checkpoint(tmp_path, "ChatBot") + handoff_during_call.append(data["handoff"]) + return make_text_response("ok") + + bot = ChatBot(agent_id="ChatBot") + with ( + handler(LiteLLMProvider()), + handler(SpyMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("hello") + + assert len(handoff_during_call) == 1 + assert "Executing send" in handoff_during_call[0] + + def test_history_persists_across_sessions(self, tmp_path: Path): + """A 'restart' (new handler + agent) sees prior history.""" + mock = MockCompletionHandler([make_text_response("reply1")]) + bot = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("first message") + + data_after_first = read_checkpoint(tmp_path, "ChatBot") + history_len_first = len(data_after_first["history"]) + + # "Restart" — new handler + new agent instance + mock2 = MockCompletionHandler([make_text_response("reply after restart")]) + bot2 = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock2), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot2.send("second message") + + data_after_second = read_checkpoint(tmp_path, "ChatBot") + assert len(data_after_second["history"]) > history_len_first + + def test_second_call_sees_prior_history(self, tmp_path: Path): + mock = MockCompletionHandler( + [make_text_response("r1"), make_text_response("r2")] + ) + bot = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("a") + bot.send("b") + + assert len(mock.received_messages[1]) > len(mock.received_messages[0]) + + def test_dataclass_state_saved_around_template_calls(self, tmp_path: Path): + mock = MockCompletionHandler([make_text_response("ok")]) + bot = StatefulBot() + bot.call_count = 7 + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("test") + + data = read_checkpoint(tmp_path, "StatefulBot") + assert data["state"]["call_count"] == 7 + + +# --------------------------------------------------------------------------- +# Tests: crash recovery +# --------------------------------------------------------------------------- + + +class TestCrashRecovery: + """Handoff notes enable resumption after crashes.""" + + def test_handoff_survives_crash(self, tmp_path: Path): + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *args, **kwargs): + raise RuntimeError("process killed") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("important task") + + data = read_checkpoint(tmp_path, "ChatBot") + assert "Executing send" in data["handoff"] + + def test_system_prompt_includes_handoff(self, tmp_path: Path): + """After a crash, the next call's system prompt includes the handoff.""" + + # Simulate crash + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *args, **kwargs): + raise RuntimeError("crash") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("important task") + + # Next session: spy on system prompt + system_prompts = [] + + class SpyMock(ObjectInterpretation): + @implements(completion) + def _completion(self_, model, messages=None, **kwargs): + system_prompts.extend( + m.get("content", "") + for m in (messages or []) + if m.get("role") == "system" + ) + return make_text_response("resumed") + + bot2 = ChatBot(agent_id="ChatBot") + with ( + handler(LiteLLMProvider()), + handler(SpyMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot2.send("resume") + + assert any("[HANDOFF FROM PRIOR SESSION]" in p for p in system_prompts) + + def test_handoff_cleared_on_success(self, tmp_path: Path): + # Create crash checkpoint + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *a, **kw): + raise RuntimeError("crash") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("crash task") + + # Successful run clears handoff + mock = MockCompletionHandler([make_text_response("done")]) + bot2 = ChatBot(agent_id="ChatBot") + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot2.send("new task") + + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "" + + def test_dataclass_state_survives_crash(self, tmp_path: Path): + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *a, **kw): + raise RuntimeError("crash") + + bot = StatefulBot() + bot.learned_patterns = ["important insight"] + bot.call_count = 3 + + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("boom") + + data = read_checkpoint(tmp_path, "StatefulBot") + assert data["state"]["learned_patterns"] == ["important insight"] + assert data["state"]["call_count"] == 3 + + +# --------------------------------------------------------------------------- +# Tests: nested template calls +# --------------------------------------------------------------------------- + + +class TestNestedCalls: + """Only outermost template call triggers checkpointing.""" + + def test_nested_template_via_tool_completes(self, tmp_path: Path): + mock = MockCompletionHandler( + [ + make_tool_call_response("self__check_tool", '{"payload": "demo"}'), + make_text_response("inner result"), + make_text_response("outer result"), + ] + ) + bot = NestedBot(agent_id="NestedBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = bot.outer("demo") + + assert result == "outer result" + + def test_nested_call_does_not_double_checkpoint(self, tmp_path: Path): + save_count = 0 + original_save = PersistenceHandler.save + + def counting_save(self, agent, handoff=""): + nonlocal save_count + save_count += 1 + return original_save(self, agent, handoff=handoff) + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__check_tool", '{"payload": "demo"}'), + make_text_response("inner"), + make_text_response("outer"), + ] + ) + bot = NestedBot(agent_id="NestedBot") + PersistenceHandler.save = counting_save # type: ignore[method-assign] + try: + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.outer("demo") + finally: + PersistenceHandler.save = original_save # type: ignore[method-assign] + + # Should be exactly 2: one before call, one after + assert save_count == 2 + + def test_handoff_cleared_after_nested_success(self, tmp_path: Path): + mock = MockCompletionHandler( + [ + make_tool_call_response("self__check_tool", '{"payload": "demo"}'), + make_text_response("inner"), + make_text_response("outer"), + ] + ) + bot = NestedBot(agent_id="NestedBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.outer("demo") + + data = read_checkpoint(tmp_path, "NestedBot") + assert data["handoff"] == "" + + +# --------------------------------------------------------------------------- +# Tests: context compaction (CompactionHandler) +# --------------------------------------------------------------------------- + + +class TestContextCompaction: + """CompactionHandler compacts agent history after template calls.""" + + def test_compact_reduces_history(self): + history: OrderedDict[str, Any] = OrderedDict() + for i in range(10): + history[f"msg{i}"] = { + "id": f"msg{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Message {i}", + } + + compaction = CompactionHandler(max_history_len=6) + mock = MockCompletionHandler( + [make_text_response("Summary of prior conversation.")] + ) + provider = LiteLLMProvider() + with handler(provider), handler(mock): + stored = get_agent_history("PlainBot") + stored.update(history) + compaction._compact("PlainBot", stored) + + result = provider._histories["PlainBot"] + assert len(result) < 10 + first_msg = next(iter(result.values())) + assert "CONTEXT SUMMARY" in first_msg["content"] + + def test_compaction_preserves_recent_messages(self): + history: OrderedDict[str, Any] = OrderedDict() + for i in range(10): + history[f"msg{i}"] = { + "id": f"msg{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Message {i}", + } + + compaction = CompactionHandler(max_history_len=6) + keep_recent = max(6 // 2, 4) + mock = MockCompletionHandler([make_text_response("Summary.")]) + provider = LiteLLMProvider() + with handler(provider), handler(mock): + stored = get_agent_history("ChatBot") + stored.update(history) + compaction._compact("ChatBot", stored) + + result = provider._histories["ChatBot"] + remaining_ids = list(result.keys()) + for i in range(10 - keep_recent, 10): + assert f"msg{i}" in remaining_ids + + def test_compaction_triggered_by_template_call(self, tmp_path: Path): + bot = ChatBot(agent_id="ChatBot") + provider = LiteLLMProvider() + + with handler(provider): + history = get_agent_history(bot.__agent_id__) + for i in range(6): + history[f"old{i}"] = { + "id": f"old{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Old message {i}", + } + + mock = MockCompletionHandler( + [ + make_text_response("new reply"), + make_text_response("Summary of old conversation."), + ] + ) + with ( + handler(provider), + handler(mock), + handler(CompactionHandler(max_history_len=4)), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("trigger compaction") + + result: OrderedDict[str, Any] = provider._histories.get( + "ChatBot", OrderedDict() + ) + assert len(result) <= 4 + 4 + + def test_compaction_works_on_plain_agent(self): + """CompactionHandler works on any Agent, not just PersistentAgent.""" + + class PlainBot(Agent): + """Plain bot.""" + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + bot = PlainBot() + provider = LiteLLMProvider() + + with handler(provider): + history = get_agent_history(bot.__agent_id__) + for i in range(10): + history[f"msg{i}"] = { + "id": f"msg{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Message {i}", + } + + mock = MockCompletionHandler( + [make_text_response("reply"), make_text_response("Summary.")] + ) + with ( + handler(provider), + handler(mock), + handler(CompactionHandler(max_history_len=4)), + ): + bot.send("trigger") + + result = provider._histories.get(bot.__agent_id__, {}) + assert len(result) > 0, "history should not be empty after compaction" + first_msg = next(iter(result.values())) + assert "CONTEXT SUMMARY" in first_msg["content"] + assert len(result) <= 4 + 4 + + def test_compaction_triggered_naturally_on_plain_agent(self): + """CompactionHandler compacts after enough calls accumulate history. + + Makes multiple template calls on a plain Agent so that history + exceeds max_history_len, then verifies compaction fires and + produces a summary message. + """ + + class PlainBot(Agent): + """Plain bot.""" + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + # 4 calls × ~3 msgs each (system+user+assistant) = ~12 msgs + # Compaction threshold is 6, so it should trigger. + responses = [make_text_response(f"reply-{i}") for i in range(4)] + # Extra response for the summarize_context call during compaction + responses.append(make_text_response("Summary of conversation.")) + mock = MockCompletionHandler(responses) + + bot = PlainBot() + provider = LiteLLMProvider() + with ( + handler(provider), + handler(mock), + handler(CompactionHandler(max_history_len=6)), + ): + for i in range(4): + bot.send(f"message-{i}") + + history = provider._histories.get(bot.__agent_id__, {}) + # Should have been compacted: summary + recent messages + first_msg = next(iter(history.values())) + assert "CONTEXT SUMMARY" in first_msg["content"] + + def test_compaction_does_not_split_tool_use_tool_result_pairs(self): + """Compaction must not split tool_use/tool_result message pairs. + + If the cut point falls between an assistant message with tool_use + blocks and the corresponding tool_result message, the Anthropic API + rejects the conversation. This test constructs a history where the + naive positional split would do exactly that and asserts that both + messages end up on the same side of the cut. + """ + history: OrderedDict[str, Any] = OrderedDict() + + for i in range(7): + history[f"msg{i}"] = { + "id": f"msg{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Message {i}", + } + + # msg7: assistant with tool_use (will be last item in old_items) + history["msg7"] = { + "id": "msg7", + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_call_xyz", + "name": "check_tool", + "input": {"payload": "test"}, + } + ], + } + + # msg8: tool_result (will be first item in recent_items) + history["msg8"] = { + "id": "msg8", + "role": "tool", + "tool_call_id": "tool_call_xyz", + "content": "tool result here", + } + + # msg9, msg10, msg11: padding so recent has 4 items + history["msg9"] = { + "id": "msg9", + "role": "assistant", + "content": "Response after tool", + } + history["msg10"] = { + "id": "msg10", + "role": "user", + "content": "Follow up question", + } + history["msg11"] = { + "id": "msg11", + "role": "assistant", + "content": "Final answer", + } + + compaction = CompactionHandler(max_history_len=8) + mock = MockCompletionHandler( + [make_text_response("Summary of prior conversation.")] + ) + provider = LiteLLMProvider() + with handler(provider), handler(mock): + stored = get_agent_history("ToolPairBot") + stored.update(history) + compaction._compact("ToolPairBot", stored) + + result = provider._histories["ToolPairBot"] + result_items = list(result.values()) + + # After compaction, there must be no orphaned tool_result messages. + tool_use_ids: set[str] = set() + for msg in result_items: + content = msg.get("content", "") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_use": + tool_use_ids.add(block["id"]) + + for msg in result_items: + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id", "") + assert tc_id in tool_use_ids, ( + f"Orphaned tool_result with tool_call_id={tc_id!r} after " + f"compaction — the matching tool_use was discarded. " + f"Remaining messages: {[m.get('id') for m in result_items]}" + ) + + def test_compaction_on_plain_agent_preserves_functionality(self): + """After compaction, the plain Agent still works for subsequent calls.""" + + class PlainBot(Agent): + """Plain bot.""" + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + responses = [make_text_response(f"reply-{i}") for i in range(4)] + # Compaction summary call fires after the 4th reply + responses.append(make_text_response("Summary.")) + # Then the 5th send() call + responses.append(make_text_response("reply-4")) + mock = MockCompletionHandler(responses) + + bot = PlainBot() + provider = LiteLLMProvider() + with ( + handler(provider), + handler(mock), + handler(CompactionHandler(max_history_len=6)), + ): + for i in range(4): + bot.send(f"msg-{i}") + # This call happens after compaction + result = bot.send("after-compaction") + + assert result == "reply-4" + + +# --------------------------------------------------------------------------- +# Tests: system prompt +# --------------------------------------------------------------------------- + + +class TestSystemPrompt: + """System prompt of PersistentAgent includes class docstring.""" + + def test_base_docstring_used(self): + bot = ChatBot(agent_id="ChatBot") + assert "persistent chat bot" in bot.__system_prompt__ + + def test_no_handoff_initially(self): + bot = ChatBot(agent_id="ChatBot") + assert "[HANDOFF" not in bot.__system_prompt__ + + +# --------------------------------------------------------------------------- +# Tests: agent isolation +# --------------------------------------------------------------------------- + + +class TestAgentIsolation: + """Multiple PersistentAgent instances are independent in the handler.""" + + def test_two_agents_independent(self, tmp_path: Path): + bot1 = ChatBot(agent_id="bot1") + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + with handler(AgentHistoryHandler()): + persist.save(bot1, handoff="bot1 work") + + # bot2 was never saved — should not exist in DB + assert not has_checkpoint(tmp_path, "bot2") + data = read_checkpoint(tmp_path, "bot1") + assert data["handoff"] == "bot1 work" + + def test_same_db_different_agent_id(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot_a = ChatBot(agent_id="alpha") + bot_b = ChatBot(agent_id="beta") + + with handler(AgentHistoryHandler()): + persist.save(bot_a, handoff="alpha work") + persist.save(bot_b, handoff="beta work") + + data_a = read_checkpoint(tmp_path, "alpha") + data_b = read_checkpoint(tmp_path, "beta") + assert data_a["handoff"] == "alpha work" + assert data_b["handoff"] == "beta work" + + +# --------------------------------------------------------------------------- +# Tests: compatibility with RetryLLMHandler +# --------------------------------------------------------------------------- + + +class TestRetryCompatibility: + """PersistentAgent works with RetryLLMHandler and PersistenceHandler.""" + + def test_retry_then_success(self, tmp_path: Path): + mock = MockCompletionHandler( + [ + make_text_response('"not_an_int"'), + make_text_response('{"value": 42}'), + ] + ) + + class NumberBot(PersistentAgent): + """You are a number bot.""" + + @Template.define + def pick(self) -> int: + """Pick a number.""" + raise NotHandled + + bot = NumberBot(agent_id="NumberBot") + with ( + handler(LiteLLMProvider()), + handler(RetryLLMHandler()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = bot.pick() + + assert result == 42 + data = read_checkpoint(tmp_path, "NumberBot") + assert data["handoff"] == "" + + +# --------------------------------------------------------------------------- +# Tests: PersistenceHandler is optional +# --------------------------------------------------------------------------- + + +class TestWithoutHandler: + """PersistentAgent works without PersistenceHandler — no auto-checkpointing.""" + + def test_agent_works_without_persistence_handler(self): + mock = MockCompletionHandler([make_text_response("hello")]) + bot = ChatBot(agent_id="ChatBot") + + with handler(LiteLLMProvider()), handler(mock): + result = bot.send("hi") + + assert result == "hello" + + +# --------------------------------------------------------------------------- +# Tests: nested calls with failures + persistence +# --------------------------------------------------------------------------- + + +class TestNestedCallFailuresWithPersistence: + """Nested tool calls that fail should not corrupt persistence state.""" + + def test_nested_tool_failure_still_checkpoints(self, tmp_path: Path): + """If a nested tool raises, the outermost handler saves a crash checkpoint.""" + + class FailingBot(PersistentAgent): + """Bot whose tool always fails.""" + + @Template.define + def inner(self, payload: str) -> str: + """Check: {payload}""" + raise NotHandled + + @Tool.define + def failing_tool(self, payload: str) -> str: + """Check payload — always raises.""" + raise RuntimeError("tool exploded") + + @Template.define + def outer(self, payload: str) -> str: + """Call `failing_tool` for: {payload}.""" + raise NotHandled + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__failing_tool", '{"payload": "boom"}'), + ] + ) + bot = FailingBot(agent_id="FailingBot") + + with pytest.raises(ToolCallExecutionError, match="tool exploded"): + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.outer("go") + + data = read_checkpoint(tmp_path, "FailingBot") + assert "Executing outer" in data["handoff"] + + def test_nested_tool_failure_then_recovery(self, tmp_path: Path): + """After a nested tool failure, next session resumes with handoff.""" + mock_crash = MockCompletionHandler( + [ + make_tool_call_response("self__check_tool", '{"payload": "crash"}'), + ] + ) + + class CrashInnerBot(PersistentAgent): + """Bot with crashing inner tool.""" + + call_count = 0 + + @Template.define + def inner_check(self, payload: str) -> str: + """Check: {payload}""" + raise NotHandled + + @Tool.define + def check_tool(self, payload: str) -> str: + """Check payload.""" + self.call_count += 1 + if self.call_count == 1: + raise RuntimeError("first call fails") + return self.inner_check(payload) + + @Template.define + def outer(self, payload: str) -> str: + """Call `check_tool` for: {payload}, then return answer.""" + raise NotHandled + + bot = CrashInnerBot(agent_id="CrashInnerBot") + + # Session 1: crash + with pytest.raises(ToolCallExecutionError, match="first call fails"): + with ( + handler(LiteLLMProvider()), + handler(mock_crash), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.outer("task") + + # Session 2: successful recovery + system_prompts: list[str] = [] + + class SpyMock(ObjectInterpretation): + @implements(completion) + def _completion(self_, model, messages=None, **kwargs): + system_prompts.extend( + m.get("content", "") + for m in (messages or []) + if m.get("role") == "system" + ) + return make_text_response("recovered") + + bot2 = CrashInnerBot(agent_id="CrashInnerBot") + with ( + handler(LiteLLMProvider()), + handler(SpyMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = bot2.outer("retry") + + assert result == "recovered" + assert any("[HANDOFF FROM PRIOR SESSION]" in p for p in system_prompts) + + +# --------------------------------------------------------------------------- +# Tests: Agent and PersistentAgent coexistence +# --------------------------------------------------------------------------- + + +class TestAgentPersistentAgentCoexistence: + """Plain Agent and PersistentAgent work side-by-side.""" + + def test_plain_and_persistent_agent_in_same_handler(self, tmp_path: Path): + """Both agent types work under the same LiteLLMProvider.""" + + class PlainBot(Agent): + """Plain bot.""" + + @Template.define + def ask(self, q: str) -> str: + """Q: {q}""" + raise NotHandled + + mock = MockCompletionHandler( + [make_text_response("plain-reply"), make_text_response("persist-reply")] + ) + plain = PlainBot() + persistent = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + r1 = plain.ask("hello") + r2 = persistent.send("hello") + + assert r1 == "plain-reply" + assert r2 == "persist-reply" + # Only the PersistentAgent gets a checkpoint entry + assert has_checkpoint(tmp_path, "ChatBot") + assert not has_checkpoint(tmp_path, plain.__agent_id__) + + def test_persistent_agent_tool_calls_plain_agent(self, tmp_path: Path): + """A PersistentAgent's tool can delegate to a plain Agent. + + Mock response sequence: + 0: outer → tool_call(self__delegate, {"q": "sub-task"}) + 1: inner plain agent → "inner-answer" + 2: outer → "final-answer" (after getting tool result) + """ + + class InnerPlainAgent(Agent): + """Inner helper agent.""" + + @Template.define + def answer(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + inner = InnerPlainAgent() + + class OuterPersistent(PersistentAgent): + """Outer persistent agent that delegates via tool.""" + + @Tool.define + def delegate(self, q: str) -> str: + """Delegate a sub-question to an inner agent.""" + return inner.answer(q) + + @Template.define + def process(self, task: str) -> str: + """Process: {task}. Use `delegate` for sub-questions.""" + raise NotHandled + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__delegate", '{"q": "sub-task"}'), + make_text_response("inner-answer"), + make_text_response("final-answer"), + ] + ) + outer = OuterPersistent(agent_id="outer") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = outer.process("do it") + + assert result == "final-answer" + data = read_checkpoint(tmp_path, "outer") + assert data["agent_id"] == "outer" + + def test_plain_agent_tool_calls_persistent_agent(self, tmp_path: Path): + """A plain Agent's tool can delegate to a PersistentAgent. + + Mock response sequence: + 0: outer plain → tool_call(self__delegate, {"q": "sub"}) + 1: inner persistent → "persisted-answer" + 2: outer plain → "done" (after getting tool result) + """ + + inner = ChatBot(agent_id="inner-bot") + + class OuterPlain(Agent): + """Plain agent that delegates to a persistent agent.""" + + @Tool.define + def delegate(self, q: str) -> str: + """Delegate to persistent bot.""" + return inner.send(q) + + @Template.define + def run(self, task: str) -> str: + """Run: {task}. Use `delegate` for sub-tasks.""" + raise NotHandled + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__delegate", '{"q": "sub"}'), + make_text_response("persisted-answer"), + make_text_response("done"), + ] + ) + outer = OuterPlain() + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = outer.run("go") + + assert result == "done" + + def test_two_persistent_agents_cooperate(self, tmp_path: Path): + """Two PersistentAgents with different IDs work independently. + + Mock response sequence: + 0: planner → "the plan" + 1: executor → "executed" + """ + mock = MockCompletionHandler( + [make_text_response("the plan"), make_text_response("executed")] + ) + + planner = ChatBot(agent_id="planner") + executor = ChatBot(agent_id="executor") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + plan = planner.send("make a plan") + result = executor.send(f"execute: {plan}") + + assert plan == "the plan" + assert result == "executed" + + # Each has independent history + planner_data = read_checkpoint(tmp_path, "planner") + executor_data = read_checkpoint(tmp_path, "executor") + assert len(planner_data["history"]) > 0 + assert len(executor_data["history"]) > 0 + + +# --------------------------------------------------------------------------- +# Tests: SQLite crash tolerance +# --------------------------------------------------------------------------- + + +class TestSQLiteCrashTolerance: + """SQLite-backed persistence is crash tolerant and restartable.""" + + def test_wal_mode_enabled(self, tmp_path: Path): + """Database uses WAL journal mode for crash tolerance.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + persist.save(bot) + + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + conn.close() + assert mode == "wal" + + def test_database_survives_incomplete_write(self, tmp_path: Path): + """Prior committed data survives if a subsequent write is interrupted.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + + # First write: commit a valid checkpoint + with handler(AgentHistoryHandler()): + history = get_agent_history(bot.__agent_id__) + history["msg1"] = {"id": "msg1", "role": "user", "content": "hello"} + persist.save(bot) + + # Simulate an interrupted write by opening a connection, beginning + # a write, then rolling back (mimicking a crash before commit). + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + conn.execute( + "UPDATE checkpoints SET handoff = 'interrupted' WHERE agent_id = ?", + ("ChatBot",), + ) + # Do NOT commit — simulate crash by closing without commit + conn.close() + + # The prior committed data should still be intact + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "" + assert len(data["history"]) == 1 + + def test_database_integrity_after_multiple_saves(self, tmp_path: Path): + """Multiple rapid saves produce a consistent database.""" + mock = MockCompletionHandler( + [make_text_response(f"reply-{i}") for i in range(3)] + ) + bot = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + for i in range(3): + bot.send(f"msg-{i}") + + # Verify integrity + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + conn.close() + assert result == "ok" + + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "" + assert len(data["history"]) > 0 + + def test_recovery_from_crash_mid_template_call(self, tmp_path: Path): + """After a crash mid-call, the DB has a handoff and can be reloaded.""" + + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *args, **kwargs): + raise RuntimeError("process killed") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("important task") + + # Verify the DB is consistent and has the crash handoff + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert result == "ok" + conn.close() + + data = read_checkpoint(tmp_path, "ChatBot") + assert "Executing send" in data["handoff"] + + # New process can load and resume + mock2 = MockCompletionHandler([make_text_response("recovered")]) + bot2 = ChatBot(agent_id="ChatBot") + with ( + handler(LiteLLMProvider()), + handler(mock2), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = bot2.send("resume") + + assert result == "recovered" + data2 = read_checkpoint(tmp_path, "ChatBot") + assert data2["handoff"] == "" + + def test_multiple_agents_single_db(self, tmp_path: Path): + """All agents share one DB file, not separate JSON files.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bots = [ChatBot(agent_id=f"bot-{i}") for i in range(5)] + + with handler(AgentHistoryHandler()): + for bot in bots: + persist.save(bot) + + # Only one DB file, no JSON files + assert (tmp_path / "checkpoints.db").exists() + assert not list(tmp_path.glob("*.json")) + + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0] + conn.close() + assert count == 5 + + def test_save_is_idempotent(self, tmp_path: Path): + """Saving the same agent multiple times updates rather than duplicates.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + + with handler(AgentHistoryHandler()): + persist.save(bot) + persist.save(bot, handoff="updated handoff") + persist.save(bot, handoff="final handoff") + + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + count = conn.execute( + "SELECT COUNT(*) FROM checkpoints WHERE agent_id = ?", ("ChatBot",) + ).fetchone()[0] + conn.close() + assert count == 1 + + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "final handoff" + + +# --------------------------------------------------------------------------- +# Tests: Thread safety +# --------------------------------------------------------------------------- + + +class TestThreadSafety: + """PersistenceHandler is safe to use from multiple threads.""" + + def test_concurrent_saves_from_threads(self, tmp_path: Path): + """Multiple threads saving different agents concurrently don't corrupt the DB.""" + import concurrent.futures + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + errors: list[Exception] = [] + + def save_agent(agent_id: str) -> None: + try: + bot = ChatBot(agent_id=agent_id) + hist = AgentHistoryHandler() + with handler(hist): + history = get_agent_history(agent_id) + for j in range(3): + history[f"{agent_id}-msg{j}"] = { + "id": f"{agent_id}-msg{j}", + "role": "user", + "content": f"msg {j} from {agent_id}", + } + persist.save(bot) + except Exception as e: + errors.append(e) + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: + futures = [pool.submit(save_agent, f"agent-{i}") for i in range(8)] + concurrent.futures.wait(futures) + + assert errors == [], f"Thread errors: {errors}" + + # Verify DB integrity and all agents saved + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert result == "ok" + count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0] + conn.close() + assert count == 8 + + def test_concurrent_reads_and_writes(self, tmp_path: Path): + """Readers and writers can operate concurrently without errors.""" + import concurrent.futures + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + errors: list[Exception] = [] + + # Seed some data + with handler(AgentHistoryHandler()): + for i in range(4): + bot = ChatBot(agent_id=f"agent-{i}") + persist.save(bot) + + def writer(agent_id: str) -> None: + try: + bot = ChatBot(agent_id=agent_id) + with handler(AgentHistoryHandler()): + history = get_agent_history(agent_id) + history["update"] = { + "id": "update", + "role": "user", + "content": "updated", + } + persist.save(bot) + except Exception as e: + errors.append(e) + + def reader(agent_id: str) -> None: + try: + # Verify checkpoint is readable via direct DB access + read_checkpoint(tmp_path, agent_id) + except Exception as e: + errors.append(e) + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: + futures = [] + for i in range(4): + futures.append(pool.submit(writer, f"agent-{i}")) + futures.append(pool.submit(reader, f"agent-{i}")) + concurrent.futures.wait(futures) + + assert errors == [], f"Thread errors: {errors}"