diff --git a/AGENTS.md b/AGENTS.md index 39b906a..3d5a46e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -38,22 +38,23 @@ Override the LLM endpoint with `OPENAI_BASE_URL` (defaults to `https://api.opena **Module map:** - [src/lib.rs](src/lib.rs) — module re-exports. -- [src/builder.rs](src/builder.rs) — `AgentWorkerBuilder` fluent builder; wires LLM + tools into a Temporal `Worker`. +- [src/builder.rs](src/builder.rs) — `AgentWorkerBuilder` fluent builder; wires LLM + tools + memory provider into a Temporal `Worker`. - [src/workflow.rs](src/workflow.rs) — `AgentWorkflow` with `#[run]`, `#[signal] add_user_message`, `#[query] get_state`, `#[query] turn_count`. Owns the ReAct loop. - [src/activities.rs](src/activities.rs) — `AgentActivities::llm_chat` and `AgentActivities::execute_tool`. The *only* place LLM providers and tool implementations execute. - [src/llm.rs](src/llm.rs) — translation between local `Message`/`ToolSchema` types and AutoAgents `ChatMessage`/`LlmTool`; native-tool-call parsing with fenced-JSON fallback. The only file that touches `autoagents_llm` types in the hot path (`src/llm.rs:6`). -- [src/state.rs](src/state.rs) — `AgentInput`, `AgentOutput`, `AgentState`, `Message`, `ToolCall`, `ToolResult`, `ToolSchema`, `LlmResponse`, `StopReason`, plus `compact()`. +- [src/state.rs](src/state.rs) — `AgentInput`, `AgentOutput`, `AgentState`, `Message`, `ToolCall`, `ToolResult`, `ToolSchema`, `LlmResponse`, `StopReason`. +- [src/memory.rs](src/memory.rs) — `MemoryProvider` trait, default `SlidingWindowMemory` impl, and the `compact_sliding_window` kernel. Pluggable compaction strategy consulted by the workflow before every turn. - [src/tool.rs](src/tool.rs) — `ToolRegistry` (immutable name→impl map) and its builder. - [src/error.rs](src/error.rs) — `AgentError` with `is_retryable()` to distinguish transient vs. permanent. -- [src/prelude.rs](src/prelude.rs) — convenience re-exports including AutoAgents traits (`ToolT`, `LLMProvider`, `ToolRuntime`, `ToolCallError`). +- [src/prelude.rs](src/prelude.rs) — convenience re-exports including AutoAgents traits (`ToolT`, `LLMProvider`, `ToolRuntime`, `ToolCallError`) and memory types (`MemoryProvider`, `SlidingWindowMemory`). -**Public API surface (what a user actually touches):** `AgentWorkerBuilder`, `AgentWorkflow`, `AgentInput`/`AgentOutput`, `ToolRegistry`. Users supply their own `Arc` and `Arc` from AutoAgents. +**Public API surface (what a user actually touches):** `AgentWorkerBuilder`, `AgentWorkflow`, `AgentInput`/`AgentOutput`, `ToolRegistry`, `MemoryProvider`/`SlidingWindowMemory`. Users supply their own `Arc` and `Arc` from AutoAgents. **Non-obvious behaviors to preserve when editing:** -- **History compaction.** When `AgentState::history.len()` exceeds `CONTINUE_AS_NEW_THRESHOLD = 200` (`src/workflow.rs:36`), the workflow calls `continue_as_new` with a compacted state: summary prepended to the system prompt, last 20 messages kept (`src/state.rs::compact`). Any change to the message shape needs to round-trip through `compact()`. +- **History compaction is pluggable.** The workflow consults `MemoryProvider::should_compact` before every turn; on `true` it calls `MemoryProvider::compact` and `continue_as_new` with the returned `AgentInput`. Default provider is `SlidingWindowMemory` (`compact_threshold = 200`, `keep_recent = 20`), preserving the legacy hardcoded behavior. Override via `AgentWorkerBuilder::memory(Arc::new(SlidingWindowMemory::new().with_compact_threshold(N).with_keep_recent(K)))` or supply your own `Arc`. Trait impls MUST be pure and sync — they run inside the deterministic workflow body. The kernel summarizer lives at `src/memory.rs::compact_sliding_window`; any change to the `Message` shape needs to round-trip through it. - **Tool error semantics.** Tool-side failures return `Ok(ToolResult { error: Some(...) })` so the LLM can see and recover from them (`src/activities.rs:59-88`). Only infrastructure errors (missing tool, serde failure) surface as activity `Err`, which Temporal retries. -- **`WORKER_TOOL_CATALOG`.** A process-global `OnceCell` set once at worker init in `build_worker` (`src/builder.rs:34`). The deterministic workflow body reads it on every replay, so it must be set before the worker starts and never mutated after. +- **Process-global worker config (`WORKER_TOOL_CATALOG`, `WORKER_MEMORY`).** Two `OnceCell`s in `src/builder.rs` published by `build_worker`. The deterministic workflow body reads them on every replay, so they must be set before the worker starts and never mutated after. Building a second worker in the same process with a *different* catalog (compared by `PartialEq`) or a different memory `Arc` (compared by `Arc::ptr_eq`) returns `AgentError::Other` — multi-worker setups in one process must share the same `Arc` and register identical tools in the same order. - **Activity timeouts.** Set inside `AgentWorkflow::run` at `src/workflow.rs:66-74`: LLM activity 120s start-to-close / 30s heartbeat, tool activity **3600s** start-to-close (generous on purpose — supports human-in-the-loop tools that block on stdin/HTTP/async-completion). - **Mid-conversation user input.** The `add_user_message` signal pushes into `pending_user_messages`, drained at the top of each loop iteration (`src/workflow.rs:145-152`). Don't mutate `history` directly from signal handlers — that races with the in-flight `llm_chat` activity. - **Dual LLM response parsing.** `src/llm.rs` tries native tool calls first, then falls back to a fenced `\`\`\`tool_calls` JSON block so non-OpenAI providers still work. @@ -62,6 +63,34 @@ Override the LLM endpoint with `OPENAI_BASE_URL` (defaults to `https://api.opena - Tools must be side-effect-safe on retry. - `LLMProvider` and `ToolT` impls must be `Send + Sync + 'static`. - Never invoke `LLMProvider` or `ToolT` from workflow code — only from activities. +- `MemoryProvider` impls must be pure, sync, and stateless (config-only) — `should_compact` and `compact` are called inside the workflow body and must return identical results on replay for the same `AgentState`. + +## Documentation maintenance + +After any change large enough to alter the public API surface, observable +behavior, defaults, or feature set, update the user-facing docs in the same +PR — stale docs are worse than no docs because they actively mislead. +Specifically: + +- **[AGENTS.md](AGENTS.md)** (this file) — update the module map, public API + surface line, non-obvious behaviors, and determinism contract whenever any + of them change. Add new module entries here as soon as you create them. +- **[README.md](README.md)** — update the features list, examples list, + user-facing determinism contract, and any code snippets affected by API + changes. If you add a feature with its own knobs (caching, fallback, + memory backends, etc.), give it a short dedicated section like + "Pluggable memory backends" so users can find it without reading the + source. +- **[examples/](examples/)** — when adding a new top-level feature, ship + one runnable example that exercises it (model the new example on + `simple_math_agent` — same worker/client/status mode template, same + three-terminal flow). Register it in `Cargo.toml` under a new + `[[example]]` entry and add a one-line description plus a runnable + invocation to README.md's "Running the examples" block. + +Rule of thumb: if you touched `src/lib.rs` re-exports, `src/prelude.rs`, or +`AgentWorkerBuilder`'s public API, you owe at least one edit to each of the +three above. ## Version pins diff --git a/Cargo.toml b/Cargo.toml index bf1f8ba..bf4daa7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,3 +70,7 @@ path = "examples/pipelined_math_agent/main.rs" [[example]] name = "structured_output_agent" path = "examples/structured_output_agent/main.rs" + +[[example]] +name = "tunable_memory_agent" +path = "examples/tunable_memory_agent/main.rs" diff --git a/README.md b/README.md index 363c469..784635e 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,10 @@ LLM tokens. - `AgentWorkerBuilder` for one-line worker setup. - Provider-agnostic: bring your own `Arc` (OpenAI, Anthropic, Ollama, etc. — anything supported by `autoagents_llm`). +- **Pluggable memory backends** via the `MemoryProvider` trait — default + `SlidingWindowMemory` matches the legacy hardcoded behavior; swap in + custom strategies through `AgentWorkerBuilder::memory`. See + [Pluggable memory backends](#pluggable-memory-backends). - **Human-in-the-loop as a regular tool** — the library does not special-case any tool name. See [Human-in-the-loop tools](#human-in-the-loop-tools). @@ -159,9 +163,50 @@ by it. See `examples/pipelined_math_agent` for a runnable demo (`add` tool + `PipelineBuilder(CacheLayer → FallbackLayer)` around two OpenAI models). +## Pluggable memory backends + +History compaction is governed by an `Arc` published +to the worker via `AgentWorkerBuilder::memory`. The default — used when +`.memory(...)` is not called — is `SlidingWindowMemory` with +`compact_threshold = 200` and `keep_recent = 20`, which matches the +legacy hardcoded behavior. + +```rust,ignore +use std::sync::Arc; +use temporal_agent_rs::prelude::*; + +let memory: Arc = Arc::new( + SlidingWindowMemory::new() + .with_compact_threshold(50) + .with_keep_recent(10), +); + +AgentWorkerBuilder::new(client) + .llm(llm) + .tool(my_tool) + .memory(memory) + .build_worker(&runtime)?; +``` + +**Trait contract.** Implementations MUST be pure and synchronous — +`should_compact` and `compact` run inside the deterministic workflow +body and must return identical results on every replay for the same +`AgentState`. Per-conversation state belongs in `AgentState` (which +Temporal persists in workflow history), never in fields on the provider. + +**Multi-worker setups.** Running multiple workers in the same process on +the same queue requires sharing the *same* `Arc` — +the builder fails fast (via `Arc::ptr_eq`) on mismatching instances to +prevent the second worker from silently inheriting the first worker's +provider while replay diverges. + +See `examples/tunable_memory_agent` for a runnable demo of a tuned +`SlidingWindowMemory` plus a minimal custom `MemoryProvider` impl +(`KeepEverythingMemory`) gated behind a `KEEP_EVERYTHING=1` env switch. + ## Running the examples -Three examples ship with the crate: +Five examples ship with the crate: - `simple_math_agent` — minimal autonomous loop with a single `add` tool. - `interactive_math_agent` — adds an `ask_user` tool so the agent can pause @@ -169,6 +214,11 @@ Three examples ship with the crate: - `pipelined_math_agent` — same `add` tool, but the provider is wrapped with `PipelineBuilder → CacheLayer → FallbackLayer` to demonstrate the composition pattern described above. +- `structured_output_agent` — forces a JSON-schema-shaped final answer via + `AgentInput::output_schema`. +- `tunable_memory_agent` — demonstrates a tuned `SlidingWindowMemory` and a + custom `MemoryProvider` impl; aggressive thresholds make `continue_as_new` + compaction observable on a short conversation. ```bash # Terminal 1: local Temporal dev server (install via `brew install temporal` or temporal.io) @@ -190,6 +240,17 @@ cargo run --example interactive_math_agent -- client # run the client twice with the same prompt to observe the cache layer. OPENAI_API_KEY=sk-... cargo run --example pipelined_math_agent -- worker cargo run --example pipelined_math_agent -- client + +# Structured output — final answer constrained by a JSON schema. +OPENAI_API_KEY=sk-... cargo run --example structured_output_agent -- worker +cargo run --example structured_output_agent -- client + +# Pluggable memory backends — aggressive SlidingWindowMemory so compaction +# fires mid-run. Use the `status` sub-command to watch history.len() and the +# "Prior conversation summary" marker appear in the system prompt. +OPENAI_API_KEY=sk-... cargo run --example tunable_memory_agent -- worker +cargo run --example tunable_memory_agent -- client +cargo run --example tunable_memory_agent -- status ``` The Temporal Web UI is at http://localhost:8233. Click into the workflow to @@ -319,6 +380,10 @@ When you write tools and provider configs: - Never call your `LLMProvider` or your `ToolT` from inside workflow code. The workflow holds tools by name; the only path to invocation is the `execute_tool` activity. +- `MemoryProvider` impls must be **pure, sync, and stateless** (config + only) — `should_compact` and `compact` run inside the workflow body and + must return identical results on replay for the same `AgentState`. Keep + conversation state in `AgentState`, never on the provider. ## Version compatibility diff --git a/examples/tunable_memory_agent/main.rs b/examples/tunable_memory_agent/main.rs new file mode 100644 index 0000000..04de5fb --- /dev/null +++ b/examples/tunable_memory_agent/main.rs @@ -0,0 +1,279 @@ +#![allow(clippy::large_futures)] +//! Pluggable memory backend demo. +//! +//! Same ReAct loop as `simple_math_agent`, but the worker passes a +//! deliberately aggressive [`SlidingWindowMemory`] config to +//! [`AgentWorkerBuilder::memory`] so you can observe history compaction +//! firing on a short conversation instead of waiting for the 200-message +//! default to kick in. +//! +//! ## What to watch +//! +//! Compaction triggers when `AgentState::history.len() > compact_threshold`. +//! This example uses `compact_threshold = 10` and `keep_recent = 4`, so after +//! roughly four reasoning turns (system + user + a handful of +//! assistant/tool exchanges) the workflow will: +//! +//! 1. Build a synthetic summary of older messages, +//! 2. Call `continue_as_new` with that summary prepended to the system +//! prompt and only the most recent 4 messages retained, +//! 3. Resume with a fresh event history (and the turn counter reset to 0). +//! +//! Inspect this via `status` while the workflow is running. After a +//! `continue_as_new` the Temporal Web UI will show a new run with a much +//! shorter event history and the new run's `input.system_prompt` will +//! contain the `"[Prior conversation summary, N messages dropped]"` block. +//! +//! ## Writing your own provider +//! +//! [`KeepEverythingMemory`] below is a complete custom `MemoryProvider` +//! implementation in 8 lines. Run the worker with `KEEP_EVERYTHING=1` to +//! activate it — useful for short-lived debugging sessions where you don't +//! want compaction at all. (Don't ship this to production: event history +//! grows unbounded.) +//! +//! Run with three terminals: +//! +//! ```bash +//! # Terminal 1: local Temporal server +//! temporal server start-dev +//! +//! # Terminal 2: worker +//! OPENAI_API_KEY=sk-... cargo run --example tunable_memory_agent -- worker +//! +//! # Terminal 3: client (multi-step prompt that should trigger compaction) +//! cargo run --example tunable_memory_agent -- client +//! +//! # Terminal 3 again, while the workflow is mid-flight: +//! cargo run --example tunable_memory_agent -- status +//! ``` + +use std::sync::Arc; + +use async_trait::async_trait; +use autoagents_core::tool::{ToolCallError, ToolInputT, ToolRuntime}; +use autoagents_derive::{ToolInput, tool}; +use autoagents_llm::backends::openai::OpenAI; +use autoagents_llm::builder::LLMBuilder; +use serde::Deserialize; +use serde_json::{Value, json}; +use temporal_agent_rs::prelude::*; +use temporalio_client::{ + Client, ClientOptions, Connection, WorkflowGetResultOptions, WorkflowQueryOptions, + WorkflowStartOptions, envconfig::LoadClientConfigProfileOptions, +}; +use temporalio_common::telemetry::TelemetryOptions; +use temporalio_sdk_core::{CoreRuntime, RuntimeOptions}; + +const WORKFLOW_ID: &str = "tunable-memory-demo-1"; + +#[derive(Deserialize, ToolInput)] +struct AddArgs { + #[input(description = "First addend")] + a: f64, + #[input(description = "Second addend")] + b: f64, +} + +#[tool(name = "add", description = "Add two numbers", input = AddArgs)] +#[derive(Default, Clone)] +struct Add; + +#[async_trait] +impl ToolRuntime for Add { + async fn execute(&self, args: Value) -> Result { + let parsed: AddArgs = serde_json::from_value(args)?; + Ok(json!({ "sum": parsed.a + parsed.b })) + } +} + +/// Example custom [`MemoryProvider`] that never compacts. +/// +/// Useful for short debugging runs where you want the full conversation +/// to survive in workflow history. NOT suitable for long-running +/// workflows — event history will grow unbounded. +/// +/// Implementations MUST be pure and synchronous: `should_compact` and +/// `compact` are invoked inside the deterministic workflow body and have +/// to return identical results on every replay for the same `AgentState`. +#[derive(Debug)] +struct KeepEverythingMemory; + +impl MemoryProvider for KeepEverythingMemory { + fn should_compact(&self, _state: &AgentState) -> bool { + false + } + + fn compact(&self, state: &AgentState) -> AgentInput { + // Never called (should_compact always returns false), but the trait + // requires an impl. Returning the original input is the safest + // no-op. + state.input.clone() + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info,temporal_agent_rs=debug".into()), + ) + .init(); + + let mode = std::env::args().nth(1).unwrap_or_else(|| "worker".into()); + let client = connect().await?; + + match mode.as_str() { + "worker" => run_worker(client).await, + "client" => run_client(client).await, + "status" => run_status(client).await, + other => Err(anyhow::anyhow!( + "unknown mode '{other}', expected one of: worker | client | status" + )), + } +} + +async fn connect() -> anyhow::Result { + let (conn_opts, client_opts) = + ClientOptions::load_from_config(LoadClientConfigProfileOptions::default()) + .map_err(|e| anyhow::anyhow!("load client config: {e}"))?; + let connection = Connection::connect(conn_opts).await?; + Ok(Client::new(connection, client_opts)?) +} + +async fn run_worker(client: Client) -> anyhow::Result<()> { + let api_key = std::env::var("OPENAI_API_KEY") + .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY must be set for the worker"))?; + + let base_url = + std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".into()); + + let llm: Arc = LLMBuilder::::new() + .api_key(api_key) + .base_url(base_url) + .model("gpt-4o-mini") + .build()?; + + let runtime_opts = RuntimeOptions::builder() + .telemetry_options(TelemetryOptions::builder().build()) + .build() + .map_err(|e| anyhow::anyhow!("build runtime options: {e}"))?; + let runtime = CoreRuntime::new_assume_tokio(runtime_opts)?; + + // Aggressive compaction so a short demo conversation actually trips it. + // Production defaults are 200 / 20; bumping `compact_threshold` down to + // 10 means continue_as_new will fire after only a few tool exchanges. + // + // Set `KEEP_EVERYTHING=1` to swap in the custom never-compact provider + // instead (see [`KeepEverythingMemory`]). + let keep_everything = std::env::var("KEEP_EVERYTHING").is_ok(); + let memory: Arc = if keep_everything { + Arc::new(KeepEverythingMemory) + } else { + Arc::new( + SlidingWindowMemory::new() + .with_keep_recent(4) + .with_compact_threshold(10), + ) + }; + + let mut worker = AgentWorkerBuilder::new(client) + .llm(llm) + .tool(Arc::new(Add)) + .queue("agents") + .memory(memory) + .build_worker(&runtime)?; + + if keep_everything { + tracing::info!("worker started on queue 'agents' (KeepEverythingMemory — no compaction)"); + } else { + tracing::info!( + compact_threshold = 10, + keep_recent = 4, + "worker started on queue 'agents' (aggressive memory compaction)" + ); + } + worker.run().await?; + Ok(()) +} + +async fn run_client(client: Client) -> anyhow::Result<()> { + let input = AgentInput { + system_prompt: + "You are a meticulous math assistant. You can call the `add` tool to compute. \ + Call the tool ONE addition at a time — never combine multiple operations into one \ + call. After each tool result, state the running total before the next call. \ + When all additions are done, output the final answer in words." + .into(), + // Chain of additions chosen so the model is forced to issue several + // sequential tool calls, each appending an assistant+tool pair to + // history — enough exchanges to trip a 10-message compact_threshold. + user_message: "Compute step by step: 1.5 + 2.5, then add 3, then add 4.25, then add 5.75. \ + Give me the final sum." + .into(), + max_turns: 12, + output_schema: None, + }; + + let handle = client + .start_workflow( + AgentWorkflow::run, + input, + WorkflowStartOptions::new("agents", WORKFLOW_ID).build(), + ) + .await?; + + tracing::info!( + workflow_id = WORKFLOW_ID, + "started workflow; awaiting result (compaction will fire mid-run)" + ); + + let out: AgentOutput = handle + .get_result(WorkflowGetResultOptions::default()) + .await?; + println!(); + println!("=== AgentOutput ==="); + println!("final_answer : {}", out.final_answer); + println!("stop_reason : {:?}", out.stop_reason); + println!("turns_used : {}", out.turns_used); + println!("tool_calls : {}", out.tool_calls); + println!(); + println!("Note: turns_used / tool_calls reflect ONLY the post-compaction run."); + println!("Earlier turns were folded into the system prompt summary."); + Ok(()) +} + +/// Inspect the running workflow's live state. Useful for catching the +/// continue-as-new boundary: before compaction `history.len()` grows past +/// 10; immediately after, it drops to `keep_recent (4) + system + user` +/// and `turn` resets to 0. +async fn run_status(client: Client) -> anyhow::Result<()> { + let handle = client.get_workflow_handle::(WORKFLOW_ID.to_string()); + let state: AgentState = handle + .query( + AgentWorkflow::get_state, + (), + WorkflowQueryOptions::default(), + ) + .await?; + println!("turn : {}", state.turn); + println!("tool_calls : {}", state.tool_calls_executed); + println!("history.len(): {}", state.history.len()); + let sys = state.input.system_prompt.as_str(); + let has_summary = sys.contains("Prior conversation summary"); + println!( + "compacted? : {}", + if has_summary { "yes" } else { "no (yet)" } + ); + println!("history (last 4):"); + for msg in state.history.iter().rev().take(4).rev() { + let preview = if msg.content.len() > 120 { + format!("{}…", &msg.content[..120]) + } else { + msg.content.clone() + }; + println!(" [{:?}] {preview}", msg.role); + } + Ok(()) +} diff --git a/src/builder.rs b/src/builder.rs index 235523f..74c2240 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -16,10 +16,11 @@ use autoagents_llm::LLMProvider; use temporalio_client::Client; use temporalio_sdk::{Worker, WorkerOptions}; use temporalio_sdk_core::CoreRuntime; -use tokio::sync::OnceCell; +use tokio::sync::{OnceCell, SetError}; use crate::activities::AgentActivities; use crate::error::AgentError; +use crate::memory::{MemoryProvider, SlidingWindowMemory}; use crate::state::ToolSchema; use crate::tool::{ToolRegistry, ToolRegistryBuilder}; use crate::workflow::AgentWorkflow; @@ -33,6 +34,12 @@ use crate::workflow::AgentWorkflow; /// reads return the same value on every replay). pub(crate) static WORKER_TOOL_CATALOG: OnceCell> = OnceCell::const_new(); +/// Memory provider made available to the workflow side. +/// +/// Same replay-safety story as [`WORKER_TOOL_CATALOG`]: set once at worker +/// start, read deterministically from the workflow body on every replay. +pub(crate) static WORKER_MEMORY: OnceCell> = OnceCell::const_new(); + /// Fluent builder for an agent-aware Temporal [`Worker`]. /// /// Required calls: [`AgentWorkerBuilder::new`] → [`AgentWorkerBuilder::llm`] @@ -43,6 +50,7 @@ pub struct AgentWorkerBuilder { llm: Option>, tools: ToolRegistryBuilder, queue: String, + memory: Option>, } impl AgentWorkerBuilder { @@ -54,6 +62,7 @@ impl AgentWorkerBuilder { llm: None, tools: ToolRegistry::builder(), queue: "agents".to_string(), + memory: None, } } @@ -78,6 +87,15 @@ impl AgentWorkerBuilder { self } + /// Override the memory provider. Defaults to + /// [`SlidingWindowMemory::default`], which preserves the legacy hardcoded + /// compaction behavior. + #[must_use] + pub fn memory(mut self, memory: Arc) -> Self { + self.memory = Some(memory); + self + } + /// Construct the Temporal worker with `AgentWorkflow` + `AgentActivities` /// registered. /// @@ -94,9 +112,61 @@ impl AgentWorkerBuilder { // loop register their own `ask_user`-style tool whose `execute()` // blocks until an answer is delivered (see the math_agent example). let catalog = registry.to_schemas(); - // OnceCell::set is fallible if already set; in long-running test - // processes we tolerate re-initialization with the same data. - let _ = WORKER_TOOL_CATALOG.set(catalog); + // OnceCell::set is fallible if already set. We tolerate re-init with + // the same data (a long-running test process rebuilding workers with + // identical config) but fail-fast on mismatch: silently binding a + // second worker to the first worker's catalog would break + // determinism the moment the second worker's workflows replayed. + if let Err(set_err) = WORKER_TOOL_CATALOG.set(catalog) { + let rejected = match set_err { + SetError::AlreadyInitializedError(v) => v, + SetError::InitializingError(_) => { + return Err(AgentError::Other( + "WORKER_TOOL_CATALOG is being initialized concurrently".into(), + )); + } + }; + let existing = WORKER_TOOL_CATALOG + .get() + .expect("AlreadyInitialized so get must succeed"); + if existing != &rejected { + return Err(AgentError::Other( + "WORKER_TOOL_CATALOG was previously initialized with a different tool \ + catalog; multiple workers in the same process must register identical \ + tools (and in the same order)" + .into(), + )); + } + } + + let memory: Arc = self + .memory + .unwrap_or_else(|| Arc::new(SlidingWindowMemory::default())); + // Same rationale as above. For memory we compare via `Arc::ptr_eq` + // because behavioral equality of a `dyn MemoryProvider` cannot be + // checked through the trait. Users running multiple workers in one + // process must therefore pass a shared `Arc` to every builder. + if let Err(set_err) = WORKER_MEMORY.set(memory.clone()) { + let rejected = match set_err { + SetError::AlreadyInitializedError(v) => v, + SetError::InitializingError(_) => { + return Err(AgentError::Other( + "WORKER_MEMORY is being initialized concurrently".into(), + )); + } + }; + let existing = WORKER_MEMORY + .get() + .expect("AlreadyInitialized so get must succeed"); + if !Arc::ptr_eq(existing, &rejected) { + return Err(AgentError::Other( + "WORKER_MEMORY was previously initialized with a different provider \ + instance; multiple workers in the same process must share the same \ + Arc" + .into(), + )); + } + } let activities = AgentActivities::new(llm, registry); diff --git a/src/lib.rs b/src/lib.rs index c46ab2b..2a765c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,6 +36,7 @@ pub mod activities; pub mod builder; pub mod error; pub mod llm; +pub mod memory; pub mod prelude; pub mod state; pub mod tool; @@ -44,6 +45,9 @@ pub mod workflow; pub use crate::activities::AgentActivities; pub use crate::builder::AgentWorkerBuilder; pub use crate::error::AgentError; +pub use crate::memory::{ + DEFAULT_COMPACT_THRESHOLD, DEFAULT_KEEP_RECENT, MemoryProvider, SlidingWindowMemory, +}; pub use crate::state::{ AgentInput, AgentOutput, AgentState, LlmChatInput, LlmResponse, Message, Role, StopReason, ToolCall, ToolResult, ToolSchema, diff --git a/src/memory.rs b/src/memory.rs new file mode 100644 index 0000000..dd47fa8 --- /dev/null +++ b/src/memory.rs @@ -0,0 +1,410 @@ +//! Pluggable history-compaction strategies for `AgentWorkflow`. +//! +//! The workflow consults a [`MemoryProvider`] before every reasoning turn to +//! decide whether to `continue_as_new` with a compacted state. The provider is +//! configured once per worker via +//! [`AgentWorkerBuilder::memory`](crate::AgentWorkerBuilder::memory); the +//! default is [`SlidingWindowMemory`], which reproduces the legacy hardcoded +//! behavior (drop everything older than the last 20 messages once history +//! exceeds 200). +//! +//! # Determinism contract +//! +//! Implementations run inside the deterministic workflow body. They MUST be: +//! +//! - **Pure**: no I/O, no clocks, no randomness — given the same `AgentState`, +//! `should_compact` and `compact` must return the same result on every +//! replay. +//! - **Synchronous**: the workflow body is sync-only. If you need network +//! memory (vector store, etc.), that work belongs in an activity called by +//! a future workflow-side `prepare` hook. +//! - **Stateless across calls**: all conversation state lives in `AgentState` +//! and is restored from event history. Provider instances are configuration +//! holders only. +//! +//! # Backward compatibility +//! +//! [`SlidingWindowMemory::default`] uses [`DEFAULT_COMPACT_THRESHOLD`] (200) +//! and [`DEFAULT_KEEP_RECENT`] (20), matching the constants that lived in +//! `src/workflow.rs` prior to this module. Workers that do not call +//! `.memory(...)` get this default and behave identically to earlier releases. + +use crate::state::{AgentInput, AgentState, Role}; + +/// Default history length above which [`SlidingWindowMemory`] compacts. +pub const DEFAULT_COMPACT_THRESHOLD: usize = 200; + +/// Default number of most-recent messages [`SlidingWindowMemory`] preserves +/// verbatim when compacting. +pub const DEFAULT_KEEP_RECENT: usize = 20; + +/// Strategy for deciding when to compact agent history and how to do it. +/// +/// See the module-level docs for the determinism contract every implementation +/// must uphold. +pub trait MemoryProvider: std::fmt::Debug + Send + Sync + 'static { + /// Called every iteration of the agent loop. Return `true` to trigger a + /// `continue_as_new` with the [`AgentInput`] returned by [`Self::compact`]. + /// + /// MUST be pure and deterministic — it is replayed verbatim from history. + fn should_compact(&self, state: &AgentState) -> bool; + + /// Produce the [`AgentInput`] that seeds the next workflow run after + /// `continue_as_new`. Only called when [`Self::should_compact`] returned + /// `true`. + /// + /// The returned value is serialized into workflow history; its byte shape + /// IS the compaction. MUST be pure and deterministic. + fn compact(&self, state: &AgentState) -> AgentInput; +} + +/// FIFO sliding-window compaction. Drops everything older than the last +/// `keep_recent` messages once history exceeds `compact_threshold`, prepending +/// a synthetic text summary of the dropped turns to the system prompt. +/// +/// This is the default provider and matches the pre-v0.2 hardcoded behavior. +#[derive(Debug, Clone)] +pub struct SlidingWindowMemory { + compact_threshold: usize, + keep_recent: usize, +} + +impl SlidingWindowMemory { + /// Construct with [`DEFAULT_COMPACT_THRESHOLD`] / [`DEFAULT_KEEP_RECENT`]. + #[must_use] + pub fn new() -> Self { + Self { + compact_threshold: DEFAULT_COMPACT_THRESHOLD, + keep_recent: DEFAULT_KEEP_RECENT, + } + } + + /// Override the history length at which compaction fires. + /// + /// Panics if `n <= keep_recent`, which would trigger compaction every + /// iteration. + #[must_use] + pub fn with_compact_threshold(mut self, n: usize) -> Self { + assert!( + n > self.keep_recent, + "compact_threshold ({}) must be greater than keep_recent ({})", + n, + self.keep_recent + ); + self.compact_threshold = n; + self + } + + /// Override the number of messages preserved verbatim after compaction. + /// + /// Panics if `n >= compact_threshold`, which would trigger compaction + /// every iteration. + #[must_use] + pub fn with_keep_recent(mut self, n: usize) -> Self { + assert!( + n < self.compact_threshold, + "keep_recent ({}) must be less than compact_threshold ({})", + n, + self.compact_threshold + ); + self.keep_recent = n; + self + } + + /// Inspect the current compaction trigger. + #[must_use] + pub fn compact_threshold(&self) -> usize { + self.compact_threshold + } + + /// Inspect the current keep-recent setting. + #[must_use] + pub fn keep_recent(&self) -> usize { + self.keep_recent + } +} + +impl Default for SlidingWindowMemory { + fn default() -> Self { + Self::new() + } +} + +impl MemoryProvider for SlidingWindowMemory { + fn should_compact(&self, state: &AgentState) -> bool { + state.history.len() > self.compact_threshold + } + + fn compact(&self, state: &AgentState) -> AgentInput { + compact_sliding_window(state, self.keep_recent) + } +} + +/// Pure sliding-window compaction kernel. +/// +/// Preserves the system prompt, summarizes everything before the last +/// `keep_recent` messages into a synthetic text block appended to the system +/// prompt, and threads through the original `max_turns` / `output_schema`. +/// +/// Exposed for advanced users who want to call the kernel directly from a +/// custom [`MemoryProvider`] without re-implementing the summarization format. +pub fn compact_sliding_window(state: &AgentState, keep_recent: usize) -> AgentInput { + let mut summary_lines = Vec::new(); + let total = state.history.len(); + let drop_until = total.saturating_sub(keep_recent); + + for msg in state.history.iter().take(drop_until) { + let line = match msg.role { + Role::System if summary_lines.is_empty() => continue, + Role::User => format!("user: {}", truncate(&msg.content, 200)), + Role::Assistant if !msg.tool_calls.is_empty() => { + let names: Vec<&str> = msg.tool_calls.iter().map(|c| c.name.as_str()).collect(); + format!("assistant: called tools [{}]", names.join(", ")) + } + Role::Assistant => format!("assistant: {}", truncate(&msg.content, 200)), + Role::Tool => format!("tool: {}", truncate(&msg.content, 120)), + Role::System => continue, + }; + summary_lines.push(line); + } + + let summary = if summary_lines.is_empty() { + String::new() + } else { + format!( + "\n\n[Prior conversation summary, {} messages dropped]\n{}", + drop_until, + summary_lines.join("\n") + ) + }; + + let recent_user = state + .history + .iter() + .rev() + .find(|m| m.role == Role::User) + .map(|m| m.content.clone()) + .unwrap_or_default(); + + AgentInput { + system_prompt: format!("{}{}", state.input.system_prompt, summary), + user_message: recent_user, + max_turns: state.input.max_turns, + output_schema: state.input.output_schema.clone(), + } +} + +fn truncate(s: &str, max: usize) -> String { + if s.len() <= max { + return s.to_string(); + } + let mut boundary = max; + while boundary > 0 && !s.is_char_boundary(boundary) { + boundary -= 1; + } + format!("{}…", &s[..boundary]) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::state::{Message, ToolCall}; + use autoagents_llm::chat::StructuredOutputFormat; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + + fn sample_schema() -> StructuredOutputFormat { + StructuredOutputFormat { + name: "weather_report".into(), + description: Some("Structured weather observation".into()), + schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "city": { "type": "string" }, + "temperature_c": { "type": "number" }, + }, + "required": ["city", "temperature_c"] + })), + strict: Some(true), + } + } + + fn populated_state(turns: u32, schema: Option) -> AgentState { + let mut state = AgentState::new(AgentInput { + system_prompt: "sys".into(), + user_message: "u0".into(), + max_turns: 50, + output_schema: schema, + }); + for i in 1..turns { + state.history.push(Message::user(format!("u{i}"))); + state.history.push(Message::assistant_text(format!("a{i}"))); + } + state + } + + #[test] + fn sliding_window_default_uses_published_constants() { + let m = SlidingWindowMemory::default(); + assert_eq!(m.compact_threshold(), DEFAULT_COMPACT_THRESHOLD); + assert_eq!(m.keep_recent(), DEFAULT_KEEP_RECENT); + } + + #[test] + fn sliding_window_should_compact_at_threshold_boundary() { + // Lower keep_recent first so the threshold setter's invariant check + // accepts the smaller threshold. + let m = SlidingWindowMemory::new() + .with_keep_recent(3) + .with_compact_threshold(10); + let mut state = AgentState::new(AgentInput::default()); + // AgentState::new seeds 2 messages (system + user); top up to 10. + while state.history.len() < 10 { + state.history.push(Message::user("x")); + } + assert!( + !m.should_compact(&state), + "len == threshold must not trigger" + ); + state.history.push(Message::user("x")); + assert!(m.should_compact(&state), "len > threshold must trigger"); + } + + #[test] + fn sliding_window_does_not_compact_short_history() { + let m = SlidingWindowMemory::default(); + let empty = AgentState::default(); + assert!(!m.should_compact(&empty)); + let small = AgentState::new(AgentInput { + system_prompt: "sys".into(), + user_message: "hi".into(), + max_turns: 5, + output_schema: None, + }); + assert!(!m.should_compact(&small)); + } + + #[test] + fn sliding_window_compact_preserves_system_prompt_and_recent_user() { + let state = populated_state(30, None); + let m = SlidingWindowMemory::new() + .with_compact_threshold(50) + .with_keep_recent(10); + let compacted = m.compact(&state); + assert!(compacted.system_prompt.starts_with("sys")); + assert!( + compacted + .system_prompt + .contains("Prior conversation summary") + ); + assert_eq!(compacted.max_turns, state.input.max_turns); + } + + #[test] + fn sliding_window_compact_preserves_output_schema() { + let schema = sample_schema(); + let state = populated_state(30, Some(schema.clone())); + let m = SlidingWindowMemory::new() + .with_compact_threshold(50) + .with_keep_recent(10); + let compacted = m.compact(&state); + assert_eq!(compacted.output_schema, Some(schema)); + } + + #[test] + fn sliding_window_compact_with_custom_keep_recent() { + // Tool-call assistant lines render with their tool names; verify the + // synthetic summary mentions the dropped tool call when keep_recent + // excludes it. + let mut state = AgentState::new(AgentInput { + system_prompt: "sys".into(), + user_message: "u0".into(), + max_turns: 50, + output_schema: None, + }); + state + .history + .push(Message::assistant_with_tools(vec![ToolCall { + id: "c1".into(), + name: "search".into(), + args: serde_json::json!({}), + }])); + for i in 1..20 { + state.history.push(Message::user(format!("u{i}"))); + state.history.push(Message::assistant_text(format!("a{i}"))); + } + let m = SlidingWindowMemory::new() + .with_compact_threshold(100) + .with_keep_recent(5); + let compacted = m.compact(&state); + assert!( + compacted + .system_prompt + .contains("assistant: called tools [search]"), + "summary should mention dropped tool call, got: {}", + compacted.system_prompt + ); + } + + #[test] + #[should_panic(expected = "keep_recent")] + fn sliding_window_panics_on_keep_recent_ge_threshold() { + let _ = SlidingWindowMemory::new() + .with_compact_threshold(50) + .with_keep_recent(100); + } + + #[test] + #[should_panic(expected = "compact_threshold")] + fn sliding_window_panics_on_threshold_le_keep_recent() { + let _ = SlidingWindowMemory::new() + .with_keep_recent(50) + .with_compact_threshold(10); + } + + #[test] + fn truncate_respects_utf8_char_boundary() { + // 'é' is 2 bytes; slicing at byte 2 would split it and panic. + let t = truncate("héllo world", 2); + assert_eq!(t, "h…"); + // Already short — returned unchanged. + assert_eq!(truncate("hi", 10), "hi"); + // Emoji boundary (4 bytes for 🦀). + assert_eq!(truncate("🦀rust", 2), "…"); + } + + #[derive(Debug, Default)] + struct PassthroughMemory { + invoked: AtomicBool, + } + + impl MemoryProvider for PassthroughMemory { + fn should_compact(&self, _state: &AgentState) -> bool { + self.invoked.store(true, Ordering::SeqCst); + false + } + + fn compact(&self, state: &AgentState) -> AgentInput { + state.input.clone() + } + } + + #[test] + fn custom_memory_provider_compiles_as_arc_dyn() { + // Compile-only — confirms the trait stays dyn-compatible. + let _: Arc = Arc::new(PassthroughMemory::default()); + } + + #[test] + fn custom_memory_provider_is_invoked_via_arc_dyn() { + let provider: Arc = Arc::new(PassthroughMemory::default()); + let state = AgentState::default(); + assert!(!provider.should_compact(&state)); + // Downcast not possible through `dyn`, but the side effect on the + // inner struct is observable via a separate Arc to the same impl. + let owned = Arc::new(PassthroughMemory::default()); + let typed: Arc = owned.clone(); + let _ = typed.should_compact(&state); + assert!(owned.invoked.load(Ordering::SeqCst)); + } +} diff --git a/src/prelude.rs b/src/prelude.rs index 0f40cfc..a811564 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -13,6 +13,7 @@ pub use crate::activities::AgentActivities; pub use crate::builder::AgentWorkerBuilder; pub use crate::error::AgentError; +pub use crate::memory::{MemoryProvider, SlidingWindowMemory}; pub use crate::state::{ AgentInput, AgentOutput, AgentState, LlmResponse, Message, Role, StopReason, ToolCall, ToolResult, ToolSchema, diff --git a/src/state.rs b/src/state.rs index 36a4475..dbbc4b6 100644 --- a/src/state.rs +++ b/src/state.rs @@ -197,73 +197,13 @@ pub struct LlmChatInput { } /// Description of a tool sent to the LLM so it knows what it can call. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ToolSchema { pub name: String, pub description: String, pub args_schema: serde_json::Value, } -/// Compact long histories so `continue_as_new` doesn't grow the event history -/// unbounded. Keeps the system prompt, a synthetic summary marker, and the -/// most recent `keep_recent` messages. -pub fn compact(state: &AgentState, keep_recent: usize) -> AgentInput { - let mut summary_lines = Vec::new(); - let total = state.history.len(); - let drop_until = total.saturating_sub(keep_recent); - - for msg in state.history.iter().take(drop_until) { - let line = match msg.role { - Role::System if summary_lines.is_empty() => continue, - Role::User => format!("user: {}", truncate(&msg.content, 200)), - Role::Assistant if !msg.tool_calls.is_empty() => { - let names: Vec<&str> = msg.tool_calls.iter().map(|c| c.name.as_str()).collect(); - format!("assistant: called tools [{}]", names.join(", ")) - } - Role::Assistant => format!("assistant: {}", truncate(&msg.content, 200)), - Role::Tool => format!("tool: {}", truncate(&msg.content, 120)), - Role::System => continue, - }; - summary_lines.push(line); - } - - let summary = if summary_lines.is_empty() { - String::new() - } else { - format!( - "\n\n[Prior conversation summary, {} messages dropped]\n{}", - drop_until, - summary_lines.join("\n") - ) - }; - - let recent_user = state - .history - .iter() - .rev() - .find(|m| m.role == Role::User) - .map(|m| m.content.clone()) - .unwrap_or_default(); - - AgentInput { - system_prompt: format!("{}{}", state.input.system_prompt, summary), - user_message: recent_user, - max_turns: state.input.max_turns, - output_schema: state.input.output_schema.clone(), - } -} - -fn truncate(s: &str, max: usize) -> String { - if s.len() <= max { - return s.to_string(); - } - let mut boundary = max; - while boundary > 0 && !s.is_char_boundary(boundary) { - boundary -= 1; - } - format!("{}…", &s[..boundary]) -} - #[cfg(test)] mod tests { use super::*; @@ -282,39 +222,6 @@ mod tests { assert_eq!(s.turn, 0); } - #[test] - fn compact_keeps_system_and_recent() { - let mut state = AgentState::new(AgentInput { - system_prompt: "sys".into(), - user_message: "u0".into(), - max_turns: 50, - output_schema: None, - }); - for i in 1..30 { - state.history.push(Message::user(format!("u{i}"))); - state.history.push(Message::assistant_text(format!("a{i}"))); - } - let compacted = compact(&state, 10); - assert!(compacted.system_prompt.starts_with("sys")); - assert!( - compacted - .system_prompt - .contains("Prior conversation summary") - ); - assert_eq!(compacted.max_turns, state.input.max_turns); - } - - #[test] - fn truncate_respects_utf8_char_boundary() { - // 'é' is 2 bytes; slicing at byte 2 would split it and panic. - let t = truncate("héllo world", 2); - assert_eq!(t, "h…"); - // Already short — returned unchanged. - assert_eq!(truncate("hi", 10), "hi"); - // Emoji boundary (4 bytes for 🦀). - assert_eq!(truncate("🦀rust", 2), "…"); - } - #[test] fn message_roundtrips_through_json() { let m = Message::assistant_with_tools(vec![ToolCall { @@ -386,20 +293,4 @@ mod tests { assert_eq!(parsed.max_turns, 3); assert!(parsed.output_schema.is_none()); } - - #[test] - fn compact_preserves_output_schema() { - let mut state = AgentState::new(AgentInput { - system_prompt: "sys".into(), - user_message: "u0".into(), - max_turns: 50, - output_schema: Some(sample_schema()), - }); - for i in 1..30 { - state.history.push(Message::user(format!("u{i}"))); - state.history.push(Message::assistant_text(format!("a{i}"))); - } - let compacted = compact(&state, 10); - assert_eq!(compacted.output_schema, state.input.output_schema); - } } diff --git a/src/workflow.rs b/src/workflow.rs index d7225cc..fcb0704 100644 --- a/src/workflow.rs +++ b/src/workflow.rs @@ -17,6 +17,7 @@ //! [`ToolT`]: autoagents_core::tool::ToolT //! [`LLMProvider`]: autoagents_llm::LLMProvider +use std::sync::Arc; use std::time::Duration; use temporalio_macros::{workflow, workflow_methods}; @@ -26,18 +27,11 @@ use temporalio_sdk::{ }; use crate::activities::AgentActivities; +use crate::memory::{MemoryProvider, SlidingWindowMemory}; use crate::state::{ - AgentInput, AgentOutput, AgentState, LlmChatInput, LlmResponse, Message, StopReason, compact, + AgentInput, AgentOutput, AgentState, LlmChatInput, LlmResponse, Message, StopReason, }; -/// Hard cap on history length before we `continue_as_new` to keep the event -/// history small. Picked conservatively; tune via the worker config in a -/// future release. -pub const CONTINUE_AS_NEW_THRESHOLD: usize = 200; - -/// How many recent messages [`compact`] keeps when rotating to a new run. -pub const COMPACT_KEEP_RECENT: usize = 20; - /// Durable AI agent workflow. /// /// Each invocation runs a ReAct loop until the model emits a final answer or @@ -82,8 +76,7 @@ impl AgentWorkflow { } }); - let (turn, max_turns, history_len) = - ctx.state(|s| (s.state.turn, s.state.input.max_turns, s.state.history.len())); + let (turn, max_turns) = ctx.state(|s| (s.state.turn, s.state.input.max_turns)); if turn >= max_turns { let out = ctx.state(|s| { @@ -92,8 +85,18 @@ impl AgentWorkflow { return Ok(out); } - if history_len > CONTINUE_AS_NEW_THRESHOLD { - let next_input = ctx.state(|s| compact(&s.state, COMPACT_KEEP_RECENT)); + // Resolve the memory provider once per iteration. In production + // this is set by `build_worker`; the `unwrap_or_else` fallback + // mirrors `WORKER_TOOL_CATALOG` usage below so workflow-only + // unit-test paths still work. + let memory: Arc = crate::builder::WORKER_MEMORY + .get() + .cloned() + .unwrap_or_else(|| Arc::new(SlidingWindowMemory::default())); + + if ctx.state(|s| memory.should_compact(&s.state)) { + let next_input = ctx.state(|s| memory.compact(&s.state)); + let history_len = ctx.state(|s| s.state.history.len()); tracing::info!(history_len, "compacting and continuing as new"); ctx.continue_as_new(&next_input, ContinueAsNewOptions::default())?; unreachable!(); // continue_as_new always returns Err diff --git a/tests/prelude_memory_exports.rs b/tests/prelude_memory_exports.rs new file mode 100644 index 0000000..84ca888 --- /dev/null +++ b/tests/prelude_memory_exports.rs @@ -0,0 +1,13 @@ +//! Compile-only assertions for the memory-provider re-exports in the prelude. +//! +//! Removing or renaming any of the listed items from `temporal_agent_rs::prelude` +//! will break this test crate, surfacing the change as an API break. + +#[allow(dead_code, unused_imports, unused_variables)] +fn _prelude_memory_exports_compile() { + use std::sync::Arc; + use temporal_agent_rs::prelude::{MemoryProvider, SlidingWindowMemory}; + + let _: SlidingWindowMemory = SlidingWindowMemory::default(); + let _: Arc = Arc::new(SlidingWindowMemory::default()); +}