diff --git a/CHANGELOG.md b/CHANGELOG.md index acb7b08..aff0215 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,32 @@ ## Unreleased +## 0.9.3 + +- Added: `load_workspace_manifest` and `refresh_context` MCP tools for + workspace-first analysis. Aleph can now bind contexts to refreshable + workspace files or generated manifests, which is a better default for large + codebases and long-lived projects than loading raw files ad hoc. +- Added: Workspace bindings now persist through memory-pack save/load, and + `get_status` / `list_contexts` expose binding metadata for MCP clients. +- Added: `ALEPH_ACTION_POLICY` / `--action-policy` with `read-write` and + `read-only` modes. Read-only mode keeps repo search and file loading + available while blocking writes and subprocess execution. +- Refactored: Workspace-oriented MCP behavior extracted into + `mcp/workspace_contexts.py` and `mcp/workspace_tools.py`, plus + `mcp/context_tools.py` for session/context MCP behavior, continuing the + modularization of `mcp/local_server.py`. +- Refactored: Continued the MCP server modularization by extracting + `mcp/sub_query_orchestration.py`, `mcp/recipe_runtime.py`, + `mcp/node_bridge.py`, and `mcp/repl_injection.py`, reducing + `mcp/local_server.py` to orchestration plus thin compatibility wrappers. +- Docs: README and DEVELOPMENT now lead with the large-codebase workflow and + document refreshable workspaces plus the read-only action policy. +- Tests: Added MCP contract coverage for workspace manifests, refreshable + file-backed contexts, action-policy enforcement, bootstrap env handling, + sub-query orchestration, recipe runtime extraction, node bridge extraction, + and REPL injection extraction. + ## 0.9.2 - Refactored: Centralized Aleph, MCP, and sub-query env parsing through typed diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 80bb760..cfc5f67 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -8,9 +8,9 @@ Architecture and development workflow for Aleph. Aleph is an MCP server implementing the [Recursive Language Model](https://arxiv.org/abs/2512.24601) (RLM) paradigm for -document analysis. Instead of stuffing context into prompts, Aleph stores -documents in a sandboxed Python REPL and provides tools for iterative -exploration. +large-codebase, project, and document analysis. Instead of stuffing context +into prompts, Aleph stores working data in a sandboxed Python REPL and +provides tools for iterative exploration. --- @@ -24,11 +24,17 @@ aleph/ ├── cli.py # CLI entry points (aleph-rlm install/doctor) ├── mcp/ │ ├── local_server.py # MCP server (main entry point) -│ ├── tool_registry.py # Tool registration helpers +│ ├── admin_tools.py # Runtime configure / remote-server MCP tools │ ├── actions.py # Action tools (read/write/run) +│ ├── context_tools.py # Context load/list/diff/save/load MCP tools +│ ├── query_tools.py # Search / peek / semantic-search MCP tools │ ├── recipes.py # Recipe schema validation +│ ├── reasoning_tools.py # Status / evidence / finalize MCP tools +│ ├── workspace_contexts.py # Refreshable file / manifest bindings +│ ├── workspace_tools.py # Workspace-manifest and refresh MCP tools │ ├── session.py # Session serialization │ ├── workspace.py # Workspace root detection +│ ├── server_bootstrap.py # CLI/env bootstrap for MCP server runtime │ └── server.py # Compatibility entry point (aliases local_server) ├── repl/ │ ├── sandbox.py # REPLEnvironment -- sandboxed code execution @@ -65,8 +71,8 @@ pip install -e ".[dev,mcp]" # Run tests python3 -m pytest -q -# Run MCP server locally (with action tools enabled) -aleph --enable-actions --tool-docs concise +# Run MCP server locally (action tools enabled, but kept read-only) +aleph --enable-actions --action-policy read-only --tool-docs concise ``` --- @@ -90,12 +96,12 @@ The primary entry point for IDE integration. Exposes tools: | Category | Tools | |---------------------|-------------------------------------------------------------------| -| **Context** | `load_context`, `peek_context`, `search_context` | +| **Context** | `load_context`, `load_file`, `load_workspace_manifest`, `refresh_context` | | **Compute** | `exec_python`, `get_variable` | | **Recursion** | `sub_query` (RLM-style recursive calls) | | **Reasoning** | `think`, `evaluate_progress`, `summarize_so_far` | | **Output** | `finalize`, `get_evidence`, `get_status` | -| **Actions** | `run_command`, `read_file`, `write_file`, `run_tests` | +| **Actions** | `rg_search`, `read_file`, `run_command`, `write_file`, `run_tests` | ### Sandbox (`repl/sandbox.py`) @@ -109,7 +115,8 @@ The `REPLEnvironment` provides a sandboxed Python execution environment: - **Helper injection:** 100+ functions for document analysis The sandbox is best-effort, not hardened. For untrusted input, use container -isolation. +isolation and keep MCP action tools in `--action-policy read-only` unless you +explicitly need writes or subprocess execution. ### Sub-Query System (`sub_query/`) @@ -192,11 +199,14 @@ ruff check aleph tests ## Adding a New Tool -1. Add the tool function in `mcp/local_server.py` inside `_register_tools()` -2. Decorate with `@self.server.tool()` -3. Include comprehensive docstring (shown to AI users) -4. Update `_Session` if tool needs state tracking -5. Add tests in `tests/` +1. Prefer a dedicated module under `aleph/mcp/` (for example + `workspace_tools.py`, `reasoning_tools.py`) instead of adding more inline + closures to `local_server.py`. +2. Register the module from `AlephMCPServerLocal._register_tools()` +3. Decorate with `@self.server.tool()` +4. Include comprehensive docstring (shown to AI users) +5. Update `_Session` if the tool needs state tracking +6. Add tests in `tests/` Example: diff --git a/README.md b/README.md index 80ee099..cceb22a 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ Aleph is an [MCP server](https://modelcontextprotocol.io/) and skill for **Recursive Language Models** (RLMs). It keeps working state — search indexes, code execution, evidence, recursion — in a Python process outside the prompt -window, so the LLM reasons iteratively over repos, logs, documents, and data -without burning context on raw content. +window, so the LLM reasons iteratively over large codebases, long-lived +projects, logs, documents, and data without burning context on raw content. ```text +-----------------+ tool calls +-----------------------------+ @@ -25,7 +25,8 @@ Why Aleph: `exec_typescript` provide a persistent Node.js runtime over the same `ctx`. - **Recurse.** Sub-queries and recipes split complex work across multiple reasoning passes. -- **Persist.** Save sessions and resume long investigations later. +- **Keep workspaces warm.** Bind contexts back to files or generated workspace + manifests, refresh them, and resume long investigations later. ## Quick Start @@ -47,6 +48,12 @@ a structured RLM workflow. Install [`docs/prompts/aleph.md`](docs/prompts/aleph.md) into your client's command/skill folder — see [MCP_SETUP.md](MCP_SETUP.md) for exact paths. +If you are using action tools on a real repo, the safest default is: + +```bash +aleph --enable-actions --action-policy read-only +``` + ### Cursor Use **global** MCP (`aleph-rlm install cursor`) for `--workspace-mode any`, or @@ -83,10 +90,37 @@ aleph-rlm configure --profile codex # overwrite existing config See [docs/CONFIGURATION.md](docs/CONFIGURATION.md) for all env vars, CLI flags, and runtime `configure(...)` options. -## First Workflow +## Large Codebase Workflow + +If your main use case is a repo or multi-folder project, start by loading a +compact workspace manifest instead of throwing raw source files into the model +window. That gives the model a map of the project, lets it search aggressively, +and keeps the session refreshable as the repo changes. + +```python +load_workspace_manifest(paths=["src", "tests"], context_id="repo") +rg_search(pattern="FastAPI|APIRouter|router\\.", paths=["src", "tests"], load_context_id="routes") +load_file(path="pyproject.toml", context_id="pyproject") +exec_python(code=""" +files = [line for line in ctx.splitlines() if line.startswith("- ")] +summary = { + "indexed_entries": len(files), + "top_python_files": [line for line in files if "| python |" in line][:10], +} +""", context_id="repo") +get_variable(name="summary", context_id="repo") +refresh_context(context_id="repo") +``` + +Use `load_workspace_manifest` as the default front door for large codebases and +projects. Then pull in specific files with `load_file`, search the repo with +`rg_search`, and refresh the bound context when the workspace changes. Refreshes +preserve the session's reasoning state, evidence log, and tracked tasks. + +### Single File Workflow -Aleph is best when you load data once, do the heavy work inside Aleph, and only -pull back compact answers. +Aleph is also strong when you load one large file once, do the heavy work +inside Aleph, and only pull back compact answers. ```python load_file(path="/absolute/path/to/large_file.log", context_id="doc") @@ -160,6 +194,7 @@ variables. | Scenario | What Aleph Is Good At | |---|---| +| Large codebase / project analysis | Build a workspace map, search quickly, load only the files that matter, and keep the session refreshable | | Large log analysis | Load big files, trace patterns, correlate events | | Codebase navigation | Search symbols, inspect routes, trace behavior | | Data exploration | Analyze JSON, CSV, and mixed text with Python helpers | @@ -171,7 +206,7 @@ variables. | Category | Primary tools | What they do | |---|---|---| -| Load context | `load_context`, `load_file`, `list_contexts`, `diff_contexts` | Put data into Aleph memory and inspect what is loaded | +| Load context | `load_context`, `load_file`, `load_workspace_manifest`, `refresh_context`, `list_contexts`, `diff_contexts` | Put data into Aleph memory, bind it back to workspace assets, and inspect what is loaded | | Navigate | `search_context`, `semantic_search`, `peek_context`, `chunk_context`, `rg_search` | Find the relevant slice before asking for an answer | | Compute | `exec_python`, `exec_javascript`, `exec_typescript`, `get_variable` | Run Python or JS/TS over the full context and retrieve only the derived result | | Reason | `think`, `evaluate_progress`, `get_evidence`, `finalize` | Structure progress and close out with evidence | @@ -273,6 +308,9 @@ explicitly pull it back: - `exec_python` stdout, stderr, and return values are bounded independently. - `ALEPH_CONTEXT_POLICY=isolated` adds stricter session export/import rules and more defensive defaults. +- `ALEPH_ACTION_POLICY=read-only` (or `--action-policy read-only`) keeps action + tools in read-only mode: search and file loading still work, but writes and + subprocess execution are blocked. The safest pattern is always: diff --git a/aleph/__init__.py b/aleph/__init__.py index 35384d8..a2969f5 100644 --- a/aleph/__init__.py +++ b/aleph/__init__.py @@ -46,4 +46,4 @@ "BudgetStatus", ] -__version__ = "0.9.2" +__version__ = "0.9.3" diff --git a/aleph/mcp/action_tools.py b/aleph/mcp/action_tools.py new file mode 100644 index 0000000..876d8b1 --- /dev/null +++ b/aleph/mcp/action_tools.py @@ -0,0 +1,475 @@ +"""Action MCP tool registrations for the local server. + +Extracted from local_server.py to keep the server class focused on +orchestration while action tools live in their own module. +""" + +from __future__ import annotations + +import json +import os +import shlex +import shutil +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +from ..compat import normalize_content_format +from ..types import ContentFormat +from .io_utils import _load_text_from_path +from .workspace import ( + DEFAULT_LINE_NUMBER_BASE, + LineNumberBase, + _scoped_path, + _validate_line_number_base, + _resolve_line_number_base, +) + +if TYPE_CHECKING: + from mcp.server.fastmcp import Context + + from .local_server import AlephMCPServerLocal +else: + Context = Any + +# Re-exported for compatibility imports in local_server.py +__all__ = [ + "register_action_tools", +] + + +def register_action_tools( + owner: "AlephMCPServerLocal", + *, + format_error: Callable[ + [str, Literal["json", "markdown", "object"]], str | dict[str, Any] + ], + format_payload: Callable[ + [dict[str, Any], Literal["json", "markdown", "object"]], str | dict[str, Any] + ], +) -> None: + _tool = owner._tool_decorator + + @_tool() + async def run_command( + cmd: str, + cwd: str | None = None, + timeout_seconds: float | None = None, + shell: bool = False, + confirm: bool = False, + output: Literal["json", "markdown", "object"] = "json", + context_id: str = "default", + ctx: Context = None, # type: ignore[assignment] + ) -> str | dict[str, Any]: + """Run a shell command.""" + err = owner._require_actions(confirm, requires_command=True) + if err: + return format_error(err, output=output) + if ctx is not None: + await owner._maybe_resolve_workspace_from_roots(ctx) + + session = owner._get_or_create_session(context_id) + session.iterations += 1 + + workspace_root = owner.action_config.workspace_root + cwd_path = ( + _scoped_path(workspace_root, cwd, owner.action_config.workspace_mode) + if cwd + else workspace_root + ) + timeout = ( + timeout_seconds + if timeout_seconds is not None + else owner.action_config.max_cmd_seconds + ) + + if shell: + user_shell = os.environ.get("SHELL", "/bin/sh") + argv = [user_shell, "-lc", cmd] + else: + argv = shlex.split(cmd) + if not argv: + return format_error("Empty command", output=output) + + payload = await owner._run_subprocess( + argv=argv, cwd=cwd_path, timeout_seconds=timeout + ) + session.repl._namespace["last_command_result"] = payload + owner._record_action( + session, + note="run_command", + snippet=(payload.get("stdout") or payload.get("stderr") or "")[:200], + ) + return format_payload(payload, output=output) + + @_tool() + async def rg_search( + pattern: str, + paths: list[str] | str | None = None, + glob: str | None = None, + max_results: int = 200, + load_context_id: str | None = None, + confirm: bool = False, + output: Literal["json", "markdown", "object"] = "json", + context_id: str = "default", + ctx: Context = None, # type: ignore[assignment] + ) -> str | dict[str, Any]: + """Fast codebase search using ripgrep (rg) with fallback scanning.""" + err = owner._require_actions(confirm) + if err: + return format_error(err, output=output) + if ctx is not None: + await owner._maybe_resolve_workspace_from_roots(ctx) + if not pattern: + return format_error("pattern is required", output=output) + if isinstance(paths, str): + paths = [paths] + + session = owner._get_or_create_session(context_id) + session.iterations += 1 + + workspace_root = owner.action_config.workspace_root + resolved_paths: list[Path] = [] + for p in paths or [str(workspace_root)]: + try: + resolved = _scoped_path( + workspace_root, p, owner.action_config.workspace_mode + ) + except Exception as e: + return format_error(str(e), output=output) + resolved_paths.append(resolved) + + matches: list[dict[str, Any]] = [] + truncated = False + used_rg = False + payload: dict[str, Any] | None = None + + rg_bin = shutil.which("rg") + if rg_bin: + used_rg = True + argv = [rg_bin, "--vimgrep", pattern] + if glob: + argv.extend(["-g", glob]) + if max_results > 0: + argv.extend(["-m", str(max_results)]) + argv.extend(str(p) for p in resolved_paths) + payload = await owner._run_subprocess( + argv=argv, + cwd=workspace_root, + timeout_seconds=owner.action_config.max_cmd_seconds, + ) + matches, truncated = owner._parse_rg_vimgrep( + payload.get("stdout") or "", max_results + ) + else: + matches, truncated = owner._python_rg_search( + pattern, + resolved_paths, + glob, + max_results, + ) + + hits_text = "\n".join( + f"{m['path']}:{m['line']}:{m['column']}:{m['text']}" for m in matches + ) + if load_context_id: + meta = owner._create_session( + hits_text, load_context_id, ContentFormat.TEXT, DEFAULT_LINE_NUMBER_BASE + ) + session.repl._namespace["last_rg_loaded_context"] = load_context_id + load_note = f"Loaded {len(matches)} match(es) into '{load_context_id}'." + else: + meta = None + load_note = None + + result_payload: dict[str, Any] = { + "pattern": pattern, + "paths": [str(p) for p in resolved_paths], + "used_rg": used_rg, + "match_count": len(matches), + "truncated": truncated, + "matches": matches, + } + if payload: + result_payload["command"] = payload.get("argv") + result_payload["timed_out"] = payload.get("timed_out", False) + result_payload["stderr"] = payload.get("stderr", "") + if load_context_id: + result_payload["loaded_context_id"] = load_context_id + result_payload["loaded_meta"] = { + "size_chars": meta.size_chars if meta else 0, + "size_lines": meta.size_lines if meta else 0, + } + if load_note: + result_payload["note"] = load_note + + session.repl._namespace["last_rg_result"] = result_payload + owner._record_action( + session, note="rg_search", snippet=f"{pattern} ({len(matches)} matches)" + ) + + if output == "object": + return result_payload + if output == "json": + return json.dumps(result_payload, ensure_ascii=False, indent=2) + + parts = [ + "## rg_search Results", + f"Pattern: `{pattern}`", + f"Matches: {len(matches)}" + (" (truncated)" if truncated else ""), + ] + if load_note: + parts.append(load_note) + if matches: + parts.append("") + parts.extend( + [ + f"- {m['path']}:{m['line']}:{m['column']}: {m['text']}" + for m in matches[:20] + ] + ) + if len(matches) > 20: + parts.append(f"... {len(matches) - 20} more") + return "\n".join(parts) + + @_tool() + async def read_file( + path: str, + start_line: int = 1, + limit: int = 200, + include_raw: bool = False, + line_number_base: int | None = None, + confirm: bool = False, + output: Literal["json", "markdown", "object"] = "json", + context_id: str = "default", + ctx: Context = None, # type: ignore[assignment] + ) -> str | dict[str, Any]: + """Read file content (raw).""" + err = owner._require_actions(confirm) + if err: + return format_error(err, output=output) + if ctx is not None: + await owner._maybe_resolve_workspace_from_roots(ctx) + + base_override: LineNumberBase | None = None + if line_number_base is not None: + try: + base_override = _validate_line_number_base(line_number_base) + except ValueError as e: + return format_error(str(e), output=output) + + session = owner._get_or_create_session(context_id, base_override) + session.iterations += 1 + try: + base = _resolve_line_number_base(session, line_number_base) + except ValueError as e: + return format_error(str(e), output=output) + + if base == 1 and start_line == 0: + start_line = 1 + if start_line < base: + return format_error(f"start_line must be >= {base}", output=output) + + try: + p = _scoped_path( + owner.action_config.workspace_root, + path, + owner.action_config.workspace_mode, + ) + except Exception as e: + return format_error(str(e), output=output) + + if not p.exists() or not p.is_file(): + return format_error(f"File not found: {path}", output=output) + + data = p.read_bytes() + if len(data) > owner.action_config.max_read_bytes: + return format_error( + f"File too large to read (>{owner.action_config.max_read_bytes} bytes): {path}", + output=output, + ) + + text = data.decode("utf-8", errors="replace") + lines = text.splitlines() + start_idx = max(0, start_line - base) + end_idx = min(len(lines), start_idx + max(0, limit)) + slice_lines = lines[start_idx:end_idx] + numbered = "\n".join( + f"{i + start_idx + base:>6}\t{line}" for i, line in enumerate(slice_lines) + ) + end_line = ( + (start_idx + len(slice_lines) - 1 + base) if slice_lines else start_line + ) + + payload: dict[str, Any] = { + "path": str(p), + "start_line": start_line, + "end_line": end_line, + "limit": limit, + "total_lines": len(lines), + "line_number_base": base, + "content": numbered, + } + if include_raw: + payload["content_raw"] = "\n".join(slice_lines) + session.repl._namespace["last_read_file_result"] = payload + owner._record_action( + session, note="read_file", snippet=f"{path} ({start_line}-{end_line})" + ) + return format_payload(payload, output=output) + + @_tool() + async def load_file( + path: str, + context_id: str = "default", + format: str = "auto", + line_number_base: LineNumberBase = DEFAULT_LINE_NUMBER_BASE, + confirm: bool = False, + ctx: Context = None, # type: ignore[assignment] + ) -> str: + """Load a workspace file into a context session.""" + from .workspace_contexts import make_file_binding + + err = owner._require_actions(confirm) + if err: + return f"Error: {err}" + if ctx is not None: + await owner._maybe_resolve_workspace_from_roots(ctx) + + try: + base = _validate_line_number_base(line_number_base) + except ValueError as e: + return f"Error: {e}" + + try: + p = _scoped_path( + owner.action_config.workspace_root, + path, + owner.action_config.workspace_mode, + ) + except Exception as e: + return f"Error: {e}" + + if not p.exists() or not p.is_file(): + return f"Error: File not found: {path}" + + try: + text, detected_fmt, warning = _load_text_from_path( + p, + owner.action_config.max_read_bytes, + owner.action_config.max_cmd_seconds, + ) + except ValueError as e: + return f"Error: {e}" + try: + normalized_format = normalize_content_format(format, allow_auto=True) + fmt = cast( + ContentFormat, + detected_fmt if normalized_format == "auto" else normalized_format, + ) + except Exception as e: + return f"Error: {e}" + meta = owner._create_session(text, context_id, fmt, base) + session = owner._sessions[context_id] + session.workspace_binding = make_file_binding( + p, owner.action_config.workspace_root + ) + owner._record_action(session, note="load_file", snippet=str(p)) + return owner._format_context_loaded(context_id, meta, base, note=warning) + + @_tool() + async def write_file( + path: str, + content: str, + mode: Literal["overwrite", "append"] = "overwrite", + confirm: bool = False, + output: Literal["json", "markdown", "object"] = "json", + context_id: str = "default", + ctx: Context = None, # type: ignore[assignment] + ) -> str | dict[str, Any]: + """Write file content.""" + err = owner._require_actions(confirm, requires_write=True) + if err: + return format_error(err, output=output) + if ctx is not None: + await owner._maybe_resolve_workspace_from_roots(ctx) + + session = owner._get_or_create_session(context_id) + session.iterations += 1 + + try: + p = _scoped_path( + owner.action_config.workspace_root, + path, + owner.action_config.workspace_mode, + ) + except Exception as e: + return format_error(str(e), output=output) + + payload_bytes = content.encode("utf-8", errors="replace") + if len(payload_bytes) > owner.action_config.max_write_bytes: + return format_error( + f"Content too large to write (>{owner.action_config.max_write_bytes} bytes)", + output=output, + ) + + p.parent.mkdir(parents=True, exist_ok=True) + file_mode = "ab" if mode == "append" else "wb" + with open(p, file_mode) as f: + f.write(payload_bytes) + + payload: dict[str, Any] = { + "path": str(p), + "bytes_written": len(payload_bytes), + "mode": mode, + } + session.repl._namespace["last_write_file_result"] = payload + owner._record_action( + session, note="write_file", snippet=f"{path} ({len(payload_bytes)} bytes)" + ) + return format_payload(payload, output=output) + + @_tool() + async def run_tests( + runner: Literal["auto", "pytest"] = "auto", + args: list[str] | None = None, + cwd: str | None = None, + confirm: bool = False, + output: Literal["json", "markdown", "object"] = "json", + context_id: str = "default", + ctx: Context = None, # type: ignore[assignment] + ) -> str | dict[str, Any]: + """Run project tests.""" + err = owner._require_actions(confirm, requires_command=True) + if err: + return format_error(err, output=output) + if ctx is not None: + await owner._maybe_resolve_workspace_from_roots(ctx) + + session = owner._get_or_create_session(context_id) + session.iterations += 1 + + workspace_root = owner.action_config.workspace_root + cwd_path = ( + _scoped_path(workspace_root, cwd, owner.action_config.workspace_mode) + if cwd + else workspace_root + ) + + # Heuristics for test runner + runner_bin: str = str(runner) + if runner == "auto": + runner_bin = "pytest" + + argv: list[str] = [runner_bin] + if args: + argv.extend(args) + + payload = await owner._run_subprocess( + argv=argv, cwd=cwd_path, timeout_seconds=owner.action_config.max_cmd_seconds + ) + owner._record_action( + session, + note=f"run_tests: {runner}", + snippet=(payload.get("stdout") or payload.get("stderr") or "")[:200], + ) + return format_payload(payload, output=output) diff --git a/aleph/mcp/actions.py b/aleph/mcp/actions.py index 8e628ca..dd6e629 100644 --- a/aleph/mcp/actions.py +++ b/aleph/mcp/actions.py @@ -1,32 +1,36 @@ -"""Action tool implementations for the MCP local server.""" +"""Action runtime helpers for the MCP local server. + +This module keeps the action-family config and helper behavior out of +``local_server.py`` while leaving MCP tool registration in +``action_tools.py``. +""" from __future__ import annotations import asyncio import fnmatch -import json -import os import re -import shlex -import shutil -import sys import time from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Iterable, Literal +from typing import Any, Iterable, Literal -from ..compat import normalize_content_format -from ..types import ContentFormat, ContextMetadata -from .formatting import _format_context_loaded, _format_error, _format_payload from .session import _Evidence, _Session -from .workspace import ( - DEFAULT_LINE_NUMBER_BASE, - DEFAULT_WORKSPACE_MODE, - LineNumberBase, - WorkspaceMode, - _detect_workspace_root, - _validate_line_number_base, -) +from .workspace import DEFAULT_WORKSPACE_MODE, WorkspaceMode, _detect_workspace_root + +ContextPolicy = Literal["trusted", "isolated"] +DEFAULT_CONTEXT_POLICY: ContextPolicy = "trusted" +ActionPolicy = Literal["read-write", "read-only"] +DEFAULT_ACTION_POLICY: ActionPolicy = "read-write" + +__all__ = [ + "ActionConfig", + "require_actions", + "record_action", + "run_subprocess", + "_parse_rg_vimgrep", + "_python_rg_search", +] @dataclass(slots=True) @@ -34,27 +38,41 @@ class ActionConfig: enabled: bool = False workspace_root: Path = field(default_factory=_detect_workspace_root) workspace_mode: WorkspaceMode = DEFAULT_WORKSPACE_MODE + context_policy: ContextPolicy = DEFAULT_CONTEXT_POLICY + action_policy: ActionPolicy = DEFAULT_ACTION_POLICY require_confirmation: bool = False max_cmd_seconds: float = 60.0 max_output_chars: int = 50_000 - max_read_bytes: int = 1_000_000_000 # Default 1GB. Increase if you have more RAM - the LLM only sees query results, not the file. - max_write_bytes: int = 100_000_000 # 100 MB - - -@dataclass(slots=True) -class ActionDeps: - action_config: ActionConfig - get_or_create_session: Callable[[str, LineNumberBase | None], _Session] - create_session: Callable[[str, str, ContentFormat, LineNumberBase], ContextMetadata] - scoped_path: Callable[[Path, str, WorkspaceMode], Path] - load_text_from_path: Callable[[Path, int, float], tuple[str, ContentFormat, str | None]] + max_read_bytes: int = 1_000_000_000 # Default 1GB. Increase if you have more RAM - the LLM only sees query results, not the file. + max_write_bytes: int = 100_000_000 # 100 MB + workspace_root_explicit: bool = ( + False # True when set via CLI arg, env var, or configure() + ) -def require_actions(action_config: ActionConfig, confirm: bool) -> str | None: +def require_actions( + action_config: ActionConfig, + confirm: bool, + *, + requires_write: bool = False, + requires_command: bool = False, +) -> str | None: if not action_config.enabled: return "Actions are disabled. Start the server with `--enable-actions`." if action_config.require_confirmation and not confirm: return "Confirmation required. Re-run with confirm=true." + if action_config.action_policy == "read-only" and ( + requires_write or requires_command + ): + if requires_command: + return ( + "Action policy is read-only. Process execution is blocked. " + "Re-run with `--action-policy read-write` or `configure(action_policy='read-write')`." + ) + return ( + "Action policy is read-only. Filesystem writes are blocked. " + "Re-run with `--action-policy read-write` or `configure(action_policy='read-write')`." + ) return None @@ -89,7 +107,9 @@ async def run_subprocess( ) timed_out = False try: - stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout_seconds) + stdout_b, stderr_b = await asyncio.wait_for( + proc.communicate(), timeout=timeout_seconds + ) except asyncio.TimeoutError: timed_out = True proc.kill() @@ -114,7 +134,10 @@ async def run_subprocess( } -def _parse_rg_vimgrep(output: str, max_results: int) -> tuple[list[dict[str, Any]], bool]: +def _parse_rg_vimgrep( + output: str, + max_results: int, +) -> tuple[list[dict[str, Any]], bool]: results: list[dict[str, Any]] = [] truncated = False limit = max_results if max_results > 0 else None @@ -128,12 +151,14 @@ def _parse_rg_vimgrep(output: str, max_results: int) -> tuple[list[dict[str, Any col_no = int(col_str) except ValueError: continue - results.append({ - "path": path_str, - "line": line_no, - "column": col_no, - "text": text, - }) + results.append( + { + "path": path_str, + "line": line_no, + "column": col_no, + "text": text, + } + ) if limit is not None and len(results) >= limit: truncated = True break @@ -143,7 +168,7 @@ def _parse_rg_vimgrep(output: str, max_results: int) -> tuple[list[dict[str, Any def _python_rg_search( pattern: str, roots: list[Path], - glob: str | None, + glob_pattern: str | None, max_results: int, max_read_bytes: int, ) -> tuple[list[dict[str, Any]], bool]: @@ -151,7 +176,16 @@ def _python_rg_search( truncated = False limit = max_results if max_results > 0 else None rx = re.compile(pattern) - skip_dirs = {".git", ".venv", "node_modules", "dist", "build", "__pycache__", ".mypy_cache", ".pytest_cache"} + skip_dirs = { + ".git", + ".venv", + "node_modules", + "dist", + "build", + "__pycache__", + ".mypy_cache", + ".pytest_cache", + } def _iter_files(root: Path) -> Iterable[Path]: if root.is_file(): @@ -166,489 +200,27 @@ def _iter_files(root: Path) -> Iterable[Path]: for root in roots: for path in _iter_files(root): - if glob and not fnmatch.fnmatch(path.name, glob): + if glob_pattern and not fnmatch.fnmatch(path.name, glob_pattern): continue try: if path.stat().st_size > max_read_bytes: continue - text = path.read_text(encoding="utf-8", errors="replace") + with open(path, "r", encoding="utf-8", errors="replace") as handle: + for idx, line in enumerate(handle, start=1): + match = rx.search(line) + if not match: + continue + results.append( + { + "path": str(path), + "line": idx, + "column": match.start() + 1, + "text": line.rstrip("\n"), + } + ) + if limit is not None and len(results) >= limit: + truncated = True + return results, truncated except Exception: continue - for idx, line in enumerate(text.splitlines(), start=1): - match = rx.search(line) - if not match: - continue - results.append({ - "path": str(path), - "line": idx, - "column": match.start() + 1, - "text": line, - }) - if limit is not None and len(results) >= limit: - truncated = True - return results, truncated return results, truncated - - -def _resolve_line_number_base( - session: _Session | None, - value: int | None, -) -> LineNumberBase: - if session is not None: - if value is None: - return session.line_number_base - base = _validate_line_number_base(value) - if base != session.line_number_base: - raise ValueError("line_number_base does not match existing session") - return base - if value is None: - return DEFAULT_LINE_NUMBER_BASE - return _validate_line_number_base(value) - -def _resolve_scoped_path( - deps: ActionDeps, - path: str, -) -> tuple[Path | None, str | None]: - try: - return ( - deps.scoped_path( - deps.action_config.workspace_root, - path, - deps.action_config.workspace_mode, - ), - None, - ) - except Exception as e: - return None, str(e) - - -async def run_command( - deps: ActionDeps, - cmd: str, - cwd: str | None = None, - timeout_seconds: float | None = None, - shell: bool = False, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", -) -> str | dict[str, Any]: - err = require_actions(deps.action_config, confirm) - if err: - return _format_error(err, output=output) - - session = deps.get_or_create_session(context_id, None) - session.iterations += 1 - - workspace_root = deps.action_config.workspace_root - cwd_path = ( - deps.scoped_path(workspace_root, cwd, deps.action_config.workspace_mode) - if cwd - else workspace_root - ) - timeout = timeout_seconds if timeout_seconds is not None else deps.action_config.max_cmd_seconds - - if shell: - user_shell = os.environ.get("SHELL", "/bin/sh") - argv = [user_shell, "-lc", cmd] - else: - argv = shlex.split(cmd) - if not argv: - return _format_error("Empty command", output=output) - - payload = await run_subprocess(action_config=deps.action_config, argv=argv, cwd=cwd_path, timeout_seconds=timeout) - session.repl._namespace["last_command_result"] = payload - record_action(session, note="run_command", snippet=(payload.get("stdout") or payload.get("stderr") or "")[:200]) - return _format_payload(payload, output=output) - - -async def rg_search( - deps: ActionDeps, - pattern: str, - paths: list[str] | str | None = None, - glob: str | None = None, - max_results: int = 200, - load_context_id: str | None = None, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", -) -> str | dict[str, Any]: - err = require_actions(deps.action_config, confirm) - if err: - return _format_error(err, output=output) - if not pattern: - return _format_error("pattern is required", output=output) - if isinstance(paths, str): - paths = [paths] - - session = deps.get_or_create_session(context_id, None) - session.iterations += 1 - - workspace_root = deps.action_config.workspace_root - resolved_paths: list[Path] = [] - for p in paths or [str(workspace_root)]: - resolved, err = _resolve_scoped_path(deps, p) - if err: - return _format_error(err, output=output) - if resolved is not None: - resolved_paths.append(resolved) - - matches: list[dict[str, Any]] = [] - truncated = False - used_rg = False - payload: dict[str, Any] | None = None - - rg_bin = shutil.which("rg") - if rg_bin: - used_rg = True - argv = [rg_bin, "--vimgrep", pattern] - if glob: - argv.extend(["-g", glob]) - if max_results > 0: - argv.extend(["-m", str(max_results)]) - argv.extend(str(p) for p in resolved_paths) - payload = await run_subprocess( - action_config=deps.action_config, - argv=argv, - cwd=workspace_root, - timeout_seconds=deps.action_config.max_cmd_seconds, - ) - matches, truncated = _parse_rg_vimgrep(payload.get("stdout") or "", max_results) - else: - matches, truncated = _python_rg_search( - pattern, - resolved_paths, - glob, - max_results, - deps.action_config.max_read_bytes, - ) - - hits_text = "\n".join( - f"{m['path']}:{m['line']}:{m['column']}:{m['text']}" for m in matches - ) - if load_context_id: - meta = deps.create_session(hits_text, load_context_id, ContentFormat.TEXT, DEFAULT_LINE_NUMBER_BASE) - session.repl._namespace["last_rg_loaded_context"] = load_context_id - load_note = f"Loaded {len(matches)} match(es) into '{load_context_id}'." - else: - meta = None - load_note = None - - result_payload = { - "pattern": pattern, - "paths": [str(p) for p in resolved_paths], - "used_rg": used_rg, - "match_count": len(matches), - "truncated": truncated, - "matches": matches, - } - if payload: - result_payload["command"] = payload.get("argv") - result_payload["timed_out"] = payload.get("timed_out", False) - result_payload["stderr"] = payload.get("stderr", "") - if load_context_id: - result_payload["loaded_context_id"] = load_context_id - result_payload["loaded_meta"] = { - "size_chars": meta.size_chars if meta else 0, - "size_lines": meta.size_lines if meta else 0, - } - if load_note: - result_payload["note"] = load_note - - session.repl._namespace["last_rg_result"] = result_payload - record_action(session, note="rg_search", snippet=f"{pattern} ({len(matches)} matches)") - - if output == "object": - return result_payload - if output == "json": - return json.dumps(result_payload, ensure_ascii=False, indent=2) - - parts = [ - "## rg_search Results", - f"Pattern: `{pattern}`", - f"Matches: {len(matches)}" + (" (truncated)" if truncated else ""), - ] - if load_note: - parts.append(load_note) - if matches: - parts.append("") - parts.extend([f"- {m['path']}:{m['line']}:{m['column']}: {m['text']}" for m in matches[:20]]) - if len(matches) > 20: - parts.append(f"... {len(matches) - 20} more") - return "\n".join(parts) - - -async def read_file( - deps: ActionDeps, - path: str, - start_line: int = 1, - limit: int = 200, - include_raw: bool = False, - line_number_base: int | None = None, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", -) -> str | dict[str, Any]: - err = require_actions(deps.action_config, confirm) - if err: - return _format_error(err, output=output) - - base_override: LineNumberBase | None = None - if line_number_base is not None: - try: - base_override = _validate_line_number_base(line_number_base) - except ValueError as e: - return _format_error(str(e), output=output) - - session = deps.get_or_create_session(context_id, base_override) - session.iterations += 1 - try: - base = _resolve_line_number_base(session, line_number_base) - except ValueError as e: - return _format_error(str(e), output=output) - - if base == 1 and start_line == 0: - start_line = 1 - if start_line < base: - return _format_error(f"start_line must be >= {base}", output=output) - - p, err = _resolve_scoped_path(deps, path) - if err or p is None: - return _format_error(err or "Invalid path", output=output) - - if not p.exists() or not p.is_file(): - return _format_error(f"File not found: {path}", output=output) - - data = p.read_bytes() - if len(data) > deps.action_config.max_read_bytes: - return _format_error( - f"File too large to read (>{deps.action_config.max_read_bytes} bytes): {path}", - output=output, - ) - - text = data.decode("utf-8", errors="replace") - lines = text.splitlines() - start_idx = max(0, start_line - base) - end_idx = min(len(lines), start_idx + max(0, limit)) - slice_lines = lines[start_idx:end_idx] - numbered = "\n".join( - f"{i + start_idx + base:>6}\t{line}" for i, line in enumerate(slice_lines) - ) - end_line = (start_idx + len(slice_lines) - 1 + base) if slice_lines else start_line - - payload: dict[str, Any] = { - "path": str(p), - "start_line": start_line, - "end_line": end_line, - "limit": limit, - "total_lines": len(lines), - "line_number_base": base, - "content": numbered, - } - if include_raw: - payload["content_raw"] = "\n".join(slice_lines) - session.repl._namespace["last_read_file_result"] = payload - record_action(session, note="read_file", snippet=f"{path} ({start_line}-{end_line})") - return _format_payload(payload, output=output) - - -async def load_file( - deps: ActionDeps, - path: str, - context_id: str = "default", - format: str = "auto", - line_number_base: LineNumberBase = DEFAULT_LINE_NUMBER_BASE, - confirm: bool = False, -) -> str: - err = require_actions(deps.action_config, confirm) - if err: - return f"Error: {err}" - - try: - base = _validate_line_number_base(line_number_base) - except ValueError as e: - return f"Error: {e}" - - p, err = _resolve_scoped_path(deps, path) - if err or p is None: - return f"Error: {err or 'Invalid path'}" - - if not p.exists() or not p.is_file(): - return f"Error: File not found: {path}" - - try: - text, detected_fmt, warning = deps.load_text_from_path( - p, - deps.action_config.max_read_bytes, - deps.action_config.max_cmd_seconds, - ) - except ValueError as e: - return f"Error: {e}" - try: - normalized_format = normalize_content_format(format, allow_auto=True) - fmt = detected_fmt if normalized_format == "auto" else normalized_format - except Exception as e: - return f"Error: {e}" - meta = deps.create_session(text, context_id, fmt, base) - session = deps.get_or_create_session(context_id, base) - record_action(session, note="load_file", snippet=str(p)) - return _format_context_loaded(context_id, meta, base, note=warning) - - -async def write_file( - deps: ActionDeps, - path: str, - content: str, - mode: Literal["overwrite", "append"] = "overwrite", - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", -) -> str | dict[str, Any]: - err = require_actions(deps.action_config, confirm) - if err: - return _format_error(err, output=output) - - session = deps.get_or_create_session(context_id, None) - session.iterations += 1 - - p, err = _resolve_scoped_path(deps, path) - if err or p is None: - return _format_error(err or "Invalid path", output=output) - - payload_bytes = content.encode("utf-8", errors="replace") - if len(payload_bytes) > deps.action_config.max_write_bytes: - return _format_error( - f"Content too large to write (>{deps.action_config.max_write_bytes} bytes)", - output=output, - ) - - p.parent.mkdir(parents=True, exist_ok=True) - file_mode = "ab" if mode == "append" else "wb" - with open(p, file_mode) as f: - f.write(payload_bytes) - - payload: dict[str, Any] = { - "path": str(p), - "bytes_written": len(payload_bytes), - "mode": mode, - } - session.repl._namespace["last_write_file_result"] = payload - record_action(session, note="write_file", snippet=f"{path} ({len(payload_bytes)} bytes)") - return _format_payload(payload, output=output) - - -async def run_tests( - deps: ActionDeps, - runner: Literal["auto", "pytest"] = "auto", - args: list[str] | None = None, - cwd: str | None = None, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", -) -> str | dict[str, Any]: - err = require_actions(deps.action_config, confirm) - if err: - return _format_error(err, output=output) - - session = deps.get_or_create_session(context_id, None) - session.iterations += 1 - - runner_resolved = "pytest" if runner == "auto" else runner - if runner_resolved != "pytest": - return _format_error(f"Unsupported test runner: {runner_resolved}", output=output) - - argv = [sys.executable, "-m", "pytest", "-vv", "--tb=short", "--maxfail=20"] - if args: - argv.extend(args) - - cwd_path = deps.action_config.workspace_root - if cwd: - resolved_cwd, err = _resolve_scoped_path(deps, cwd) - if err or resolved_cwd is None: - return _format_error(err or "Invalid path", output=output) - cwd_path = resolved_cwd - - proc_payload = await run_subprocess( - action_config=deps.action_config, - argv=argv, - cwd=cwd_path, - timeout_seconds=deps.action_config.max_cmd_seconds, - ) - stdout = str(proc_payload.get("stdout") or "") - stderr = str(proc_payload.get("stderr") or "") - raw_output = stdout + ("\n" + stderr if stderr else "") - - passed = 0 - failed = 0 - errors = 0 - duration_ms = float(proc_payload.get("duration_ms") or 0.0) - exit_code = int(proc_payload.get("exit_code") or 0) - - m_passed = re.search(r"(\\d+)\\s+passed", raw_output) - if m_passed: - passed = int(m_passed.group(1)) - m_failed = re.search(r"(\\d+)\\s+failed", raw_output) - if m_failed: - failed = int(m_failed.group(1)) - m_errors = re.search(r"(\\d+)\\s+errors?", raw_output) - if m_errors: - errors = int(m_errors.group(1)) - - failures: list[dict[str, Any]] = [] - section_re = re.compile(r"^_{3,}\\s+(?P.+?)\\s+_{3,}\\s*$", re.MULTILINE) - matches = list(section_re.finditer(raw_output)) - for i, sm in enumerate(matches): - start = sm.end() - end = matches[i + 1].start() if i + 1 < len(matches) else len(raw_output) - block = raw_output[start:end].strip() - file = "" - line = 0 - file_line = re.search(r"^(?P.+?\\.py):(?P\\d+):", block, re.MULTILINE) - if file_line: - file = file_line.group("file") - try: - line = int(file_line.group("line")) - except Exception: - line = 0 - msg = "" - err_line = re.search(r"^E\\s+(.+)$", block, re.MULTILINE) - if err_line: - msg = err_line.group(1).strip() - - failures.append( - { - "file": file, - "line": line, - "test_name": sm.group("name").strip(), - "message": msg, - "traceback": block, - } - ) - - if exit_code != 0 and failed == 0 and errors == 0: - errors = 1 - - status = "passed" - if exit_code != 0: - status = "failed" if failed > 0 else "error" - - result: dict[str, Any] = { - "passed": passed, - "failed": failed, - "errors": errors, - "failures": failures, - "status": status, - "duration_ms": duration_ms, - "exit_code": exit_code, - "raw_output": raw_output, - "command": proc_payload, - } - - session.repl._namespace["last_test_result"] = result - summary_snippet = ( - f"status={status} passed={passed} failed={failed} errors={errors} " - f"failures={len(failures)} exit_code={exit_code}" - ) - record_action(session, note="run_tests", snippet=summary_snippet) - for f in failures[:10]: - record_action(session, note="test_failure", snippet=(f.get("message") or f.get("test_name") or "")[:200]) - return _format_payload(result, output=output) diff --git a/aleph/mcp/admin_tools.py b/aleph/mcp/admin_tools.py index 88e058d..f2622b0 100644 --- a/aleph/mcp/admin_tools.py +++ b/aleph/mcp/admin_tools.py @@ -31,6 +31,7 @@ async def configure( max_recipe_concurrency: int | None = None, tool_docs_mode: Literal["concise", "full"] | None = None, context_policy: Literal["trusted", "isolated"] | None = None, + action_policy: Literal["read-write", "read-only"] | None = None, workspace_root: str | None = None, output_feedback: str | None = None, ) -> str: @@ -57,6 +58,9 @@ async def configure( os.environ["ALEPH_MAX_RECIPE_CONCURRENCY"] = str(max_recipe_concurrency) if tool_docs_mode: owner.tool_docs_mode = tool_docs_mode + if action_policy is not None: + owner.action_config.action_policy = action_policy + os.environ["ALEPH_ACTION_POLICY"] = action_policy if context_policy is not None: old_policy = owner.context_policy owner.context_policy = context_policy diff --git a/aleph/mcp/context_tools.py b/aleph/mcp/context_tools.py new file mode 100644 index 0000000..6ffd478 --- /dev/null +++ b/aleph/mcp/context_tools.py @@ -0,0 +1,221 @@ +"""Context/session MCP tool registrations for the local server.""" + +from __future__ import annotations + +import asyncio +import difflib +import json +from typing import TYPE_CHECKING, Any, Callable, Literal, cast + +from ..compat import normalize_content_format +from ..types import ContentFormat +from .io_utils import _detect_format +from .session import load_memory_pack_payload +from .workspace import LineNumberBase, _validate_line_number_base +from .workspace_contexts import binding_summary + +if TYPE_CHECKING: + from .local_server import AlephMCPServerLocal + + +def register_context_tools( + owner: "AlephMCPServerLocal", + *, + format_error: Callable[[str, Literal["json", "markdown", "object"]], str | dict[str, Any]], +) -> None: + _tool = owner._tool_decorator + + @_tool() + async def load_context( + content: str | None = None, + context_id: str = "default", + format: str = "auto", + line_number_base: LineNumberBase = 1, + context: str | None = None, + ) -> str: + """Load context into an in-memory REPL session.""" + text = content if content is not None else context + if text is None: + return "Error: content is required" + try: + base = _validate_line_number_base(line_number_base) + except ValueError as exc: + return f"Error: {exc}" + + normalized_format = normalize_content_format(format, allow_auto=True) + fmt = cast( + ContentFormat, + _detect_format(text) if normalized_format == "auto" else normalized_format, + ) + meta = owner._create_session(text, context_id, fmt, base) + return owner._format_context_loaded(context_id, meta, base) + + @_tool() + async def list_contexts( + output: Literal["json", "markdown", "object"] = "json", + ) -> str | dict[str, Any]: + """List all active context sessions and their status.""" + items = [] + for cid, session in owner._sessions.items(): + summary = binding_summary(session.workspace_binding) + items.append({ + "id": cid, + "chars": session.meta.size_chars, + "lines": session.meta.size_lines, + "iterations": session.iterations, + "evidence": len(session.evidence), + "workspace_binding": session.workspace_binding, + "workspace_binding_summary": summary, + }) + + if output == "object": + return {"count": len(items), "items": items} + if output == "json": + return json.dumps({"count": len(items), "items": items}, indent=2) + + lines = [f"Found {len(items)} active context session(s):\n"] + for item in items: + binding_note = ( + f" [{item['workspace_binding_summary']}]" + if item["workspace_binding_summary"] + else "" + ) + lines.append( + f"- **{item['id']}**: {item['chars']:,} chars, " + f"{item['lines']:,} lines, {item['iterations']} iterations{binding_note}" + ) + return "\n".join(lines) + + @_tool() + async def diff_contexts( + a: str, + b: str, + context_lines: int = 3, + max_lines: int = 400, + output: Literal["markdown", "text"] = "markdown", + ) -> str: + """Compare two context sessions using unified diff.""" + if a not in owner._sessions: + return f"Error: Context '{a}' not found." + if b not in owner._sessions: + return f"Error: Context '{b}' not found." + + lines_a = str(owner._sessions[a].repl.get_variable("ctx") or "").splitlines() + lines_b = str(owner._sessions[b].repl.get_variable("ctx") or "").splitlines() + + diff = list( + difflib.unified_diff( + lines_a, + lines_b, + fromfile=f"context:{a}", + tofile=f"context:{b}", + n=context_lines, + lineterm="", + ) + ) + + if not diff: + return f"Contexts '{a}' and '{b}' are identical." + + if len(diff) > max_lines: + diff = diff[:max_lines] + ["... (diff truncated)"] + + diff_text = "\n".join(diff) + rendered = ( + f"### Diff: {a} vs {b}\n\n```diff\n{diff_text}\n```" + if output == "markdown" + else diff_text + ) + text, _ = owner._truncate_tool_text(rendered) + return text + + @_tool() + async def save_session( + path: str = "aleph_session.json", + context_id: str | None = None, + session_id: str = "default", + confirm: bool = False, + output: Literal["json", "markdown", "object"] = "json", + ) -> str | dict[str, Any]: + """Save session state to a file (Memory Pack).""" + err = owner._require_actions(confirm, requires_write=True) + if err: + return format_error(err, output=output) + if owner.context_policy == "isolated" and not confirm: + return format_error( + "Isolated policy requires confirm=true for session export (prevents accidental context leaks).\n" + "To proceed: save_session(path=..., confirm=true)\n" + "To switch policy: configure(context_policy='trusted')", + output=output, + ) + + payload, skipped = owner._build_memory_pack_payload() + try: + scoped_path = owner._scoped_path(path) + except Exception as exc: + return format_error(f"Invalid path: {exc}", output=output) + + try: + scoped_path.parent.mkdir(parents=True, exist_ok=True) + with open(scoped_path, "w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, ensure_ascii=False) + except Exception as exc: + return format_error(f"Failed to save: {exc}", output=output) + + message = f"Session saved to {path}." + if skipped: + message += f" Warning: skipped {len(skipped)} sessions due to serialization errors." + + if output == "object": + return {"status": "success", "path": str(scoped_path), "skipped": skipped} + if output == "json": + return json.dumps({"status": "success", "path": str(scoped_path), "skipped": skipped}) + return message + + @_tool() + async def load_session( + path: str, + context_id: str | None = None, + session_id: str | None = None, + confirm: bool = False, + output: Literal["json", "markdown", "object"] = "json", + ) -> str | dict[str, Any]: + """Load session state from a file (Memory Pack).""" + err = owner._require_actions(confirm) + if err: + return format_error(err, output=output) + if owner.context_policy == "isolated" and not confirm: + return format_error( + "Isolated policy requires confirm=true for session import (prevents unvetted context rehydration).\n" + "To proceed: load_session(path=..., confirm=true)\n" + "To switch policy: configure(context_policy='trusted')", + output=output, + ) + + try: + scoped_path = owner._scoped_path(path) + with open(scoped_path, "r", encoding="utf-8") as handle: + payload = json.load(handle) + except Exception as exc: + return format_error(f"Failed to load: {exc}", output=output) + + try: + loaded, skipped = load_memory_pack_payload( + payload, + sessions=owner._sessions, + sandbox_config=owner.sandbox_config, + configure_session=owner._configure_session, + loop=asyncio.get_running_loop(), + close_node_repl=owner._close_node_repl, + ) + except ValueError as exc: + return format_error(str(exc), output=output) + + message = f"Loaded {len(loaded)} session(s) from {path}." + if skipped: + message += f" Skipped {len(skipped)} invalid session(s)." + if output == "object": + return {"status": "success", "loaded": loaded, "skipped": skipped} + if output == "json": + return json.dumps({"status": "success", "loaded": loaded, "skipped": skipped}) + return message diff --git a/aleph/mcp/formatting.py b/aleph/mcp/formatting.py index dbac388..04ae9e2 100644 --- a/aleph/mcp/formatting.py +++ b/aleph/mcp/formatting.py @@ -3,29 +3,101 @@ from __future__ import annotations import json -from typing import Any, Literal +from typing import Any, Callable, Literal, cast -from ..types import ContextMetadata +from ..types import ContextMetadata, ExecutionResult +from .session import _coerce_context_to_text + +DEFAULT_TOOL_RESPONSE_MAX_CHARS = 10_000 +DEFAULT_TOOL_TRUNCATION_SUFFIX = "\n... [TRUNCATED]" + + +def _truncate_tool_text( + text: str, + *, + max_chars: int = DEFAULT_TOOL_RESPONSE_MAX_CHARS, + truncation_suffix: str = DEFAULT_TOOL_TRUNCATION_SUFFIX, +) -> tuple[str, bool]: + if max_chars <= 0 or len(text) <= max_chars: + return text, False + if max_chars <= len(truncation_suffix): + return truncation_suffix[:max_chars], True + + # Keep a compact prefix/suffix preview instead of a large contiguous head. + # This avoids spilling big raw blocks into model context while preserving signal. + preview_each_side = min(400, max(0, (max_chars - len(truncation_suffix)) // 2)) + if preview_each_side == 0: + keep = max_chars - len(truncation_suffix) + return text[:keep] + truncation_suffix, True + return ( + text[:preview_each_side] + truncation_suffix + text[-preview_each_side:] + ), True def _format_payload( payload: dict[str, Any], output: Literal["json", "markdown", "object"], + *, + max_chars: int = DEFAULT_TOOL_RESPONSE_MAX_CHARS, + truncation_suffix: str = DEFAULT_TOOL_TRUNCATION_SUFFIX, + coerce_context_to_text: Callable[[Any], str] = _coerce_context_to_text, ) -> str | dict[str, Any]: + def _truncate_inline(text: str, limit: int) -> str: + return _truncate_tool_text( + text, + max_chars=limit, + truncation_suffix=truncation_suffix, + )[0] + + def _sanitize(value: Any, *, key: str | None = None) -> Any: + if key == "ctx": + text = coerce_context_to_text(value) + return { + "redacted": True, + "reason": "context_field_blocked", + "original_chars": len(text), + "value_preview": _truncate_inline(text, min(200, max_chars)), + } + + if isinstance(value, dict): + return {str(k): _sanitize(v, key=str(k)) for k, v in value.items()} + if isinstance(value, list): + return [_sanitize(v, key=key) for v in value] + if isinstance(value, str): + return _truncate_inline(value, max_chars) + return value + + safe_payload = cast(dict[str, Any], _sanitize(payload)) if output == "object": - return payload + return safe_payload + + rendered = json.dumps(safe_payload, ensure_ascii=False, indent=2) if output == "json": - return json.dumps(payload, ensure_ascii=False, indent=2) - return "```json\n" + json.dumps(payload, ensure_ascii=False, indent=2) + "\n```" + return _truncate_inline(rendered, max_chars) + + fence_overhead = len("```json\n\n```") + json_limit = max(0, max_chars - fence_overhead) + rendered = _truncate_inline(rendered, json_limit) + return "```json\n" + rendered + "\n```" def _format_error( message: str, output: Literal["json", "markdown", "object"], + *, + max_chars: int = DEFAULT_TOOL_RESPONSE_MAX_CHARS, + truncation_suffix: str = DEFAULT_TOOL_TRUNCATION_SUFFIX, + coerce_context_to_text: Callable[[Any], str] = _coerce_context_to_text, ) -> str | dict[str, Any]: if output == "markdown": return f"Error: {message}" - return _format_payload({"error": message}, output=output) + return _format_payload( + {"error": message}, + output=output, + max_chars=max_chars, + truncation_suffix=truncation_suffix, + coerce_context_to_text=coerce_context_to_text, + ) def _format_context_loaded( @@ -45,6 +117,140 @@ def _format_context_loaded( return msg +def _format_execution_result( + result: ExecutionResult, + *, + max_chars: int = DEFAULT_TOOL_RESPONSE_MAX_CHARS, + truncation_suffix: str = DEFAULT_TOOL_TRUNCATION_SUFFIX, +) -> str: + """Format sandboxed execution results for output.""" + if result.error: + text, _ = _truncate_tool_text( + f"## Execution Error\n\n{result.error}", + max_chars=max_chars, + truncation_suffix=truncation_suffix, + ) + return text + + res = ["## Execution Result\n"] + formatting_truncated = False + if result.stdout: + stdout_text, was_truncated = _truncate_tool_text( + result.stdout, + max_chars=max_chars, + truncation_suffix=truncation_suffix, + ) + formatting_truncated = formatting_truncated or was_truncated + res.append(f"**Output:**\n```\n{stdout_text}\n```") + if result.stderr: + stderr_text, was_truncated = _truncate_tool_text( + result.stderr, + max_chars=max_chars, + truncation_suffix=truncation_suffix, + ) + formatting_truncated = formatting_truncated or was_truncated + res.append(f"**Stderr:**\n```\n{stderr_text}\n```") + if result.return_value is not None: + rendered = repr(result.return_value) + rendered, was_truncated = _truncate_tool_text( + rendered, + max_chars=max_chars, + truncation_suffix=truncation_suffix, + ) + formatting_truncated = formatting_truncated or was_truncated + res.append(f"**Return Value:** `{rendered}`") + if result.variables_updated: + res.append( + f"\n**Variables Updated:** {', '.join(f'`{v}`' for v in result.variables_updated)}" + ) + + if result.truncated or formatting_truncated: + res.append("\n*Note: Output was truncated*") + + out = "\n".join(res) + out, _ = _truncate_tool_text( + out, + max_chars=max_chars, + truncation_suffix=truncation_suffix, + ) + return out + + +def _limit_json_items( + items: list[Any], + *, + max_chars: int = DEFAULT_TOOL_RESPONSE_MAX_CHARS, + to_jsonable: Callable[[Any], Any] | None = None, +) -> tuple[list[Any], bool]: + serializer = to_jsonable or _to_jsonable + used = 2 # [] delimiters + limited: list[Any] = [] + + for raw in items: + item = serializer(raw) + try: + encoded = json.dumps(item, ensure_ascii=False) + except Exception: + encoded = json.dumps(str(item), ensure_ascii=False) + + projected = used + len(encoded) + (1 if limited else 0) + if projected > max_chars: + return limited, True + + limited.append(item) + used = projected + + return limited, False + + +def _format_variable_value( + name: str, + value: Any, + *, + max_chars: int = DEFAULT_TOOL_RESPONSE_MAX_CHARS, + truncation_suffix: str = DEFAULT_TOOL_TRUNCATION_SUFFIX, + to_jsonable: Callable[[Any], Any] | None = None, +) -> Any: + serializer = to_jsonable or _to_jsonable + + if value is None or isinstance(value, (int, float, bool)): + return value + + if isinstance(value, str): + text, truncated = _truncate_tool_text( + value, + max_chars=max_chars, + truncation_suffix=truncation_suffix, + ) + if not truncated: + return value + return { + "name": name, + "truncated": True, + "original_chars": len(value), + "value_preview": text, + } + + jsonable = serializer(value) + try: + rendered = json.dumps(jsonable, ensure_ascii=False) + except Exception: + rendered = str(jsonable) + text, truncated = _truncate_tool_text( + rendered, + max_chars=max_chars, + truncation_suffix=truncation_suffix, + ) + if not truncated: + return jsonable + return { + "name": name, + "truncated": True, + "original_chars": len(rendered), + "value_preview": text, + } + + def _to_jsonable(obj: Any) -> Any: """Best-effort conversion of MCP/Pydantic objects into JSON-serializable data.""" if obj is None or isinstance(obj, (str, int, float, bool)): diff --git a/aleph/mcp/local_server.py b/aleph/mcp/local_server.py index 307eb7b..8526bf0 100644 --- a/aleph/mcp/local_server.py +++ b/aleph/mcp/local_server.py @@ -44,64 +44,80 @@ import asyncio from collections import OrderedDict -import difflib -import fnmatch import inspect import json import os -import re -import shutil -import shlex import sys -import time -from dataclasses import dataclass, field -from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal, cast +from typing import TYPE_CHECKING, Any, Callable, Literal, cast if TYPE_CHECKING: from mcp.server.fastmcp import Context -from ..compat import normalize_content_format, normalize_output_feedback +from ..compat import normalize_output_feedback from ..config import AlephConfig -from ..core import Aleph -from ..observability import traced_span -from ..prompts.system import DEFAULT_SYSTEM_PROMPT -from ..providers.registry import get_provider -from ..repl import helpers as repl_helpers -from ..repl.node_runtime import NodeREPLEnvironment +from ..core import Aleph # noqa: F401 - compatibility for external patching/imports +from ..repl.node_runtime import NodeREPLEnvironment # noqa: F401 - re-export from ..repl.sandbox import REPLEnvironment, SandboxConfig -from ..types import AlephResponse, ContentFormat, ContextMetadata, ContextType, ExecutionResult -from ..sub_query import ( - SubQueryConfig, - detect_backend, +from ..types import ( + AlephResponse, + ContentFormat, + ContextMetadata, + ExecutionResult, ) -from ..sub_query.config import ( - resolve_codex_mode, - resolve_codex_model, - resolve_codex_profile, - resolve_codex_reasoning_effort, +from ..sub_query import SubQueryConfig +from ..sub_query.cli_backend import CLI_BACKENDS +from .actions import ( + ActionConfig, + _parse_rg_vimgrep as _parse_rg_vimgrep_impl, + _python_rg_search as _python_rg_search_impl, + record_action as _record_action_impl, + require_actions as _require_actions_impl, + run_subprocess as _run_subprocess_impl, ) -from ..sub_query.cli_backend import run_cli_sub_query, CLI_BACKENDS -from ..sub_query.codex_mcp_backend import ( - build_codex_mcp_tool_call, - compose_sub_query_prompt, - extract_codex_mcp_result_text, - suppress_mcp_notification_validation_logs, -) -from ..sub_query.api_backend import run_api_sub_query +from .action_tools import register_action_tools as _register_action_tools_module from .admin_tools import register_admin_tools -from .env_utils import DEFAULT_REMOTE_TOOL_TIMEOUT_SECONDS, _get_env_bool, _get_env_int -from .io_utils import _detect_format, _load_text_from_path +from .context_tools import register_context_tools as _register_context_tools_module +from .env_utils import DEFAULT_REMOTE_TOOL_TIMEOUT_SECONDS, _get_env_int +from .formatting import ( + DEFAULT_TOOL_RESPONSE_MAX_CHARS, + DEFAULT_TOOL_TRUNCATION_SUFFIX, + _format_context_loaded as _format_context_loaded_impl, + _format_error, + _format_execution_result as _format_execution_result_impl, + _format_payload, + _format_variable_value as _format_variable_value_impl, + _limit_json_items as _limit_json_items_impl, + _to_jsonable, + _truncate_tool_text as _truncate_tool_text_impl, +) +from .io_utils import _detect_format, _load_text_from_path # noqa: F401 +from .node_bridge import ( + close_node_repl as _close_node_repl_impl, + configure_node_repl as _configure_node_repl_impl, + get_or_create_node_repl as _get_or_create_node_repl_impl, + sync_session_from_node_repl as _sync_session_from_node_repl_impl, +) from .query_tools import register_query_tools as _register_query_tools_module -from .recipes import estimate_recipe as _estimate_recipe -from .recipes import validate_recipe as _validate_recipe -from .reasoning_tools import register_reasoning_tools as _register_reasoning_tools_module +from .repl_injection import ( + configure_session as _configure_session_impl, + inject_repl_config_helpers as _inject_repl_config_helpers_impl, + inject_repl_sub_aleph as _inject_repl_sub_aleph_impl, + inject_repl_sub_query as _inject_repl_sub_query_impl, +) +from .recipe_runtime import ( + compile_recipe_code as _compile_recipe_code_impl, + execute_recipe as _execute_recipe_impl, + recipe_context_slice as _recipe_context_slice_impl, + recipe_preview as _recipe_preview_impl, +) +from .reasoning_tools import ( + register_reasoning_tools as _register_reasoning_tools_module, +) from .remote_servers import ( _RemoteServerHandle, close_remote_server, ensure_remote_server, - register_remote_server, remote_call_tool, remote_list_tools, remote_tool_allowed, @@ -116,13 +132,41 @@ apply_sub_query_runtime_config, get_sub_query_config_snapshot, ) -from .workspace import roots_to_workspace_root +from .sub_query_orchestration import ( + build_sub_aleph_cli_prompt as _build_sub_aleph_cli_prompt_impl, + ensure_internal_codex_mcp_server as _ensure_internal_codex_mcp_server_impl, + ensure_streamable_http_server as _ensure_streamable_http_server_impl, + extract_final_answer as _extract_final_answer_impl, + format_streamable_http_url as _format_streamable_http_url_impl, + normalize_streamable_http_path as _normalize_streamable_http_path_impl, + run_internal_codex_mcp_query as _run_internal_codex_mcp_query_impl, + run_streamable_http_server as _run_streamable_http_server_impl, + run_sub_aleph as _run_sub_aleph_impl, + run_sub_query as _run_sub_query_impl, + wait_for_streamable_http_ready as _wait_for_streamable_http_ready_impl, +) +from . import workspace as _workspace +from .workspace import ( + DEFAULT_WORKSPACE_MODE, + LineNumberBase, + _detect_workspace_root, + _scoped_path, + _validate_line_number_base, + roots_to_workspace_root, +) +from .workspace_tools import register_workspace_tools from .session import ( - _Evidence, + MEMORY_PACK_RELATIVE_PATH, + _Evidence, # noqa: F401 - compatibility for external imports _Session, - _coerce_context_to_text, - _session_to_payload, - _session_from_payload, + _resolve_session_payload_id, # noqa: F401 + build_memory_pack_payload as _build_memory_pack_payload_impl, + create_session as _create_session_impl, + get_or_create_session as _get_or_create_session_impl, + load_memory_pack_payload as _load_memory_pack_payload, + replace_session_context as _replace_session_context_impl, + restore_session_state as _restore_session_state_impl, + snapshot_session_state as _snapshot_session_state_impl, ) __all__ = ["AlephMCPServerLocal", "main", "mcp"] @@ -130,16 +174,17 @@ mcp: Any -LineNumberBase = Literal[0, 1] -DEFAULT_LINE_NUMBER_BASE: LineNumberBase = 1 -WorkspaceMode = Literal["fixed", "git", "any"] -DEFAULT_WORKSPACE_MODE: WorkspaceMode = "fixed" +_find_git_root = _workspace._find_git_root +_nearest_existing_parent = _workspace._nearest_existing_parent + + ToolDocsMode = Literal["concise", "full"] DEFAULT_TOOL_DOCS_MODE: ToolDocsMode = "concise" ContextPolicy = Literal["trusted", "isolated"] DEFAULT_CONTEXT_POLICY: ContextPolicy = "trusted" -DEFAULT_TOOL_RESPONSE_MAX_CHARS = 10_000 -_TOOL_TRUNCATION_SUFFIX = "\n... [TRUNCATED]" +ActionPolicy = Literal["read-write", "read-only"] +DEFAULT_ACTION_POLICY: ActionPolicy = "read-write" +_TOOL_TRUNCATION_SUFFIX = DEFAULT_TOOL_TRUNCATION_SUFFIX def _normalize_context_policy( @@ -156,10 +201,24 @@ def _normalize_context_policy( return cast(ContextPolicy, default) +def _normalize_action_policy( + value: str | None, + default: str = DEFAULT_ACTION_POLICY, +) -> ActionPolicy: + if value is None: + return cast(ActionPolicy, default) + normalized = value.strip().lower() + if normalized in {"read-write", "workspace-write", "write"}: + return "read-write" + if normalized in {"read-only", "readonly", "safe"}: + return "read-only" + return cast(ActionPolicy, default) _ANALYZE_CACHE_MAX = 64 -_ANALYZE_CACHE: OrderedDict[tuple[int, int, ContentFormat], ContextMetadata] = OrderedDict() +_ANALYZE_CACHE: OrderedDict[tuple[int, int, ContentFormat], ContextMetadata] = ( + OrderedDict() +) def _analyze_text_context(text: str, fmt: ContentFormat) -> ContextMetadata: @@ -185,21 +244,8 @@ def _analyze_text_context(text: str, fmt: ContentFormat) -> ContextMetadata: return meta -_FINAL_RE = re.compile(r"FINAL\((.*?)\)", re.DOTALL) -_FINAL_VAR_RE = re.compile(r"FINAL_VAR\((.*?)\)", re.DOTALL) - - def _extract_final_answer(text: str) -> tuple[str, bool]: - match = _FINAL_RE.search(text) - if match: - return match.group(1).strip(), True - match_var = _FINAL_VAR_RE.search(text) - if match_var: - raw = match_var.group(1).strip() - if len(raw) >= 2 and ((raw[0] == raw[-1] == '"') or (raw[0] == raw[-1] == "'")): - raw = raw[1:-1].strip() - return raw, True - return text.strip(), False + return _extract_final_answer_impl(text) def _build_sub_aleph_cli_prompt( @@ -209,183 +255,13 @@ def _build_sub_aleph_cli_prompt( context_format: ContentFormat, cfg: AlephConfig, ) -> str: - meta = _analyze_text_context(context_slice, context_format) - system_template = cfg.system_prompt or DEFAULT_SYSTEM_PROMPT - system_prompt = system_template.format( + return _build_sub_aleph_cli_prompt_impl( query=query, - context_var=cfg.context_var_name, - context_format=meta.format.value, - context_size_chars=meta.size_chars, - context_size_lines=meta.size_lines, - context_size_tokens=meta.size_tokens_estimate, - context_preview="[OMITTED FOR CONTEXT ISOLATION]", - structure_hint=meta.structure_hint or "N/A", - ) - instructions = ( - "SINGLE-SHOT MODE (no live Python REPL in this call):\n" - "- Do not output code blocks.\n" - "- Answer directly and wrap the final answer in FINAL(...).\n" + context_slice=context_slice, + context_format=context_format, + cfg=cfg, + analyze_text_context=_analyze_text_context, ) - return f"{system_prompt}\n\n{instructions}\nQUERY:\n{query}" - - -def _resolve_env_dir(name: str, require_exists: bool = True) -> Path | None: - value = os.environ.get(name) - if value is None: - return None - value = value.strip() - if not value: - return None - try: - path = Path(value).expanduser() - except Exception: - return None - if require_exists and not path.exists(): - return None - try: - path = path.resolve() - except Exception: - pass - if path.is_file(): - return path.parent - return path - - -def _detect_workspace_root() -> Path: - env_root = _resolve_env_dir("ALEPH_WORKSPACE_ROOT", require_exists=False) - if env_root is not None: - return env_root - cwd = _resolve_env_dir("PWD") or _resolve_env_dir("INIT_CWD") or Path.cwd() - for parent in [cwd, *cwd.parents]: - if (parent / ".git").exists(): - return parent - return cwd - - -def _nearest_existing_parent(path: Path) -> Path: - for parent in [path, *path.parents]: - if parent.exists(): - return parent - return path - - -def _find_git_root(path: Path) -> Path | None: - start = _nearest_existing_parent(path) - if start.is_file(): - start = start.parent - for parent in [start, *start.parents]: - if (parent / ".git").exists(): - return parent - return None - - -def _scoped_path(workspace_root: Path, path: str, mode: WorkspaceMode) -> Path: - root = workspace_root.resolve() - p = Path(path) - if p.is_absolute(): - resolved = p.resolve() - else: - resolved = (root / p).resolve() - - if mode == "any": - return resolved - - if mode == "git": - git_root = _find_git_root(resolved) - if git_root is None: - raise ValueError(f"Path '{path}' is not inside a git repository (workspace mode: git)") - if not resolved.is_relative_to(git_root): - raise ValueError(f"Path '{path}' escapes git root '{git_root}'") - return resolved - - if not resolved.is_relative_to(root): - raise ValueError(f"Path '{path}' escapes workspace root '{root}'") - return resolved - - -def _format_payload( - payload: dict[str, Any], - output: Literal["json", "markdown", "object"], -) -> str | dict[str, Any]: - def _truncate_inline(text: str, limit: int) -> str: - if limit <= 0 or len(text) <= limit: - return text - if limit <= len(_TOOL_TRUNCATION_SUFFIX): - return _TOOL_TRUNCATION_SUFFIX[:limit] - preview_each_side = min(400, max(0, (limit - len(_TOOL_TRUNCATION_SUFFIX)) // 2)) - if preview_each_side == 0: - keep = limit - len(_TOOL_TRUNCATION_SUFFIX) - return text[:keep] + _TOOL_TRUNCATION_SUFFIX - return ( - text[:preview_each_side] - + _TOOL_TRUNCATION_SUFFIX - + text[-preview_each_side:] - ) - - def _sanitize(value: Any, *, key: str | None = None) -> Any: - if key == "ctx": - text = _coerce_context_to_text(value) - return { - "redacted": True, - "reason": "context_field_blocked", - "original_chars": len(text), - "value_preview": _truncate_inline(text, min(200, DEFAULT_TOOL_RESPONSE_MAX_CHARS)), - } - - if isinstance(value, dict): - return { - str(k): _sanitize(v, key=str(k)) - for k, v in value.items() - } - if isinstance(value, list): - return [_sanitize(v, key=key) for v in value] - if isinstance(value, str): - return _truncate_inline(value, DEFAULT_TOOL_RESPONSE_MAX_CHARS) - return value - - safe_payload = cast(dict[str, Any], _sanitize(payload)) - if output == "object": - return safe_payload - - rendered = json.dumps(safe_payload, ensure_ascii=False, indent=2) - if output == "json": - return _truncate_inline(rendered, DEFAULT_TOOL_RESPONSE_MAX_CHARS) - - fence_overhead = len("```json\n\n```") - json_limit = max(0, DEFAULT_TOOL_RESPONSE_MAX_CHARS - fence_overhead) - rendered = _truncate_inline(rendered, json_limit) - return "```json\n" + rendered + "\n```" - - -def _format_error( - message: str, - output: Literal["json", "markdown", "object"], -) -> str | dict[str, Any]: - if output == "markdown": - return f"Error: {message}" - return _format_payload({"error": message}, output=output) - - -def _validate_line_number_base(value: int) -> LineNumberBase: - if value not in (0, 1): - raise ValueError("line_number_base must be 0 or 1") - return cast(LineNumberBase, value) - - -def _resolve_line_number_base( - session: _Session | None, - value: int | None, -) -> LineNumberBase: - if session is not None: - if value is None: - return session.line_number_base - base = _validate_line_number_base(value) - if base != session.line_number_base: - raise ValueError("line_number_base does not match existing session") - return base - if value is None: - return DEFAULT_LINE_NUMBER_BASE - return _validate_line_number_base(value) def _to_internal_line_index(index: int | None, base: int) -> int | None: @@ -402,18 +278,6 @@ def _to_internal_line_index(index: int | None, base: int) -> int | None: return index - 1 -def _resolve_session_payload_id(session_payload: Any) -> str | None: - """Resolve a session identifier from a memory-pack session payload.""" - - if not isinstance(session_payload, dict): - return None - for key in ("id", "context_id", "session_id"): - value = session_payload.get(key) - if value is not None and str(value).strip(): - return str(value) - return None - - def _get_repl_helper(repl: REPLEnvironment, name: str) -> object | None: """Return a helper callable, preferring stable helper references.""" @@ -425,40 +289,6 @@ def _get_repl_helper(repl: REPLEnvironment, name: str) -> object | None: return cast(object | None, repl.get_variable(name)) -def _to_jsonable(obj: Any) -> Any: - """Best-effort conversion of MCP/Pydantic objects into JSON-serializable data.""" - if obj is None or isinstance(obj, (str, int, float, bool)): - return obj - if isinstance(obj, dict): - return {str(k): _to_jsonable(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - return [_to_jsonable(v) for v in obj] - if hasattr(obj, "model_dump"): - try: - return obj.model_dump() - except Exception: - pass - if hasattr(obj, "__dict__"): - try: - return _to_jsonable(vars(obj)) - except Exception: - pass - return str(obj) - - -@dataclass(slots=True) -class ActionConfig: - enabled: bool = False - workspace_root: Path = field(default_factory=_detect_workspace_root) - workspace_mode: WorkspaceMode = DEFAULT_WORKSPACE_MODE - context_policy: ContextPolicy = DEFAULT_CONTEXT_POLICY - require_confirmation: bool = False - max_cmd_seconds: float = 60.0 - max_output_chars: int = 50_000 - max_read_bytes: int = 1_000_000_000 # Default 1GB. Increase if you have more RAM - the LLM only sees query results, not the file. - max_write_bytes: int = 100_000_000 # 100 MB - workspace_root_explicit: bool = False # True when set via CLI arg, env var, or configure() - class AlephMCPServerLocal: """MCP server for local AI reasoning. @@ -481,6 +311,10 @@ def __init__( self.action_config.context_policy, ) self.action_config.context_policy = self.context_policy + self.action_config.action_policy = _normalize_action_policy( + os.environ.get("ALEPH_ACTION_POLICY"), + self.action_config.action_policy, + ) self.output_feedback = normalize_output_feedback( os.environ.get("ALEPH_OUTPUT_FEEDBACK", "full") ) @@ -509,7 +343,9 @@ def __init__( # MCP roots-based workspace resolution (lazy, first action tool call) self._mcp_roots_resolved: bool = False self._workspace_root_source: str = ( - "explicit" if self.action_config.workspace_root_explicit else "auto-detected" + "explicit" + if self.action_config.workspace_root_explicit + else "auto-detected" ) # Import MCP lazily so it's an optional dependency @@ -517,7 +353,7 @@ def __init__( from mcp.server.fastmcp import Context as _MCPContext, FastMCP except Exception as e: raise RuntimeError( - "MCP support requires the `mcp` package. Install with `pip install \"aleph-rlm[mcp]\"`." + 'MCP support requires the `mcp` package. Install with `pip install "aleph-rlm[mcp]"`.' ) from e self._MCPContext = _MCPContext @@ -536,7 +372,7 @@ def _auto_load_memory_pack(self) -> None: if self._auto_pack_loaded: return self._auto_pack_loaded = True - pack_path = self.action_config.workspace_root / ".aleph" / "memory_pack.json" + pack_path = self.action_config.workspace_root / MEMORY_PACK_RELATIVE_PATH if not pack_path.exists() or not pack_path.is_file(): return try: @@ -552,33 +388,23 @@ def _auto_load_memory_pack(self) -> None: if not isinstance(obj, dict): return - if obj.get("schema") != "aleph.memory_pack.v1": - return - sessions = obj.get("sessions") - if not isinstance(sessions, list): + try: + _load_memory_pack_payload( + obj, + sessions=self._sessions, + sandbox_config=self.sandbox_config, + configure_session=self._configure_session, + loop=None, + skip_existing=True, + ) + except Exception: return - for payload in sessions: - if not isinstance(payload, dict): - continue - session_id = payload.get("context_id") or payload.get("session_id") - resolved_id = str(session_id) if session_id else f"session_{len(self._sessions) + 1}" - if resolved_id in self._sessions: - continue - try: - session = _session_from_payload(payload, resolved_id, self.sandbox_config, loop=None) - except Exception: - continue - self._configure_session(session, resolved_id, loop=None) - self._sessions[resolved_id] = session def _normalize_streamable_http_path(self, path: str) -> str: - if not path: - return "/mcp" - return path if path.startswith("/") else f"/{path}" + return _normalize_streamable_http_path_impl(path) def _format_streamable_http_url(self, host: str, port: int, path: str) -> str: - connect_host = "127.0.0.1" if host in {"0.0.0.0", "::"} else host - return f"http://{connect_host}:{port}{path}" + return _format_streamable_http_url_impl(host, port, path) async def _wait_for_streamable_http_ready( self, @@ -586,48 +412,15 @@ async def _wait_for_streamable_http_ready( port: int, timeout_seconds: float = 2.0, ) -> tuple[bool, str]: - deadline = time.monotonic() + timeout_seconds - connect_host = "127.0.0.1" if host in {"0.0.0.0", "::"} else host - - while time.monotonic() < deadline: - if self._streamable_http_task and self._streamable_http_task.done(): - exc = self._streamable_http_task.exception() - if exc: - return False, f"Streamable HTTP server failed to start: {exc}" - return False, "Streamable HTTP server stopped unexpectedly." - try: - reader, writer = await asyncio.wait_for( - asyncio.open_connection(connect_host, port), - timeout=0.2, - ) - writer.close() - await writer.wait_closed() - return True, "" - except Exception: - await asyncio.sleep(0.05) - - return False, f"Timed out waiting for streamable HTTP server on {connect_host}:{port}." + return await _wait_for_streamable_http_ready_impl( + self, + host, + port, + timeout_seconds=timeout_seconds, + ) async def _run_streamable_http_server(self, host: str, port: int) -> None: - try: - import uvicorn - except Exception as exc: - raise RuntimeError( - "uvicorn is required for streamable HTTP transport. " - "Install with: pip install uvicorn" - ) from exc - - app = self.server.streamable_http_app() - config = uvicorn.Config( - app, - host=host, - port=port, - log_level="warning", - access_log=False, - lifespan="on", - ) - server = uvicorn.Server(config) - await server.serve() + await _run_streamable_http_server_impl(self, host, port) async def _ensure_streamable_http_server( self, @@ -635,65 +428,10 @@ async def _ensure_streamable_http_server( port: int, path: str, ) -> tuple[bool, str]: - normalized_path = self._normalize_streamable_http_path(path) - async with self._streamable_http_lock: - if self._streamable_http_task and not self._streamable_http_task.done(): - url = self._streamable_http_url or self._format_streamable_http_url( - host, - port, - normalized_path, - ) - return True, url - if self._streamable_http_task and self._streamable_http_task.done(): - self._streamable_http_task = None - self._streamable_http_url = None - - self.server.settings.host = host - self.server.settings.port = port - self.server.settings.streamable_http_path = normalized_path - - self._streamable_http_task = asyncio.create_task( - self._run_streamable_http_server(host, port) - ) - self._streamable_http_host = host - self._streamable_http_port = port - self._streamable_http_path = normalized_path - self._streamable_http_url = self._format_streamable_http_url( - host, - port, - normalized_path, - ) - - ok, err = await self._wait_for_streamable_http_ready(host, port) - if not ok: - return False, err - return True, self._streamable_http_url or self._format_streamable_http_url( - host, - port, - normalized_path, - ) + return await _ensure_streamable_http_server_impl(self, host, port, path) async def _ensure_internal_codex_mcp_server(self, cwd: Path | None) -> str: - server_id = "__aleph_internal_codex__" - handle = self._remote_servers.get(server_id) - if handle is None: - handle = register_remote_server( - self._remote_servers, - server_id, - command="codex", - args=["mcp-server", "-c", "mcp_servers={}"], - cwd=cwd, - allow_tools=["codex", "codex-reply"], - ) - elif handle.cwd != cwd: - await self._reset_remote_server_handle(handle) - handle.cwd = cwd - - with suppress_mcp_notification_validation_logs(): - ok, res = await self._ensure_remote_server(server_id) - if not ok: - raise RuntimeError(str(res)) - return server_id + return await _ensure_internal_codex_mcp_server_impl(self, cwd) async def _run_internal_codex_mcp_query( self, @@ -705,46 +443,16 @@ async def _run_internal_codex_mcp_query( mcp_server_name: str, thread_id: str | None = None, ) -> tuple[bool, str, str | None]: - full_prompt = compose_sub_query_prompt(prompt, context_slice) - - tool_name, arguments = build_codex_mcp_tool_call( - prompt=full_prompt, + return await _run_internal_codex_mcp_query_impl( + self, + prompt=prompt, cwd=cwd, mcp_server_url=mcp_server_url, mcp_server_name=mcp_server_name, - trust_mcp_server=True, - model=resolve_codex_model(self.sub_query_config.codex_model), - reasoning_effort=resolve_codex_reasoning_effort( - self.sub_query_config.codex_reasoning_effort - ), - profile=resolve_codex_profile(self.sub_query_config.codex_profile), + context_slice=context_slice, thread_id=thread_id, ) - try: - server_id = await self._ensure_internal_codex_mcp_server(cwd) - except Exception as e: - return False, f"Failed to start internal Codex MCP server: {e}", None - - with suppress_mcp_notification_validation_logs(): - ok, result = await self._remote_call_tool( - server_id, - tool_name, - arguments, - timeout_seconds=self.sub_query_config.cli_timeout_seconds, - ) - if not ok: - return False, str(result), None - - output, resolved_thread_id = extract_codex_mcp_result_text(result) - if not output: - output = json.dumps(_to_jsonable(result), ensure_ascii=True) - - if len(output) > self.sub_query_config.cli_max_output_chars: - output = output[: self.sub_query_config.cli_max_output_chars] + "\n...[truncated]" - - return True, output, resolved_thread_id - async def _run_sub_query( self, *, @@ -756,192 +464,16 @@ async def _run_sub_query( max_retries: int | None = None, retry_prompt: str | None = None, ) -> tuple[bool, str, bool, str]: - session = self._sessions.get(context_id) - if session: - session.iterations += 1 - - truncated = False - if context_slice and len(context_slice) > self.sub_query_config.max_context_chars: - context_slice = context_slice[: self.sub_query_config.max_context_chars] - truncated = True - - resolved_backend = backend - if backend == "auto": - resolved_backend = detect_backend(self.sub_query_config) - - allowed_backends = {"auto", "api", *CLI_BACKENDS} - if resolved_backend not in allowed_backends: - allowed_list = ", ".join(sorted(allowed_backends)) - return ( - False, - f"Unsupported backend '{resolved_backend}'. Choose from: {allowed_list}.", - truncated, - resolved_backend, - ) - - resolved_validation_regex = validation_regex - if resolved_validation_regex is None: - resolved_validation_regex = ( - self.sub_query_config.validation_regex - or os.environ.get("ALEPH_SUB_QUERY_VALIDATION_REGEX") - ) - if resolved_validation_regex is not None: - resolved_validation_regex = resolved_validation_regex.strip() - if not resolved_validation_regex: - resolved_validation_regex = None - - resolved_max_retries = self.sub_query_config.max_retries if max_retries is None else max_retries - if max_retries is None: - resolved_max_retries = _get_env_int("ALEPH_SUB_QUERY_MAX_RETRIES", resolved_max_retries) - - resolved_retry_prompt = ( - self.sub_query_config.retry_prompt if retry_prompt is None else retry_prompt + return await _run_sub_query_impl( + self, + prompt=prompt, + context_slice=context_slice, + context_id=context_id, + backend=backend, + validation_regex=validation_regex, + max_retries=max_retries, + retry_prompt=retry_prompt, ) - if retry_prompt is None: - env_retry_prompt = os.environ.get("ALEPH_SUB_QUERY_RETRY_PROMPT") - if env_retry_prompt: - resolved_retry_prompt = env_retry_prompt - - validation_re: re.Pattern[str] | None = None - if resolved_validation_regex: - try: - validation_re = re.compile(resolved_validation_regex, re.MULTILINE) - except re.error as e: - return False, f"Invalid validation regex: {e}", truncated, resolved_backend - - attempt = 0 - base_prompt = prompt - prompt_for_attempt = base_prompt - codex_thread_id: str | None = None - with traced_span( - "aleph.sub_query", - { - "aleph.context_id": context_id, - "aleph.sub_query.backend.requested": backend, - "aleph.sub_query.backend.resolved": resolved_backend, - "aleph.sub_query.context_chars": len(context_slice or ""), - "aleph.sub_query.context_truncated": truncated, - "aleph.sub_query.validation_enabled": bool(resolved_validation_regex), - }, - ) as span: - success = False - output = "" - try: - while True: - run_prompt = prompt_for_attempt - if resolved_backend in CLI_BACKENDS: - mcp_server_url = None - server_name = "aleph_shared" - share_session = _get_env_bool("ALEPH_SUB_QUERY_SHARE_SESSION", False) - if share_session and resolved_backend in {"claude", "codex", "gemini", "kimi"}: - host = os.environ.get("ALEPH_SUB_QUERY_HTTP_HOST", "127.0.0.1") - port = _get_env_int("ALEPH_SUB_QUERY_HTTP_PORT", 8765) - path = os.environ.get("ALEPH_SUB_QUERY_HTTP_PATH", "/mcp") - server_name = os.environ.get( - "ALEPH_SUB_QUERY_MCP_SERVER_NAME", - "aleph_shared", - ).strip() or "aleph_shared" - ok, url_or_err = await self._ensure_streamable_http_server(host, port, path) - if not ok: - return False, f"Failed to start streamable HTTP server: {url_or_err}", truncated, resolved_backend - mcp_server_url = url_or_err - run_prompt = ( - f"{run_prompt}\n\n" - f"[MCP tools are available via the live Aleph server. " - f"Use context_id={context_id!r} when calling tools. " - f"Tools are prefixed with `mcp__{server_name}__`.]" - ) - cwd = self.action_config.workspace_root if self.action_config.enabled else None - if resolved_backend == "codex" and resolve_codex_mode( - self.sub_query_config.codex_mode - ) == "mcp": - success, output, codex_thread_id = await self._run_internal_codex_mcp_query( - prompt=run_prompt, - context_slice=context_slice, - cwd=cwd, - mcp_server_url=mcp_server_url, - mcp_server_name=server_name, - thread_id=codex_thread_id, - ) - else: - success, output = await run_cli_sub_query( - prompt=run_prompt, - context_slice=context_slice, - backend=resolved_backend, # type: ignore[arg-type] - timeout=self.sub_query_config.cli_timeout_seconds, - cwd=cwd, - max_output_chars=self.sub_query_config.cli_max_output_chars, - max_context_chars=self.sub_query_config.max_context_chars, - mcp_server_url=mcp_server_url, - mcp_server_name=server_name, - trust_mcp_server=True, - claude_model=self.sub_query_config.claude_model, - claude_effort=self.sub_query_config.claude_effort, - codex_mode=self.sub_query_config.codex_mode, - codex_model=self.sub_query_config.codex_model, - codex_reasoning_effort=self.sub_query_config.codex_reasoning_effort, - codex_profile=self.sub_query_config.codex_profile, - ) - else: - success, output = await run_api_sub_query( - prompt=run_prompt, - context_slice=context_slice, - model=self.sub_query_config.api_model, - api_key_env=self.sub_query_config.api_key_env, - api_base_url_env=self.sub_query_config.api_base_url_env, - api_model_env=self.sub_query_config.api_model_env, - timeout=self.sub_query_config.api_timeout_seconds, - system_prompt=self.sub_query_config.system_prompt if self.sub_query_config.include_system_prompt else None, - max_context_chars=self.sub_query_config.max_context_chars, - ) - - if not success: - break - - if validation_re and not validation_re.search(output): - if attempt >= resolved_max_retries: - success = False - output = ( - f"Output failed validation regex {resolved_validation_regex!r} " - f"after {attempt + 1} attempt(s). Last output: {output}" - ) - break - attempt += 1 - prompt_for_attempt = ( - f"{base_prompt}\n\n" - f"{resolved_retry_prompt}\n" - f"Required format regex: {resolved_validation_regex}" - ) - continue - - break - except Exception as e: - span.record_exception(e) - success = False - output = f"{type(e).__name__}: {e}" - - span.set_attribute("aleph.sub_query.success", success) - span.set_attribute("aleph.sub_query.attempts", attempt + 1) - span.set_attribute("aleph.sub_query.output_chars", len(output)) - - if session: - note_parts = [f"backend={resolved_backend}"] - if resolved_validation_regex: - note_parts.append(f"validation={resolved_validation_regex!r}") - if attempt: - note_parts.append(f"retries={attempt}") - if truncated: - note_parts.append("truncated_context") - session.evidence.append(_Evidence( - source="sub_query", - line_range=None, - pattern=None, - snippet=output[:200] if success else f"[ERROR] {output[:150]}", - note=" ".join(note_parts), - )) - session.information_gain.append(1 if success else 0) - - return success, output, truncated, resolved_backend async def _run_sub_aleph( self, @@ -959,243 +491,30 @@ async def _run_sub_aleph( max_wall_time_seconds: float | None = None, temperature: float | None = None, ) -> tuple[AlephResponse, dict[str, object]]: - session = self._sessions.get(context_id) - if session: - session.iterations += 1 - session.max_depth_seen = max(session.max_depth_seen, current_depth) - - cfg = AlephConfig.from_env() - budget = cfg.to_budget() - if max_tokens is not None: - budget.max_tokens = max_tokens - if max_iterations is not None: - budget.max_iterations = max_iterations - if max_depth is not None: - budget.max_depth = max_depth - if max_wall_time_seconds is not None: - budget.max_wall_time_seconds = max_wall_time_seconds - if max_sub_queries is not None: - budget.max_sub_queries = max_sub_queries - - resolved_root = root_model or cfg.root_model - resolved_sub = sub_model or cfg.sub_model or resolved_root - - temp_val = 0.0 - if temperature is not None: - try: - temp_val = float(temperature) - except (TypeError, ValueError): - temp_val = 0.0 - - resolved_backend = detect_backend(self.sub_query_config) - truncated_context = False - start_time = time.perf_counter() - - if resolved_backend in CLI_BACKENDS: - cli_context = context_slice or "" - if cli_context and len(cli_context) > self.sub_query_config.max_context_chars: - cli_context = cli_context[: self.sub_query_config.max_context_chars] - truncated_context = True - - context_format = session.meta.format if session else ContentFormat.TEXT - prompt = _build_sub_aleph_cli_prompt( - query=query, - context_slice=cli_context, - context_format=context_format, - cfg=cfg, - ) - - mcp_server_url = None - server_name = "aleph_shared" - share_session = _get_env_bool("ALEPH_SUB_QUERY_SHARE_SESSION", False) - if share_session and resolved_backend in {"claude", "codex", "gemini", "kimi"}: - host = os.environ.get("ALEPH_SUB_QUERY_HTTP_HOST", "127.0.0.1") - port = _get_env_int("ALEPH_SUB_QUERY_HTTP_PORT", 8765) - path = os.environ.get("ALEPH_SUB_QUERY_HTTP_PATH", "/mcp") - server_name = os.environ.get( - "ALEPH_SUB_QUERY_MCP_SERVER_NAME", - "aleph_shared", - ).strip() or "aleph_shared" - ok, url_or_err = await self._ensure_streamable_http_server(host, port, path) - if not ok: - response = AlephResponse( - answer="", - success=False, - total_iterations=0, - max_depth_reached=0, - total_tokens=0, - total_cost_usd=0.0, - wall_time_seconds=time.perf_counter() - start_time, - trajectory=[], - error=f"Failed to start streamable HTTP server: {url_or_err}", - error_type="cli_error", - ) - else: - mcp_server_url = url_or_err - prompt = ( - f"{prompt}\n\n" - f"[MCP tools are available via the live Aleph server. " - f"Use context_id={context_id!r} when calling tools. " - f"Tools are prefixed with `mcp__{server_name}__`.]" - ) - - if mcp_server_url is not None or not share_session: - try: - cwd = self.action_config.workspace_root if self.action_config.enabled else None - if resolved_backend == "codex" and resolve_codex_mode( - self.sub_query_config.codex_mode - ) == "mcp": - success, output, _thread_id = await self._run_internal_codex_mcp_query( - prompt=prompt, - context_slice=cli_context if cli_context else None, - cwd=cwd, - mcp_server_url=mcp_server_url, - mcp_server_name=server_name, - ) - else: - success, output = await run_cli_sub_query( - prompt=prompt, - context_slice=cli_context if cli_context else None, - backend=resolved_backend, # type: ignore[arg-type] - timeout=self.sub_query_config.cli_timeout_seconds, - cwd=cwd, - max_output_chars=self.sub_query_config.cli_max_output_chars, - max_context_chars=self.sub_query_config.max_context_chars, - mcp_server_url=mcp_server_url, - mcp_server_name=server_name, - trust_mcp_server=True, - claude_model=self.sub_query_config.claude_model, - claude_effort=self.sub_query_config.claude_effort, - codex_mode=self.sub_query_config.codex_mode, - codex_model=self.sub_query_config.codex_model, - codex_reasoning_effort=self.sub_query_config.codex_reasoning_effort, - codex_profile=self.sub_query_config.codex_profile, - ) - except Exception as e: - success, output = False, f"{type(e).__name__}: {e}" - - wall_time = time.perf_counter() - start_time - if success: - answer, _ = _extract_final_answer(output) - if not answer: - response = AlephResponse( - answer="", - success=False, - total_iterations=current_depth, - max_depth_reached=current_depth, - total_tokens=0, - total_cost_usd=0.0, - wall_time_seconds=wall_time, - trajectory=[], - error="Empty response from CLI backend", - error_type="cli_error", - ) - else: - response = AlephResponse( - answer=answer, - success=True, - total_iterations=current_depth, - max_depth_reached=current_depth, - total_tokens=0, - total_cost_usd=0.0, - wall_time_seconds=wall_time, - trajectory=[], - ) - else: - response = AlephResponse( - answer="", - success=False, - total_iterations=current_depth, - max_depth_reached=current_depth, - total_tokens=0, - total_cost_usd=0.0, - wall_time_seconds=wall_time, - trajectory=[], - error=output, - error_type="cli_error", - ) - else: - try: - provider = get_provider(cfg.provider, api_key=cfg.api_key) - runner = Aleph( - provider=provider, - root_model=resolved_root, - sub_model=resolved_sub, - budget=budget, - sandbox_config=self.sandbox_config, - system_prompt=cfg.system_prompt, - enable_caching=cfg.enable_caching, - log_trajectory=cfg.log_trajectory, - ) - response = await runner.complete( - query=query, - context=context_slice or "", - root_model=resolved_root, - sub_model=resolved_sub, - budget=budget, - temperature=temp_val, - ) - except Exception as e: - response = AlephResponse( - answer="", - success=False, - total_iterations=0, - max_depth_reached=0, - total_tokens=0, - total_cost_usd=0.0, - wall_time_seconds=0.0, - trajectory=[], - error=str(e), - error_type="provider_error", - ) - - if session: - note_parts = [f"backend={resolved_backend}", f"models={resolved_root}/{resolved_sub}"] - if budget.max_depth is not None: - note_parts.append(f"max_depth={budget.max_depth}") - if truncated_context: - note_parts.append("truncated_context") - session.evidence.append(_Evidence( - source="sub_aleph", - line_range=None, - pattern=None, - snippet=response.answer[:200] if response.success else f"[ERROR] {str(response.error)[:150]}", - note=" ".join(note_parts), - )) - session.information_gain.append(1 if response.success else 0) - - meta: dict[str, object] = { - "root_model": resolved_root, - "sub_model": resolved_sub, - "budget": budget, - "temperature": temp_val, - "backend": resolved_backend, - "truncated_context": truncated_context, - } - return response, meta + return await _run_sub_aleph_impl( + self, + query=query, + context_slice=context_slice, + context_id=context_id, + current_depth=current_depth, + root_model=root_model, + sub_model=sub_model, + max_depth=max_depth, + max_iterations=max_iterations, + max_tokens=max_tokens, + max_sub_queries=max_sub_queries, + max_wall_time_seconds=max_wall_time_seconds, + temperature=temperature, + analyze_text_context=_analyze_text_context, + ) @staticmethod def _recipe_preview(value: Any, limit: int = 180) -> str: - text = _coerce_context_to_text(value) - if len(text) <= limit: - return text - return text[: limit - 3] + "..." + return _recipe_preview_impl(value, limit=limit) @staticmethod def _recipe_context_slice(value: Any, context_field: str | None) -> str: - selected = value - if context_field: - if isinstance(value, dict): - selected = value.get(context_field) - elif isinstance(value, list): - extracted: list[Any] = [] - for item in value: - if isinstance(item, dict): - extracted.append(item.get(context_field)) - else: - extracted.append(item) - selected = extracted - return _coerce_context_to_text(selected) + return _recipe_context_slice_impl(value, context_field) async def _execute_recipe( self, @@ -1203,307 +522,17 @@ async def _execute_recipe( recipe: dict[str, Any], context_id_override: str | None = None, dry_run: bool = False, - progress_callback: Callable[[float, float | None, str | None], Any] | None = None, + progress_callback: Callable[[float, float | None, str | None], Any] + | None = None, ) -> tuple[bool, dict[str, Any]]: - normalized, errors = _validate_recipe(recipe) - if errors: - return False, {"errors": errors} - assert normalized is not None - - resolved_context_id = context_id_override or normalized["context_id"] - if resolved_context_id not in self._sessions: - return False, {"error": f"No context loaded with ID '{resolved_context_id}'."} - - estimate = _estimate_recipe(normalized) - if dry_run: - return True, { - "context_id": resolved_context_id, - "mode": "dry_run", - "recipe": normalized, - "estimate": estimate, - } - - session = self._sessions[resolved_context_id] - budget = normalized["budget"] - max_steps = int(budget["max_steps"]) - max_sub_queries = int(budget["max_sub_queries"]) - - current: Any = session.repl.get_variable("ctx") - variables: dict[str, Any] = {"ctx": current} - trace: list[dict[str, Any]] = [] - sub_queries_used = 0 - total_steps = float(len(normalized["steps"])) - - async def _report(progress: float, total: float | None = None, message: str | None = None) -> None: - if progress_callback is not None: - try: - result = progress_callback(progress, total, message) - if asyncio.iscoroutine(result): - await result - except Exception: - pass - - for step_index, step in enumerate(normalized["steps"], 1): - if step_index > max_steps: - return False, { - "error": f"Recipe exceeded budget.max_steps ({step_index} > {max_steps})", - "failed_step": step_index, - "trace": trace, - } - - session.iterations += 1 - - input_name = step.get("input") - if input_name: - if input_name not in variables: - return False, { - "error": f"Step {step_index}: input variable '{input_name}' not found.", - "failed_step": step_index, - "trace": trace, - } - current = variables[input_name] - - op = step["op"] - step_trace: dict[str, Any] = { - "step": step_index, - "op": op, - } - - try: - if op == "search": - current = repl_helpers.search( - current, - step["pattern"], - context_lines=step.get("context_lines", 2), - max_results=step.get("max_results", 20), - ) - step_trace["result_count"] = len(current) if isinstance(current, list) else 0 - - elif op == "peek": - current = repl_helpers.peek( - current, - start=step.get("start", 0), - end=step.get("end"), - ) - - elif op == "lines": - current = repl_helpers.lines( - current, - start=step.get("start", 0), - end=step.get("end"), - ) - - elif op == "take": - count = int(step["count"]) - if isinstance(current, str): - current = current[:count] - elif isinstance(current, (list, tuple)): - current = list(current)[:count] - else: - raise ValueError("take requires a list/tuple/string value") - - elif op == "chunk": - text = _coerce_context_to_text(current) - chunk_size = int(step["chunk_size"]) - overlap = int(step.get("overlap", 0)) - current = repl_helpers.chunk(text, chunk_size, overlap) - step_trace["result_count"] = len(current) - - elif op == "filter": - if not isinstance(current, list): - raise ValueError("filter requires current value to be a list") - field_name = step.get("field") - pattern = step.get("pattern") - contains = step.get("contains") - rx = re.compile(pattern) if pattern else None - out: list[Any] = [] - for item in current: - candidate: Any = item - if field_name: - if isinstance(item, dict): - candidate = item.get(field_name) - else: - candidate = None - candidate_text = _coerce_context_to_text(candidate) - matched = True - if rx is not None: - matched = bool(rx.search(candidate_text)) - if contains is not None: - matched = matched and contains in candidate_text - if matched: - out.append(item) - current = out - step_trace["result_count"] = len(current) - - elif op == "assign": - variables[step["name"]] = current - - elif op == "load": - name = step["name"] - if name not in variables: - raise ValueError(f"variable '{name}' not found") - current = variables[name] - - elif op == "map_sub_query": - if not isinstance(current, list): - raise ValueError("map_sub_query requires current value to be a list") - - limit = step.get("limit") - items = current[:limit] if isinstance(limit, int) else current - parallel = step.get("parallel", True) - continue_on_error = step.get("continue_on_error", False) - - remaining_budget = max_sub_queries - sub_queries_used - if len(items) > remaining_budget: - raise RuntimeError( - f"Recipe sub-query budget would be exceeded " - f"({sub_queries_used} + {len(items)} > {max_sub_queries})" - ) - - if parallel and len(items) > 1: - # Parallel execution with bounded concurrency - parallel_limit = max(1, min(self.max_recipe_concurrency, len(items))) - sem = asyncio.Semaphore(parallel_limit) - - async def _run_item(idx: int, item: object) -> tuple[int, bool, str]: - async with sem: - ctx_slice = self._recipe_context_slice(item, step.get("context_field")) - ok, out, _trunc, _bk = await self._run_sub_query( - prompt=step["prompt"], - context_slice=ctx_slice, - context_id=resolved_context_id, - backend=step.get("backend", "auto"), - ) - return idx, ok, out - - tasks = [_run_item(i, it) for i, it in enumerate(items)] - results = await asyncio.gather(*tasks, return_exceptions=True) - outputs: list[str] = [""] * len(items) - for r in results: - if isinstance(r, BaseException): - if not continue_on_error: - raise RuntimeError(f"sub_query failed: {r}") - outputs[0] = f"[ERROR] {r}" # placeholder - else: - idx, ok, item_output = r - if not ok and not continue_on_error: - raise RuntimeError(f"sub_query failed: {item_output}") - outputs[idx] = item_output if ok else f"[ERROR] {item_output}" - sub_queries_used += len(items) - else: - # Sequential fallback - outputs = [] - for item in items: - context_slice = self._recipe_context_slice(item, step.get("context_field")) - success, output, _truncated, _backend = await self._run_sub_query( - prompt=step["prompt"], - context_slice=context_slice, - context_id=resolved_context_id, - backend=step.get("backend", "auto"), - ) - sub_queries_used += 1 - if not success and not continue_on_error: - raise RuntimeError(f"sub_query failed: {output}") - outputs.append(output if success else f"[ERROR] {output}") - await _report( - float(step_index), - total_steps, - f"map_sub_query: {len(items)} items processed", - ) - - current = outputs - step_trace["sub_queries"] = len(outputs) - step_trace["parallel"] = parallel and len(items) > 1 - - elif op in {"sub_query", "aggregate"}: - if sub_queries_used >= max_sub_queries: - raise RuntimeError( - "Recipe sub-query budget exceeded " - f"({sub_queries_used} >= {max_sub_queries})" - ) - - if op == "aggregate" and isinstance(current, list): - context_slice = "\n\n".join( - _coerce_context_to_text(item) for item in current - ) - else: - context_slice = self._recipe_context_slice( - current, step.get("context_field") - ) - - success, output, _truncated, _backend = await self._run_sub_query( - prompt=step["prompt"], - context_slice=context_slice, - context_id=resolved_context_id, - backend=step.get("backend", "auto"), - ) - sub_queries_used += 1 - if not success: - raise RuntimeError(f"sub_query failed: {output}") - current = output - step_trace["sub_queries"] = 1 - - elif op == "finalize": - step_trace["status"] = "finalized" - trace.append(step_trace) - break - - else: - raise ValueError(f"unsupported op: {op}") - except Exception as e: - step_trace["status"] = "error" - step_trace["error"] = str(e) - trace.append(step_trace) - session.evidence.append( - _Evidence( - source="exec", - line_range=None, - pattern=None, - note=f"run_recipe failed at step {step_index}", - snippet=f"{op}: {str(e)[:180]}", - ) - ) - return False, { - "error": f"Step {step_index} ({op}) failed: {e}", - "failed_step": step_index, - "trace": trace, - "sub_queries_used": sub_queries_used, - "budget": budget, - "estimate": estimate, - } - - store_name = step.get("store") - if store_name: - variables[store_name] = current - - step_trace["status"] = "ok" - step_trace["preview"] = self._recipe_preview(current) - trace.append(step_trace) - await _report(float(step_index), total_steps, f"Step {step_index}/{int(total_steps)} ({op}) done") - - session.evidence.append( - _Evidence( - source="exec", - line_range=None, - pattern=None, - note=f"run_recipe completed ({len(trace)} steps)", - snippet=self._recipe_preview(current), - ) + return await _execute_recipe_impl( + self, + recipe=recipe, + context_id_override=context_id_override, + dry_run=dry_run, + progress_callback=progress_callback, ) - payload = { - "context_id": resolved_context_id, - "recipe_version": normalized["version"], - "step_count": len(normalized["steps"]), - "sub_queries_used": sub_queries_used, - "budget": budget, - "estimate": estimate, - "trace": trace, - "value": _to_jsonable(current), - "variables": sorted(variables.keys()), - } - return True, payload - async def _compile_recipe_code( self, *, @@ -1511,84 +540,12 @@ async def _compile_recipe_code( context_id: str = "default", language: str = "python", ) -> tuple[bool, dict[str, Any]]: - if context_id not in self._sessions: - return False, {"error": f"No context loaded with ID '{context_id}'."} - - session = self._sessions[context_id] - session.iterations += 1 - - if language in ("javascript", "typescript"): - node_repl = self._get_or_create_node_repl(context_id) - result = await node_repl.execute_async( - code, language=language, # type: ignore[arg-type] - ) - else: - result = await session.repl.execute_async(code) - - if result.error: - return False, { - "error": f"Recipe code execution failed: {result.error}", - "execution": { - "stderr": result.stderr, - "stdout": result.stdout, - }, - } - - if language in ("javascript", "typescript"): - candidate = result.return_value - if candidate is None: - maybe_node_repl = self._node_repls.get(context_id) - if maybe_node_repl is not None: - candidate = maybe_node_repl.get_variable("recipe") - self._sync_session_from_node_repl(context_id) - else: - candidate = result.return_value - if candidate is None: - candidate = session.repl.get_variable("recipe") - - if candidate is None: - return False, { - "error": ( - "Recipe code did not return a recipe value. " - "Return a RecipeBuilder/dict or assign to variable `recipe`." - ), - } - - compiled: Any = candidate - if isinstance(candidate, dict): - compiled = dict(candidate) - elif hasattr(candidate, "compile") and callable(getattr(candidate, "compile")): - compiled = candidate.compile() - elif hasattr(candidate, "to_dict") and callable(getattr(candidate, "to_dict")): - compiled = candidate.to_dict() - else: - return False, { - "error": ( - "Recipe code returned unsupported type. " - "Expected dict or object with compile()/to_dict()." - ), - "type": str(type(candidate)), - } - - normalized, errors = _validate_recipe(compiled) - if errors or normalized is None: - return False, { - "error": "Compiled recipe is invalid.", - "errors": errors, - "compiled": _to_jsonable(compiled), - } - - return True, { - "context_id": context_id, - "recipe": normalized, - "estimate": _estimate_recipe(normalized), - "execution": { - "variables_updated": result.variables_updated, - "execution_time_ms": result.execution_time_ms, - "stdout": result.stdout, - "stderr": result.stderr, - }, - } + return await _compile_recipe_code_impl( + self, + code=code, + context_id=context_id, + language=language, + ) def _get_sub_query_config_snapshot(self) -> dict[str, Any]: return get_sub_query_config_snapshot( @@ -1612,54 +569,13 @@ def _apply_sub_query_runtime_config( ) def _inject_repl_config_helpers(self, session: _Session) -> None: - def set_backend(backend: str) -> str: - ok, message = self._apply_sub_query_runtime_config(sub_query_backend=backend) - if not ok: - raise ValueError(message) - snapshot = self._get_sub_query_config_snapshot() - return ( - "sub_query_backend set to " - f"{snapshot['sub_query_backend']!r} " - f"(resolved: {snapshot['sub_query_backend_resolved']!r})" - ) - - def get_config() -> dict[str, Any]: - return self._get_sub_query_config_snapshot() - - session.repl.set_variable("set_backend", set_backend) - session.repl.set_variable("get_config", get_config) + _inject_repl_config_helpers_impl(self, session) def _inject_repl_sub_query(self, session: _Session, context_id: str) -> None: - async def sub_query(prompt: str, context_slice: str | None = None) -> str: - success, output, _truncated, _backend = await self._run_sub_query( - prompt=prompt, - context_slice=context_slice, - context_id=context_id, - backend="auto", - ) - if not success: - return f"[ERROR: sub_query failed: {output}]" - return output - - session.repl.inject_sub_query(sub_query) + _inject_repl_sub_query_impl(self, session, context_id) def _inject_repl_sub_aleph(self, session: _Session, context_id: str) -> None: - async def sub_aleph(query: str, context: ContextType | None = None) -> AlephResponse: - context_slice: str | None - if context is None: - context_slice = None - elif isinstance(context, str): - context_slice = context - else: - context_slice = _coerce_context_to_text(context) - response, _meta = await self._run_sub_aleph( - query=query, - context_slice=context_slice, - context_id=context_id, - ) - return response - - session.repl.inject_sub_aleph(sub_aleph) + _inject_repl_sub_aleph_impl(self, session, context_id) def _configure_session( self, @@ -1667,13 +583,11 @@ def _configure_session( context_id: str, loop: asyncio.AbstractEventLoop | None = None, ) -> None: - if loop is not None: - session.repl.set_loop(loop) - self._inject_repl_sub_query(session, context_id) - self._inject_repl_sub_aleph(session, context_id) - self._inject_repl_config_helpers(session) + _configure_session_impl(self, session, context_id, loop=loop) - async def _ensure_remote_server(self, server_id: str) -> tuple[bool, str | _RemoteServerHandle]: + async def _ensure_remote_server( + self, server_id: str + ) -> tuple[bool, str | _RemoteServerHandle]: return await ensure_remote_server(self._remote_servers, server_id) async def _reset_remote_server_handle(self, handle: _RemoteServerHandle) -> None: @@ -1716,161 +630,93 @@ def _format_context_loaded( line_number_base: LineNumberBase, note: str | None = None, ) -> str: - line_desc = "1-based" if line_number_base == 1 else "0-based" - msg = ( - f"Context loaded '{context_id}': {meta.size_chars:,} chars, " - f"{meta.size_lines:,} lines, ~{meta.size_tokens_estimate:,} tokens " - f"(line numbers {line_desc})." + return _format_context_loaded_impl( + context_id, + meta, + line_number_base, + note=note, ) - if note: - msg += f"\nNote: {note}" - return msg - def _create_session( + def _snapshot_session_state(self, session: _Session) -> dict[str, Any]: + return _snapshot_session_state_impl(session) + + def _restore_session_state(self, session: _Session, state: dict[str, Any]) -> None: + _restore_session_state_impl(session, state) + + def _replace_session_context( self, context: str, context_id: str, fmt: ContentFormat, line_number_base: LineNumberBase, + *, + preserve_state: bool = False, ) -> ContextMetadata: - self._close_node_repl(context_id) - meta = _analyze_text_context(context, fmt) - repl = REPLEnvironment( + return _replace_session_context_impl( + sessions=self._sessions, context=context, - context_var_name="ctx", - config=self.sandbox_config, - loop=asyncio.get_running_loop(), + context_id=context_id, + fmt=fmt, + line_number_base=line_number_base, + sandbox_config=self.sandbox_config, + analyze_text_context=_analyze_text_context, + configure_session=self._configure_session, + close_node_repl=self._close_node_repl, + preserve_state=preserve_state, ) - repl.set_variable("line_number_base", line_number_base) - self._sessions[context_id] = _Session( - repl=repl, - meta=meta, + + def _create_session( + self, + context: str, + context_id: str, + fmt: ContentFormat, + line_number_base: LineNumberBase, + ) -> ContextMetadata: + return _create_session_impl( + sessions=self._sessions, + context=context, + context_id=context_id, + fmt=fmt, line_number_base=line_number_base, + sandbox_config=self.sandbox_config, + analyze_text_context=_analyze_text_context, + configure_session=self._configure_session, + close_node_repl=self._close_node_repl, ) - self._configure_session(self._sessions[context_id], context_id, loop=asyncio.get_running_loop()) - return meta def _get_or_create_session( self, context_id: str, line_number_base: LineNumberBase | None = None, ) -> _Session: - session = self._sessions.get(context_id) - if session is not None: - self._configure_session(session, context_id, loop=asyncio.get_running_loop()) - return session - - base = line_number_base if line_number_base is not None else DEFAULT_LINE_NUMBER_BASE - meta = _analyze_text_context("", ContentFormat.TEXT) - repl = REPLEnvironment( - context="", - context_var_name="ctx", - config=self.sandbox_config, - loop=asyncio.get_running_loop(), + return _get_or_create_session_impl( + sessions=self._sessions, + context_id=context_id, + line_number_base=line_number_base, + sandbox_config=self.sandbox_config, + analyze_text_context=_analyze_text_context, + configure_session=self._configure_session, ) - repl.set_variable("line_number_base", base) - session = _Session(repl=repl, meta=meta, line_number_base=base) - self._sessions[context_id] = session - self._configure_session(session, context_id, loop=asyncio.get_running_loop()) - return session def _close_node_repl(self, context_id: str) -> None: - node_repl = self._node_repls.pop(context_id, None) - if node_repl is not None: - node_repl.close() + _close_node_repl_impl(self._node_repls, context_id) def _configure_node_repl( self, - node_repl: NodeREPLEnvironment, + node_repl: "NodeREPLEnvironment", session: _Session, ) -> None: - def _get_helper(name: str) -> Callable[..., Any]: - fn = session.repl.get_helper(name) - if not callable(fn): - raise RuntimeError(f"{name} is not available in this REPL session") - return cast(Callable[..., Any], fn) - - def _get_callable(name: str) -> Callable[..., Any]: - fn = session.repl.get_variable(name) - if not callable(fn): - raise RuntimeError(f"{name} is not available in this REPL session") - return cast(Callable[..., Any], fn) - - node_repl.register_callback( - "sub_query", - lambda prompt, context_slice=None: _get_callable("sub_query")(prompt, context_slice), - ) - node_repl.register_callback( - "sub_query_map", - lambda prompts, context_slices=None, limit=None, parallel=True: _get_helper("sub_query_map")( - prompts, - context_slices=context_slices, - limit=limit, - parallel=parallel, - ), - ) - node_repl.register_callback( - "sub_query_batch", - lambda prompt, context_slices, limit=None: _get_helper("sub_query_batch")( - prompt, - context_slices, - limit=limit, - ), - ) - node_repl.register_callback( - "sub_query_strict", - lambda prompt, context_slice=None, validate_regex=None, max_retries=0, retry_prompt=None: _get_helper( - "sub_query_strict" - )( - prompt, - context_slice=context_slice, - validate_regex=validate_regex, - max_retries=max_retries, - retry_prompt=retry_prompt, - ), - ) - node_repl.register_callback( - "sub_aleph", - lambda query, context=None: _get_callable("sub_aleph")(query, context), - ) - node_repl.register_callback("set_backend", lambda backend: _get_callable("set_backend")(backend)) - node_repl.register_callback("get_config", lambda: _get_callable("get_config")()) - - def _get_or_create_node_repl(self, context_id: str) -> NodeREPLEnvironment: - if context_id not in self._sessions: - raise KeyError(context_id) - - session = self._sessions[context_id] - node_repl = self._node_repls.get(context_id) - current_ctx = session.repl.get_variable("ctx") - current_loop = asyncio.get_running_loop() - - if node_repl is None: - node_repl = NodeREPLEnvironment( - context=current_ctx, - context_var_name="ctx", - config=self.sandbox_config, - loop=current_loop, - ) - self._node_repls[context_id] = node_repl - else: - node_repl.set_loop(current_loop) + _configure_node_repl_impl(node_repl, session) - node_repl.sync_context(current_ctx, session.line_number_base) - self._configure_node_repl(node_repl, session) - return node_repl + def _get_or_create_node_repl(self, context_id: str) -> "NodeREPLEnvironment": + return _get_or_create_node_repl_impl( + self._node_repls, self._sessions, context_id, self.sandbox_config + ) def _sync_session_from_node_repl(self, context_id: str) -> list[dict[str, Any]]: - node_repl = self._node_repls.get(context_id) - if node_repl is None or context_id not in self._sessions: - return [] - - session = self._sessions[context_id] - ctx_value = node_repl.get_variable("ctx") - ctx_text = _coerce_context_to_text(ctx_value) - session.repl.set_variable("ctx", ctx_text) - session.meta = _analyze_text_context(ctx_text, session.meta.format) - return node_repl.drain_citations() + return _sync_session_from_node_repl_impl( + self._node_repls, self._sessions, context_id, _analyze_text_context + ) def _first_doc_line(self, fn: Any) -> str: doc = inspect.getdoc(fn) or "" @@ -1899,12 +745,19 @@ def decorator(fn: Any) -> Any: return decorator - def _require_actions(self, confirm: bool) -> str | None: - if not self.action_config.enabled: - return "Actions are disabled. Start the server with `--enable-actions`." - if self.action_config.require_confirmation and not confirm: - return "Confirmation required. Re-run with confirm=true." - return None + def _require_actions( + self, + confirm: bool, + *, + requires_write: bool = False, + requires_command: bool = False, + ) -> str | None: + return _require_actions_impl( + self.action_config, + confirm, + requires_write=requires_write, + requires_command=requires_command, + ) async def _maybe_resolve_workspace_from_roots(self, ctx: "Context") -> None: """Try to resolve workspace root from MCP client roots (lazy, once).""" @@ -1928,39 +781,24 @@ async def _maybe_resolve_workspace_from_roots(self, ctx: "Context") -> None: pass def _record_action(self, session: _Session | None, note: str, snippet: str) -> None: - if session is None: - return - evidence_before = len(session.evidence) - session.evidence.append( - _Evidence( - source="action", - line_range=None, - pattern=None, - note=note, - snippet=snippet[:200], - ) + _record_action_impl(session, note=note, snippet=snippet) + + def _scoped_path(self, path: str) -> Path: + return _scoped_path( + self.action_config.workspace_root, + path, + self.action_config.workspace_mode, ) - session.information_gain.append(len(session.evidence) - evidence_before) def _build_memory_pack_payload( self, *, include_ctx: bool = True, ) -> tuple[dict[str, Any], list[str]]: - sessions_payload: list[dict[str, Any]] = [] - skipped: list[str] = [] - for sid, sess in self._sessions.items(): - try: - sessions_payload.append(_session_to_payload(sid, sess, include_ctx=include_ctx)) - except Exception: - skipped.append(sid) - payload = { - "schema": "aleph.memory_pack.v1", - "created_at": datetime.now().isoformat(), - "sessions": sessions_payload, - "skipped": skipped, - } - return payload, skipped + return _build_memory_pack_payload_impl( + self._sessions, + include_ctx=include_ctx, + ) async def _run_subprocess( self, @@ -1968,63 +806,17 @@ async def _run_subprocess( cwd: Path, timeout_seconds: float, ) -> dict[str, Any]: - start = time.perf_counter() - proc = await asyncio.create_subprocess_exec( - *argv, - cwd=str(cwd), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + return await _run_subprocess_impl( + self.action_config, + argv=argv, + cwd=cwd, + timeout_seconds=timeout_seconds, ) - timed_out = False - try: - stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout_seconds) - except asyncio.TimeoutError: - timed_out = True - proc.kill() - stdout_b, stderr_b = await proc.communicate() - - duration_ms = (time.perf_counter() - start) * 1000.0 - stdout = stdout_b.decode("utf-8", errors="replace") - stderr = stderr_b.decode("utf-8", errors="replace") - if len(stdout) > self.action_config.max_output_chars: - stdout = stdout[: self.action_config.max_output_chars] + "\n... (truncated)" - if len(stderr) > self.action_config.max_output_chars: - stderr = stderr[: self.action_config.max_output_chars] + "\n... (truncated)" - - return { - "argv": argv, - "cwd": str(cwd), - "exit_code": proc.returncode, - "timed_out": timed_out, - "duration_ms": duration_ms, - "stdout": stdout, - "stderr": stderr, - } - - def _parse_rg_vimgrep(self, output: str, max_results: int) -> tuple[list[dict[str, Any]], bool]: - results: list[dict[str, Any]] = [] - truncated = False - limit = max_results if max_results > 0 else None - for line in output.splitlines(): - parts = line.split(":", 3) - if len(parts) < 4: - continue - path_str, line_str, col_str, text = parts - try: - line_no = int(line_str) - col_no = int(col_str) - except ValueError: - continue - results.append({ - "path": path_str, - "line": line_no, - "column": col_no, - "text": text, - }) - if limit is not None and len(results) >= limit: - truncated = True - break - return results, truncated + + def _parse_rg_vimgrep( + self, output: str, max_results: int + ) -> tuple[list[dict[str, Any]], bool]: + return _parse_rg_vimgrep_impl(output, max_results) def _python_rg_search( self, @@ -2033,55 +825,25 @@ def _python_rg_search( glob_pattern: str | None, max_results: int, ) -> tuple[list[dict[str, Any]], bool]: - results: list[dict[str, Any]] = [] - truncated = False - limit = max_results if max_results > 0 else None - rx = re.compile(pattern) - skip_dirs = {".git", ".venv", "node_modules", "dist", "build", "__pycache__", ".mypy_cache", ".pytest_cache"} - - def _iter_files(root: Path) -> Iterable[Path]: - if root.is_file(): - yield root - return - for path in root.rglob("*"): - if path.is_dir(): - continue - if any(part in skip_dirs for part in path.parts): - continue - yield path - - for root in roots: - for path in _iter_files(root): - if glob_pattern and not fnmatch.fnmatch(path.name, glob_pattern): - continue - try: - if path.stat().st_size > self.action_config.max_read_bytes: - continue - with open(path, "r", encoding="utf-8", errors="replace") as f: - for idx, line in enumerate(f, start=1): - match = rx.search(line) - if not match: - continue - results.append({ - "path": str(path), - "line": idx, - "column": match.start() + 1, - "text": line.rstrip("\n"), - }) - if limit is not None and len(results) >= limit: - truncated = True - return results, truncated - except Exception: - continue - return results, truncated + return _python_rg_search_impl( + pattern, + roots, + glob_pattern, + max_results, + self.action_config.max_read_bytes, + ) def _auto_save_memory_pack(self) -> None: if self.context_policy == "isolated": return + if self.action_config.action_policy == "read-only": + return if not self.action_config.enabled or not self._sessions: return payload, _ = self._build_memory_pack_payload(include_ctx=True) - out_bytes = json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8", errors="replace") + out_bytes = json.dumps(payload, ensure_ascii=False, indent=2).encode( + "utf-8", errors="replace" + ) if len(out_bytes) > self.action_config.max_write_bytes: return try: @@ -2102,616 +864,19 @@ def _auto_save_memory_pack(self) -> None: self._record_action(sess, note="auto_save_memory_pack", snippet=str(p)) def _register_core_tools(self) -> None: - _tool = self._tool_decorator - - @_tool() - async def load_context( - content: str | None = None, - context_id: str = "default", - format: str = "auto", - line_number_base: LineNumberBase = DEFAULT_LINE_NUMBER_BASE, - context: str | None = None, - ) -> str: - """Load context into an in-memory REPL session. - - The context is stored in a sandboxed Python environment as the variable `ctx`. - You can then use other tools to explore and process this context. - - Args: - content: The text/data to load - context_id: Identifier for this context session (default: "default") - format: Content format - "auto", "text", or "json" (default: "auto") - line_number_base: Line number base for this context (0 or 1) - context: Deprecated alias for content - - Returns: - Confirmation with context metadata - """ - text = content if content is not None else context - if text is None: - return "Error: content is required" - try: - base = _validate_line_number_base(line_number_base) - except ValueError as e: - return f"Error: {e}" - - normalized_format = normalize_content_format(format, allow_auto=True) - fmt = cast( - ContentFormat, - _detect_format(text) if normalized_format == "auto" else normalized_format, - ) - meta = self._create_session(text, context_id, fmt, base) - return self._format_context_loaded(context_id, meta, base) - - @_tool() - async def list_contexts( - output: Literal["json", "markdown", "object"] = "json", - ) -> str | dict[str, Any]: - """List all active context sessions and their status.""" - items = [] - for cid, session in self._sessions.items(): - items.append({ - "id": cid, - "chars": session.meta.size_chars, - "lines": session.meta.size_lines, - "iterations": session.iterations, - "evidence": len(session.evidence), - }) - - if output == "object": - return {"count": len(items), "items": items} - if output == "json": - return json.dumps({"count": len(items), "items": items}, indent=2) - - res = [f"Found {len(items)} active context session(s):\n"] - for item in items: - res.append(f"- **{item['id']}**: {item['chars']:,} chars, {item['lines']:,} lines, {item['iterations']} iterations") - return "\n".join(res) - - @_tool() - async def diff_contexts( - a: str, - b: str, - context_lines: int = 3, - max_lines: int = 400, - output: Literal["markdown", "text"] = "markdown", - ) -> str: - """Compare two context sessions using unified diff.""" - if a not in self._sessions: - return f"Error: Context '{a}' not found." - if b not in self._sessions: - return f"Error: Context '{b}' not found." - - lines_a = str(self._sessions[a].repl.get_variable("ctx") or "").splitlines() - lines_b = str(self._sessions[b].repl.get_variable("ctx") or "").splitlines() - - diff = list(difflib.unified_diff( - lines_a, lines_b, - fromfile=f"context:{a}", - tofile=f"context:{b}", - n=context_lines, - lineterm="" - )) - - if not diff: - return f"Contexts '{a}' and '{b}' are identical." - - if len(diff) > max_lines: - diff = diff[:max_lines] + ["... (diff truncated)"] - - diff_text = "\n".join(diff) - if output == "markdown": - rendered = f"### Diff: {a} vs {b}\n\n```diff\n{diff_text}\n```" - else: - rendered = diff_text - - text, _ = self._truncate_tool_text(rendered) - return text - - @_tool() - async def save_session( - path: str = "aleph_session.json", - context_id: str | None = None, - session_id: str = "default", - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - ) -> str | dict[str, Any]: - """Save session state to a file (Memory Pack).""" - err = self._require_actions(confirm) - if err: - return _format_error(err, output=output) - if self.context_policy == "isolated" and not confirm: - return _format_error( - "Isolated policy requires confirm=true for session export (prevents accidental context leaks).\n" - "To proceed: save_session(path=..., confirm=true)\n" - "To switch policy: configure(context_policy='trusted')", - output=output, - ) - - payload, skipped = self._build_memory_pack_payload() - try: - p = _scoped_path(self.action_config.workspace_root, path, self.action_config.workspace_mode) - except Exception as e: - return _format_error(f"Invalid path: {e}", output=output) - - try: - p.parent.mkdir(parents=True, exist_ok=True) - with open(p, "w", encoding="utf-8") as f: - json.dump(payload, f, indent=2, ensure_ascii=False) - except Exception as e: - return _format_error(f"Failed to save: {e}", output=output) - - msg = f"Session saved to {path}." - if skipped: - msg += f" Warning: skipped {len(skipped)} sessions due to serialization errors." - - if output == "object": - return {"status": "success", "path": str(p), "skipped": skipped} - if output == "json": - return json.dumps({"status": "success", "path": str(p), "skipped": skipped}) - return msg - - @_tool() - async def load_session( - path: str, - context_id: str | None = None, - session_id: str | None = None, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - ) -> str | dict[str, Any]: - """Load session state from a file (Memory Pack).""" - err = self._require_actions(confirm) - if err: - return _format_error(err, output=output) - if self.context_policy == "isolated" and not confirm: - return _format_error( - "Isolated policy requires confirm=true for session import (prevents unvetted context rehydration).\n" - "To proceed: load_session(path=..., confirm=true)\n" - "To switch policy: configure(context_policy='trusted')", - output=output, - ) - - try: - p = _scoped_path(self.action_config.workspace_root, path, self.action_config.workspace_mode) - with open(p, "r", encoding="utf-8") as f: - payload = json.load(f) - except Exception as e: - return _format_error(f"Failed to load: {e}", output=output) - - if payload.get("schema") != "aleph.memory_pack.v1": - return _format_error("Invalid memory pack schema", output=output) - - loaded = [] - skipped: list[dict[str, str]] = [] - for sp in payload.get("sessions", []): - sid = _resolve_session_payload_id(sp) - if not sid: - skipped.append({"id": "", "error": "missing session identifier"}) - continue - try: - self._close_node_repl(sid) - session = _session_from_payload(sp, sid, self.sandbox_config, asyncio.get_running_loop()) - self._configure_session(session, sid, loop=asyncio.get_running_loop()) - self._sessions[sid] = session - loaded.append(sid) - except Exception as e: - skipped.append({"id": sid, "error": str(e)}) - - msg = f"Loaded {len(loaded)} session(s) from {path}." - if skipped: - msg += f" Skipped {len(skipped)} invalid session(s)." - if output == "object": - return {"status": "success", "loaded": loaded, "skipped": skipped} - if output == "json": - return json.dumps({"status": "success", "loaded": loaded, "skipped": skipped}) - return msg + _register_context_tools_module(self, format_error=_format_error) def _register_action_tools(self) -> None: - _tool = self._tool_decorator - - @_tool() - async def run_command( - cmd: str, - cwd: str | None = None, - timeout_seconds: float | None = None, - shell: bool = False, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", - ctx: Context = None, # type: ignore[assignment] - ) -> str | dict[str, Any]: - """Run a shell command.""" - err = self._require_actions(confirm) - if err: - return _format_error(err, output=output) - if ctx is not None: - await self._maybe_resolve_workspace_from_roots(ctx) - - session = self._get_or_create_session(context_id) - session.iterations += 1 - - workspace_root = self.action_config.workspace_root - cwd_path = ( - _scoped_path(workspace_root, cwd, self.action_config.workspace_mode) - if cwd - else workspace_root - ) - timeout = timeout_seconds if timeout_seconds is not None else self.action_config.max_cmd_seconds - - if shell: - user_shell = os.environ.get("SHELL", "/bin/sh") - argv = [user_shell, "-lc", cmd] - else: - argv = shlex.split(cmd) - if not argv: - return _format_error("Empty command", output=output) - - payload = await self._run_subprocess(argv=argv, cwd=cwd_path, timeout_seconds=timeout) - session.repl._namespace["last_command_result"] = payload - self._record_action(session, note="run_command", snippet=(payload.get("stdout") or payload.get("stderr") or "")[:200]) - return _format_payload(payload, output=output) - - @_tool() - async def rg_search( - pattern: str, - paths: list[str] | str | None = None, - glob: str | None = None, - max_results: int = 200, - load_context_id: str | None = None, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", - ctx: Context = None, # type: ignore[assignment] - ) -> str | dict[str, Any]: - """Fast codebase search using ripgrep (rg) with fallback scanning.""" - err = self._require_actions(confirm) - if err: - return _format_error(err, output=output) - if ctx is not None: - await self._maybe_resolve_workspace_from_roots(ctx) - if not pattern: - return _format_error("pattern is required", output=output) - if isinstance(paths, str): - paths = [paths] - - session = self._get_or_create_session(context_id) - session.iterations += 1 - - workspace_root = self.action_config.workspace_root - resolved_paths: list[Path] = [] - for p in paths or [str(workspace_root)]: - try: - resolved = _scoped_path(workspace_root, p, self.action_config.workspace_mode) - except Exception as e: - return _format_error(str(e), output=output) - resolved_paths.append(resolved) - - matches: list[dict[str, Any]] = [] - truncated = False - used_rg = False - payload: dict[str, Any] | None = None - - rg_bin = shutil.which("rg") - if rg_bin: - used_rg = True - argv = [rg_bin, "--vimgrep", pattern] - if glob: - argv.extend(["-g", glob]) - if max_results > 0: - argv.extend(["-m", str(max_results)]) - argv.extend(str(p) for p in resolved_paths) - payload = await self._run_subprocess( - argv=argv, - cwd=workspace_root, - timeout_seconds=self.action_config.max_cmd_seconds, - ) - matches, truncated = self._parse_rg_vimgrep(payload.get("stdout") or "", max_results) - else: - matches, truncated = self._python_rg_search( - pattern, - resolved_paths, - glob, - max_results, - ) - - hits_text = "\n".join( - f"{m['path']}:{m['line']}:{m['column']}:{m['text']}" for m in matches - ) - if load_context_id: - meta = self._create_session(hits_text, load_context_id, ContentFormat.TEXT, DEFAULT_LINE_NUMBER_BASE) - session.repl._namespace["last_rg_loaded_context"] = load_context_id - load_note = f"Loaded {len(matches)} match(es) into '{load_context_id}'." - else: - meta = None - load_note = None - - result_payload: dict[str, Any] = { - "pattern": pattern, - "paths": [str(p) for p in resolved_paths], - "used_rg": used_rg, - "match_count": len(matches), - "truncated": truncated, - "matches": matches, - } - if payload: - result_payload["command"] = payload.get("argv") - result_payload["timed_out"] = payload.get("timed_out", False) - result_payload["stderr"] = payload.get("stderr", "") - if load_context_id: - result_payload["loaded_context_id"] = load_context_id - result_payload["loaded_meta"] = { - "size_chars": meta.size_chars if meta else 0, - "size_lines": meta.size_lines if meta else 0, - } - if load_note: - result_payload["note"] = load_note - - session.repl._namespace["last_rg_result"] = result_payload - self._record_action(session, note="rg_search", snippet=f"{pattern} ({len(matches)} matches)") - - if output == "object": - return result_payload - if output == "json": - return json.dumps(result_payload, ensure_ascii=False, indent=2) - - parts = [ - "## rg_search Results", - f"Pattern: `{pattern}`", - f"Matches: {len(matches)}" + (" (truncated)" if truncated else ""), - ] - if load_note: - parts.append(load_note) - if matches: - parts.append("") - parts.extend([f"- {m['path']}:{m['line']}:{m['column']}: {m['text']}" for m in matches[:20]]) - if len(matches) > 20: - parts.append(f"... {len(matches) - 20} more") - return "\n".join(parts) - - @_tool() - async def read_file( - path: str, - start_line: int = 1, - limit: int = 200, - include_raw: bool = False, - line_number_base: int | None = None, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", - ctx: Context = None, # type: ignore[assignment] - ) -> str | dict[str, Any]: - """Read file content (raw).""" - err = self._require_actions(confirm) - if err: - return _format_error(err, output=output) - if ctx is not None: - await self._maybe_resolve_workspace_from_roots(ctx) - - base_override: LineNumberBase | None = None - if line_number_base is not None: - try: - base_override = _validate_line_number_base(line_number_base) - except ValueError as e: - return _format_error(str(e), output=output) - - session = self._get_or_create_session(context_id, base_override) - session.iterations += 1 - try: - base = _resolve_line_number_base(session, line_number_base) - except ValueError as e: - return _format_error(str(e), output=output) - - if base == 1 and start_line == 0: - start_line = 1 - if start_line < base: - return _format_error(f"start_line must be >= {base}", output=output) - - try: - p = _scoped_path(self.action_config.workspace_root, path, self.action_config.workspace_mode) - except Exception as e: - return _format_error(str(e), output=output) - - if not p.exists() or not p.is_file(): - return _format_error(f"File not found: {path}", output=output) - - data = p.read_bytes() - if len(data) > self.action_config.max_read_bytes: - return _format_error( - f"File too large to read (>{self.action_config.max_read_bytes} bytes): {path}", - output=output, - ) - - text = data.decode("utf-8", errors="replace") - lines = text.splitlines() - start_idx = max(0, start_line - base) - end_idx = min(len(lines), start_idx + max(0, limit)) - slice_lines = lines[start_idx:end_idx] - numbered = "\n".join( - f"{i + start_idx + base:>6}\t{line}" for i, line in enumerate(slice_lines) - ) - end_line = (start_idx + len(slice_lines) - 1 + base) if slice_lines else start_line - - payload: dict[str, Any] = { - "path": str(p), - "start_line": start_line, - "end_line": end_line, - "limit": limit, - "total_lines": len(lines), - "line_number_base": base, - "content": numbered, - } - if include_raw: - payload["content_raw"] = "\n".join(slice_lines) - session.repl._namespace["last_read_file_result"] = payload - self._record_action(session, note="read_file", snippet=f"{path} ({start_line}-{end_line})") - return _format_payload(payload, output=output) - - @_tool() - async def load_file( - path: str, - context_id: str = "default", - format: str = "auto", - line_number_base: LineNumberBase = DEFAULT_LINE_NUMBER_BASE, - confirm: bool = False, - ctx: Context = None, # type: ignore[assignment] - ) -> str: - """Load a workspace file into a context session.""" - err = self._require_actions(confirm) - if err: - return f"Error: {err}" - if ctx is not None: - await self._maybe_resolve_workspace_from_roots(ctx) - - try: - base = _validate_line_number_base(line_number_base) - except ValueError as e: - return f"Error: {e}" - - try: - p = _scoped_path(self.action_config.workspace_root, path, self.action_config.workspace_mode) - except Exception as e: - return f"Error: {e}" - - if not p.exists() or not p.is_file(): - return f"Error: File not found: {path}" - - try: - text, detected_fmt, warning = _load_text_from_path( - p, - self.action_config.max_read_bytes, - self.action_config.max_cmd_seconds, - ) - except ValueError as e: - return f"Error: {e}" - try: - normalized_format = normalize_content_format(format, allow_auto=True) - fmt = cast( - ContentFormat, - detected_fmt if normalized_format == "auto" else normalized_format, - ) - except Exception as e: - return f"Error: {e}" - meta = self._create_session(text, context_id, fmt, base) - session = self._get_or_create_session(context_id, base) - self._record_action(session, note="load_file", snippet=str(p)) - return self._format_context_loaded(context_id, meta, base, note=warning) - - @_tool() - async def write_file( - path: str, - content: str, - mode: Literal["overwrite", "append"] = "overwrite", - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", - ctx: Context = None, # type: ignore[assignment] - ) -> str | dict[str, Any]: - """Write file content.""" - err = self._require_actions(confirm) - if err: - return _format_error(err, output=output) - if ctx is not None: - await self._maybe_resolve_workspace_from_roots(ctx) - - session = self._get_or_create_session(context_id) - session.iterations += 1 - - try: - p = _scoped_path(self.action_config.workspace_root, path, self.action_config.workspace_mode) - except Exception as e: - return _format_error(str(e), output=output) - - payload_bytes = content.encode("utf-8", errors="replace") - if len(payload_bytes) > self.action_config.max_write_bytes: - return _format_error( - f"Content too large to write (>{self.action_config.max_write_bytes} bytes)", - output=output, - ) - - p.parent.mkdir(parents=True, exist_ok=True) - file_mode = "ab" if mode == "append" else "wb" - with open(p, file_mode) as f: - f.write(payload_bytes) - - payload: dict[str, Any] = { - "path": str(p), - "bytes_written": len(payload_bytes), - "mode": mode, - } - session.repl._namespace["last_write_file_result"] = payload - self._record_action(session, note="write_file", snippet=f"{path} ({len(payload_bytes)} bytes)") - return _format_payload(payload, output=output) - - @_tool() - async def run_tests( - runner: Literal["auto", "pytest"] = "auto", - args: list[str] | None = None, - cwd: str | None = None, - confirm: bool = False, - output: Literal["json", "markdown", "object"] = "json", - context_id: str = "default", - ctx: Context = None, # type: ignore[assignment] - ) -> str | dict[str, Any]: - """Run project tests.""" - err = self._require_actions(confirm) - if err: - return _format_error(err, output=output) - if ctx is not None: - await self._maybe_resolve_workspace_from_roots(ctx) - - session = self._get_or_create_session(context_id) - session.iterations += 1 - - workspace_root = self.action_config.workspace_root - cwd_path = ( - _scoped_path(workspace_root, cwd, self.action_config.workspace_mode) - if cwd - else workspace_root - ) - - # Heuristics for test runner - runner_bin: str = str(runner) - if runner == "auto": - runner_bin = "pytest" - - argv: list[str] = [runner_bin] - if args: - argv.extend(args) - - payload = await self._run_subprocess(argv=argv, cwd=cwd_path, timeout_seconds=self.action_config.max_cmd_seconds) - self._record_action(session, note=f"run_tests: {runner}", snippet=(payload.get("stdout") or payload.get("stderr") or "")[:200]) - return _format_payload(payload, output=output) + _register_action_tools_module( + self, format_error=_format_error, format_payload=_format_payload + ) def _format_execution_result(self, result: ExecutionResult) -> str | dict[str, Any]: - """Format sandboxed execution results for output.""" - if result.error: - text, _ = self._truncate_tool_text(f"## Execution Error\n\n{result.error}") - return text - - res = ["## Execution Result\n"] - formatting_truncated = False - if result.stdout: - stdout_text, was_truncated = self._truncate_tool_text(result.stdout) - formatting_truncated = formatting_truncated or was_truncated - res.append(f"**Output:**\n```\n{stdout_text}\n```") - if result.stderr: - stderr_text, was_truncated = self._truncate_tool_text(result.stderr) - formatting_truncated = formatting_truncated or was_truncated - res.append(f"**Stderr:**\n```\n{stderr_text}\n```") - if result.return_value is not None: - rendered = repr(result.return_value) - rendered, was_truncated = self._truncate_tool_text(rendered) - formatting_truncated = formatting_truncated or was_truncated - res.append(f"**Return Value:** `{rendered}`") - if result.variables_updated: - res.append(f"\n**Variables Updated:** {', '.join(f'`{v}`' for v in result.variables_updated)}") - - if result.truncated or formatting_truncated: - res.append("\n*Note: Output was truncated*") - - out = "\n".join(res) - out, _ = self._truncate_tool_text(out) - return out + return _format_execution_result_impl( + result, + max_chars=self.max_tool_response_chars, + truncation_suffix=_TOOL_TRUNCATION_SUFFIX, + ) def _truncate_tool_text( self, @@ -2719,24 +884,13 @@ def _truncate_tool_text( *, max_chars: int | None = None, ) -> tuple[str, bool]: - limit = self.max_tool_response_chars if max_chars is None else max_chars - if limit <= 0 or len(text) <= limit: - return text, False - if limit <= len(_TOOL_TRUNCATION_SUFFIX): - return _TOOL_TRUNCATION_SUFFIX[:limit], True - - # Keep a compact prefix/suffix preview instead of a large contiguous head. - # This avoids spilling big raw blocks (for example long repeated characters) - # into the model context while still preserving enough signal for debugging. - preview_each_side = min(400, max(0, (limit - len(_TOOL_TRUNCATION_SUFFIX)) // 2)) - if preview_each_side == 0: - keep = limit - len(_TOOL_TRUNCATION_SUFFIX) - return text[:keep] + _TOOL_TRUNCATION_SUFFIX, True - return ( - text[:preview_each_side] - + _TOOL_TRUNCATION_SUFFIX - + text[-preview_each_side:] - ), True + return _truncate_tool_text_impl( + text, + max_chars=( + self.max_tool_response_chars if max_chars is None else max_chars + ), + truncation_suffix=_TOOL_TRUNCATION_SUFFIX, + ) def _limit_json_items( self, @@ -2744,55 +898,22 @@ def _limit_json_items( *, max_chars: int | None = None, ) -> tuple[list[Any], bool]: - limit = self.max_tool_response_chars if max_chars is None else max_chars - used = 2 # [] delimiters - limited: list[Any] = [] - - for raw in items: - item = _to_jsonable(raw) - try: - encoded = json.dumps(item, ensure_ascii=False) - except Exception: - encoded = json.dumps(str(item), ensure_ascii=False) - - projected = used + len(encoded) + (1 if limited else 0) - if projected > limit: - return limited, True - - limited.append(item) - used = projected - - return limited, False + return _limit_json_items_impl( + items, + max_chars=( + self.max_tool_response_chars if max_chars is None else max_chars + ), + to_jsonable=_to_jsonable, + ) def _format_variable_value(self, name: str, value: Any) -> Any: - if value is None or isinstance(value, (int, float, bool)): - return value - - if isinstance(value, str): - text, truncated = self._truncate_tool_text(value) - if not truncated: - return value - return { - "name": name, - "truncated": True, - "original_chars": len(value), - "value_preview": text, - } - - jsonable = _to_jsonable(value) - try: - rendered = json.dumps(jsonable, ensure_ascii=False) - except Exception: - rendered = str(jsonable) - text, truncated = self._truncate_tool_text(rendered) - if not truncated: - return jsonable - return { - "name": name, - "truncated": True, - "original_chars": len(rendered), - "value_preview": text, - } + return _format_variable_value_impl( + name, + value, + max_chars=self.max_tool_response_chars, + truncation_suffix=_TOOL_TRUNCATION_SUFFIX, + to_jsonable=_to_jsonable, + ) def _register_query_tools(self) -> None: _register_query_tools_module( @@ -2807,6 +928,9 @@ def _register_reasoning_tools(self) -> None: def _register_mcp_tools(self) -> None: register_admin_tools(self, format_error=_format_error) + def _register_workspace_tools(self) -> None: + register_workspace_tools(self) + def _register_tools(self) -> None: """Register all MCP tools.""" self._register_core_tools() @@ -2814,6 +938,7 @@ def _register_tools(self) -> None: self._register_query_tools() self._register_reasoning_tools() self._register_mcp_tools() + self._register_workspace_tools() async def run(self, transport: str = "stdio") -> None: """Run the MCP server.""" @@ -2822,6 +947,7 @@ async def run(self, transport: str = "stdio") -> None: await self.server.run_stdio_async() + def main() -> None: """CLI entry point: `aleph` or `python -m aleph.mcp.local_server`""" @@ -2840,7 +966,9 @@ def main() -> None: args, detect_workspace_root=_detect_workspace_root, normalize_context_policy=_normalize_context_policy, + normalize_action_policy=_normalize_action_policy, default_context_policy=DEFAULT_CONTEXT_POLICY, + default_action_policy=DEFAULT_ACTION_POLICY, sandbox_config_factory=SandboxConfig, action_config_factory=ActionConfig, ) diff --git a/aleph/mcp/node_bridge.py b/aleph/mcp/node_bridge.py new file mode 100644 index 0000000..b446c27 --- /dev/null +++ b/aleph/mcp/node_bridge.py @@ -0,0 +1,160 @@ +"""Node REPL bridge for the local MCP server. + +Manages the lifecycle of per-context NodeREPLEnvironment instances: +creation, context sync, callback registration, and teardown. +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, Callable, cast + +from ..repl.node_runtime import NodeREPLEnvironment +from ..repl.sandbox import SandboxConfig +from .session import _Session, _coerce_context_to_text + +if TYPE_CHECKING: + from ..types import ContentFormat, ContextMetadata + + +__all__ = [ + "close_node_repl", + "configure_node_repl", + "get_or_create_node_repl", + "sync_session_from_node_repl", +] + + +def close_node_repl( + node_repls: dict[str, NodeREPLEnvironment], + context_id: str, +) -> None: + """Remove and close the Node REPL for *context_id*, if any.""" + node_repl = node_repls.pop(context_id, None) + if node_repl is not None: + node_repl.close() + + +def configure_node_repl( + node_repl: NodeREPLEnvironment, + session: _Session, +) -> None: + """Register Python→Node callback bridges on *node_repl*. + + Each callback delegates to the corresponding helper/variable already + injected into the Python *session.repl* by the REPL injection layer. + """ + + def _get_helper(name: str) -> Callable[..., Any]: + fn = session.repl.get_helper(name) + if not callable(fn): + raise RuntimeError(f"{name} is not available in this REPL session") + return cast(Callable[..., Any], fn) + + def _get_callable(name: str) -> Callable[..., Any]: + fn = session.repl.get_variable(name) + if not callable(fn): + raise RuntimeError(f"{name} is not available in this REPL session") + return cast(Callable[..., Any], fn) + + node_repl.register_callback( + "sub_query", + lambda prompt, context_slice=None: _get_callable("sub_query")( + prompt, context_slice + ), + ) + node_repl.register_callback( + "sub_query_map", + lambda prompts, context_slices=None, limit=None, parallel=True: _get_helper( + "sub_query_map" + )( + prompts, + context_slices=context_slices, + limit=limit, + parallel=parallel, + ), + ) + node_repl.register_callback( + "sub_query_batch", + lambda prompt, context_slices, limit=None: _get_helper("sub_query_batch")( + prompt, + context_slices, + limit=limit, + ), + ) + node_repl.register_callback( + "sub_query_strict", + lambda prompt, context_slice=None, validate_regex=None, max_retries=0, retry_prompt=None: ( + _get_helper("sub_query_strict")( + prompt, + context_slice=context_slice, + validate_regex=validate_regex, + max_retries=max_retries, + retry_prompt=retry_prompt, + ) + ), + ) + node_repl.register_callback( + "sub_aleph", + lambda query, context=None: _get_callable("sub_aleph")(query, context), + ) + node_repl.register_callback( + "set_backend", lambda backend: _get_callable("set_backend")(backend) + ) + node_repl.register_callback("get_config", lambda: _get_callable("get_config")()) + + +def get_or_create_node_repl( + node_repls: dict[str, NodeREPLEnvironment], + sessions: dict[str, _Session], + context_id: str, + sandbox_config: SandboxConfig, +) -> NodeREPLEnvironment: + """Return the Node REPL for *context_id*, creating one if needed. + + The returned REPL has its context synced and callbacks configured. + """ + if context_id not in sessions: + raise KeyError(context_id) + + session = sessions[context_id] + node_repl = node_repls.get(context_id) + current_ctx = session.repl.get_variable("ctx") + current_loop = asyncio.get_running_loop() + + if node_repl is None: + node_repl = NodeREPLEnvironment( + context=current_ctx, + context_var_name="ctx", + config=sandbox_config, + loop=current_loop, + ) + node_repls[context_id] = node_repl + else: + node_repl.set_loop(current_loop) + + node_repl.sync_context(current_ctx, session.line_number_base) + configure_node_repl(node_repl, session) + return node_repl + + +def sync_session_from_node_repl( + node_repls: dict[str, NodeREPLEnvironment], + sessions: dict[str, _Session], + context_id: str, + analyze_text_context: Callable[[str, "ContentFormat"], "ContextMetadata"], +) -> list[dict[str, Any]]: + """Sync state from Node REPL back into the Python session. + + Returns any citations drained from the Node REPL. + """ + node_repl = node_repls.get(context_id) + if node_repl is None or context_id not in sessions: + return [] + + session = sessions[context_id] + ctx_value = node_repl.get_variable("ctx") + ctx_text = _coerce_context_to_text(ctx_value) + session.repl.set_variable("ctx", ctx_text) + session.meta = analyze_text_context(ctx_text, session.meta.format) + return node_repl.drain_citations() diff --git a/aleph/mcp/reasoning_tools.py b/aleph/mcp/reasoning_tools.py index 0a39808..cf09a08 100644 --- a/aleph/mcp/reasoning_tools.py +++ b/aleph/mcp/reasoning_tools.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Callable, Literal from .recipe_tools import register_recipe_tools +from .workspace_contexts import binding_status, binding_summary if TYPE_CHECKING: from .local_server import AlephMCPServerLocal @@ -19,6 +20,15 @@ def register_reasoning_tools( ) -> None: _tool = owner._tool_decorator + def _task_store(session: Any) -> list[dict[str, Any]]: + namespace_tasks = session.repl._namespace.get("_tasks") # type: ignore[attr-defined] + if isinstance(namespace_tasks, list): + if session.tasks is not namespace_tasks: + session.tasks = namespace_tasks + return namespace_tasks + session.repl._namespace["_tasks"] = session.tasks # type: ignore[attr-defined] + return session.tasks + @_tool() async def think( question: str, @@ -31,6 +41,7 @@ async def think( session = owner._sessions[context_id] session.iterations += 1 + session.think_history.append(question) log_entry = { "iteration": session.iterations, @@ -72,7 +83,7 @@ async def tasks( session = owner._sessions[context_id] session.iterations += 1 - tasks_list: list[dict[str, Any]] = session.repl._namespace.setdefault("_tasks", []) # type: ignore + tasks_list = _task_store(session) if action == "add" and description: new_id = task_id or f"T{len(tasks_list) + 1}" @@ -89,7 +100,7 @@ async def tasks( return f"Error: Task {task_id} not found." if action == "clear": - session.repl._namespace["_tasks"] = [] + tasks_list.clear() return "All tasks cleared." if not tasks_list: @@ -111,7 +122,7 @@ async def get_status( return f"Error: No context loaded with ID '{context_id}'." session = owner._sessions[context_id] - tasks_list: list[dict[str, Any]] = session.repl._namespace.get("_tasks", []) # type: ignore + tasks_list = _task_store(session) status = { "context_id": context_id, "iterations": session.iterations, @@ -123,7 +134,11 @@ async def get_status( "workspace_root": str(owner.action_config.workspace_root), "workspace_root_source": owner._workspace_root_source, "context_policy": owner.context_policy, + "action_policy": owner.action_config.action_policy, "auto_memory_pack": owner.context_policy != "isolated", + "workspace_binding": session.workspace_binding, + "workspace_binding_summary": binding_summary(session.workspace_binding), + "workspace_binding_status": binding_status(session.workspace_binding), } if output == "object": @@ -134,11 +149,14 @@ async def get_status( res = [f"## Session Status: {context_id}\n"] res.append(f"- **Iterations**: {session.iterations}") res.append(f"- **Evidence Items**: {len(session.evidence)}") - res.append(f"- **Tracked Tasks**: {len(session.repl._namespace.get('_tasks', []))}") # type: ignore + res.append(f"- **Tracked Tasks**: {len(tasks_list)}") res.append(f"- **User Variables**: {', '.join(status['variables']) or 'None'}") # type: ignore res.append(f"- **Context Size**: {session.meta.size_chars:,} chars ({session.meta.size_lines:,} lines)") res.append(f"- **Workspace Root**: {owner.action_config.workspace_root} ({owner._workspace_root_source})") res.append(f"- **Context Policy**: {owner.context_policy}") + res.append(f"- **Action Policy**: {owner.action_config.action_policy}") + if status["workspace_binding_summary"]: + res.append(f"- **Workspace Binding**: {status['workspace_binding_summary']}") return "\n".join(res) @_tool() diff --git a/aleph/mcp/recipe_runtime.py b/aleph/mcp/recipe_runtime.py new file mode 100644 index 0000000..5c2d453 --- /dev/null +++ b/aleph/mcp/recipe_runtime.py @@ -0,0 +1,472 @@ +"""Recipe execution and compilation runtime for the local MCP server. + +This module contains the core recipe execution engine (step dispatch, +budget enforcement, sub-query orchestration within recipes) and the +recipe-code compilation flow (execute DSL code, extract the recipe +object, validate). +""" + +from __future__ import annotations + +import asyncio +import re +from typing import TYPE_CHECKING, Any, Callable + +from ..repl import helpers as repl_helpers +from .formatting import _to_jsonable +from .recipes import estimate_recipe as _estimate_recipe +from .recipes import validate_recipe as _validate_recipe +from .session import _Evidence, _coerce_context_to_text + +if TYPE_CHECKING: + from .local_server import AlephMCPServerLocal + + +__all__ = [ + "compile_recipe_code", + "execute_recipe", + "recipe_context_slice", + "recipe_preview", +] + + +def recipe_preview(value: Any, limit: int = 180) -> str: + text = _coerce_context_to_text(value) + if len(text) <= limit: + return text + return text[: limit - 3] + "..." + + +def recipe_context_slice(value: Any, context_field: str | None) -> str: + selected = value + if context_field: + if isinstance(value, dict): + selected = value.get(context_field) + elif isinstance(value, list): + extracted: list[Any] = [] + for item in value: + if isinstance(item, dict): + extracted.append(item.get(context_field)) + else: + extracted.append(item) + selected = extracted + return _coerce_context_to_text(selected) + + +async def execute_recipe( + owner: "AlephMCPServerLocal", + *, + recipe: dict[str, Any], + context_id_override: str | None = None, + dry_run: bool = False, + progress_callback: Callable[[float, float | None, str | None], Any] | None = None, +) -> tuple[bool, dict[str, Any]]: + normalized, errors = _validate_recipe(recipe) + if errors: + return False, {"errors": errors} + assert normalized is not None + + resolved_context_id = context_id_override or normalized["context_id"] + if resolved_context_id not in owner._sessions: + return False, {"error": f"No context loaded with ID '{resolved_context_id}'."} + + estimate = _estimate_recipe(normalized) + if dry_run: + return True, { + "context_id": resolved_context_id, + "mode": "dry_run", + "recipe": normalized, + "estimate": estimate, + } + + session = owner._sessions[resolved_context_id] + budget = normalized["budget"] + max_steps = int(budget["max_steps"]) + max_sub_queries = int(budget["max_sub_queries"]) + + current: Any = session.repl.get_variable("ctx") + variables: dict[str, Any] = {"ctx": current} + trace: list[dict[str, Any]] = [] + sub_queries_used = 0 + total_steps = float(len(normalized["steps"])) + + async def _report( + progress: float, total: float | None = None, message: str | None = None + ) -> None: + if progress_callback is not None: + try: + result = progress_callback(progress, total, message) + if asyncio.iscoroutine(result): + await result + except Exception: + pass + + for step_index, step in enumerate(normalized["steps"], 1): + if step_index > max_steps: + return False, { + "error": f"Recipe exceeded budget.max_steps ({step_index} > {max_steps})", + "failed_step": step_index, + "trace": trace, + } + + session.iterations += 1 + + input_name = step.get("input") + if input_name: + if input_name not in variables: + return False, { + "error": f"Step {step_index}: input variable '{input_name}' not found.", + "failed_step": step_index, + "trace": trace, + } + current = variables[input_name] + + op = step["op"] + step_trace: dict[str, Any] = { + "step": step_index, + "op": op, + } + + try: + if op == "search": + current = repl_helpers.search( + current, + step["pattern"], + context_lines=step.get("context_lines", 2), + max_results=step.get("max_results", 20), + ) + step_trace["result_count"] = ( + len(current) if isinstance(current, list) else 0 + ) + + elif op == "peek": + current = repl_helpers.peek( + current, + start=step.get("start", 0), + end=step.get("end"), + ) + + elif op == "lines": + current = repl_helpers.lines( + current, + start=step.get("start", 0), + end=step.get("end"), + ) + + elif op == "take": + count = int(step["count"]) + if isinstance(current, str): + current = current[:count] + elif isinstance(current, (list, tuple)): + current = list(current)[:count] + else: + raise ValueError("take requires a list/tuple/string value") + + elif op == "chunk": + text = _coerce_context_to_text(current) + chunk_size = int(step["chunk_size"]) + overlap = int(step.get("overlap", 0)) + current = repl_helpers.chunk(text, chunk_size, overlap) + step_trace["result_count"] = len(current) + + elif op == "filter": + if not isinstance(current, list): + raise ValueError("filter requires current value to be a list") + field_name = step.get("field") + pattern = step.get("pattern") + contains = step.get("contains") + rx = re.compile(pattern) if pattern else None + out: list[Any] = [] + for item in current: + candidate: Any = item + if field_name: + if isinstance(item, dict): + candidate = item.get(field_name) + else: + candidate = None + candidate_text = _coerce_context_to_text(candidate) + matched = True + if rx is not None: + matched = bool(rx.search(candidate_text)) + if contains is not None: + matched = matched and contains in candidate_text + if matched: + out.append(item) + current = out + step_trace["result_count"] = len(current) + + elif op == "assign": + variables[step["name"]] = current + + elif op == "load": + name = step["name"] + if name not in variables: + raise ValueError(f"variable '{name}' not found") + current = variables[name] + + elif op == "map_sub_query": + if not isinstance(current, list): + raise ValueError( + "map_sub_query requires current value to be a list" + ) + + limit = step.get("limit") + items = current[:limit] if isinstance(limit, int) else current + parallel = step.get("parallel", True) + continue_on_error = step.get("continue_on_error", False) + + remaining_budget = max_sub_queries - sub_queries_used + if len(items) > remaining_budget: + raise RuntimeError( + f"Recipe sub-query budget would be exceeded " + f"({sub_queries_used} + {len(items)} > {max_sub_queries})" + ) + + if parallel and len(items) > 1: + parallel_limit = max( + 1, min(owner.max_recipe_concurrency, len(items)) + ) + sem = asyncio.Semaphore(parallel_limit) + + async def _run_item( + idx: int, item: object + ) -> tuple[int, bool, str]: + async with sem: + ctx_slice = recipe_context_slice( + item, step.get("context_field") + ) + ok, out, _trunc, _bk = await owner._run_sub_query( + prompt=step["prompt"], + context_slice=ctx_slice, + context_id=resolved_context_id, + backend=step.get("backend", "auto"), + ) + return idx, ok, out + + tasks = [_run_item(i, it) for i, it in enumerate(items)] + results = await asyncio.gather(*tasks, return_exceptions=True) + outputs: list[str] = [""] * len(items) + for r in results: + if isinstance(r, BaseException): + if not continue_on_error: + raise RuntimeError(f"sub_query failed: {r}") + outputs[0] = f"[ERROR] {r}" + else: + idx, ok, item_output = r + if not ok and not continue_on_error: + raise RuntimeError(f"sub_query failed: {item_output}") + outputs[idx] = ( + item_output if ok else f"[ERROR] {item_output}" + ) + sub_queries_used += len(items) + else: + outputs = [] + for item in items: + context_slice = recipe_context_slice( + item, step.get("context_field") + ) + ( + success, + output, + _truncated, + _backend, + ) = await owner._run_sub_query( + prompt=step["prompt"], + context_slice=context_slice, + context_id=resolved_context_id, + backend=step.get("backend", "auto"), + ) + sub_queries_used += 1 + if not success and not continue_on_error: + raise RuntimeError(f"sub_query failed: {output}") + outputs.append(output if success else f"[ERROR] {output}") + await _report( + float(step_index), + total_steps, + f"map_sub_query: {len(items)} items processed", + ) + + current = outputs + step_trace["sub_queries"] = len(outputs) + step_trace["parallel"] = parallel and len(items) > 1 + + elif op in {"sub_query", "aggregate"}: + if sub_queries_used >= max_sub_queries: + raise RuntimeError( + "Recipe sub-query budget exceeded " + f"({sub_queries_used} >= {max_sub_queries})" + ) + + if op == "aggregate" and isinstance(current, list): + context_slice = "\n\n".join( + _coerce_context_to_text(item) for item in current + ) + else: + context_slice = recipe_context_slice( + current, step.get("context_field") + ) + + success, output, _truncated, _backend = await owner._run_sub_query( + prompt=step["prompt"], + context_slice=context_slice, + context_id=resolved_context_id, + backend=step.get("backend", "auto"), + ) + sub_queries_used += 1 + if not success: + raise RuntimeError(f"sub_query failed: {output}") + current = output + step_trace["sub_queries"] = 1 + + elif op == "finalize": + step_trace["status"] = "finalized" + trace.append(step_trace) + break + + else: + raise ValueError(f"unsupported op: {op}") + except Exception as e: + step_trace["status"] = "error" + step_trace["error"] = str(e) + trace.append(step_trace) + session.evidence.append( + _Evidence( + source="exec", + line_range=None, + pattern=None, + note=f"run_recipe failed at step {step_index}", + snippet=f"{op}: {str(e)[:180]}", + ) + ) + return False, { + "error": f"Step {step_index} ({op}) failed: {e}", + "failed_step": step_index, + "trace": trace, + "sub_queries_used": sub_queries_used, + "budget": budget, + "estimate": estimate, + } + + store_name = step.get("store") + if store_name: + variables[store_name] = current + + step_trace["status"] = "ok" + step_trace["preview"] = recipe_preview(current) + trace.append(step_trace) + await _report( + float(step_index), + total_steps, + f"Step {step_index}/{int(total_steps)} ({op}) done", + ) + + session.evidence.append( + _Evidence( + source="exec", + line_range=None, + pattern=None, + note=f"run_recipe completed ({len(trace)} steps)", + snippet=recipe_preview(current), + ) + ) + + payload = { + "context_id": resolved_context_id, + "recipe_version": normalized["version"], + "step_count": len(normalized["steps"]), + "sub_queries_used": sub_queries_used, + "budget": budget, + "estimate": estimate, + "trace": trace, + "value": _to_jsonable(current), + "variables": sorted(variables.keys()), + } + return True, payload + + +async def compile_recipe_code( + owner: "AlephMCPServerLocal", + *, + code: str, + context_id: str = "default", + language: str = "python", +) -> tuple[bool, dict[str, Any]]: + if context_id not in owner._sessions: + return False, {"error": f"No context loaded with ID '{context_id}'."} + + session = owner._sessions[context_id] + session.iterations += 1 + + if language in ("javascript", "typescript"): + node_repl = owner._get_or_create_node_repl(context_id) + result = await node_repl.execute_async( + code, + language=language, # type: ignore[arg-type] + ) + else: + result = await session.repl.execute_async(code) + + if result.error: + return False, { + "error": f"Recipe code execution failed: {result.error}", + "execution": { + "stderr": result.stderr, + "stdout": result.stdout, + }, + } + + if language in ("javascript", "typescript"): + candidate = result.return_value + if candidate is None: + maybe_node_repl = owner._node_repls.get(context_id) + if maybe_node_repl is not None: + candidate = maybe_node_repl.get_variable("recipe") + owner._sync_session_from_node_repl(context_id) + else: + candidate = result.return_value + if candidate is None: + candidate = session.repl.get_variable("recipe") + + if candidate is None: + return False, { + "error": ( + "Recipe code did not return a recipe value. " + "Return a RecipeBuilder/dict or assign to variable `recipe`." + ), + } + + compiled: Any = candidate + if isinstance(candidate, dict): + compiled = dict(candidate) + elif hasattr(candidate, "compile") and callable(getattr(candidate, "compile")): + compiled = candidate.compile() + elif hasattr(candidate, "to_dict") and callable(getattr(candidate, "to_dict")): + compiled = candidate.to_dict() + else: + return False, { + "error": ( + "Recipe code returned unsupported type. " + "Expected dict or object with compile()/to_dict()." + ), + "type": str(type(candidate)), + } + + normalized, errors = _validate_recipe(compiled) + if errors or normalized is None: + return False, { + "error": "Compiled recipe is invalid.", + "errors": errors, + "compiled": _to_jsonable(compiled), + } + + return True, { + "context_id": context_id, + "recipe": normalized, + "estimate": _estimate_recipe(normalized), + "execution": { + "variables_updated": result.variables_updated, + "execution_time_ms": result.execution_time_ms, + "stdout": result.stdout, + "stderr": result.stderr, + }, + } diff --git a/aleph/mcp/repl_injection.py b/aleph/mcp/repl_injection.py new file mode 100644 index 0000000..1f4c339 --- /dev/null +++ b/aleph/mcp/repl_injection.py @@ -0,0 +1,108 @@ +"""REPL helper injection for the local MCP server.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from ..types import AlephResponse, ContextType +from .session import _Session, _coerce_context_to_text + +if TYPE_CHECKING: + from .local_server import AlephMCPServerLocal + + +__all__ = [ + "configure_session", + "inject_repl_config_helpers", + "inject_repl_sub_aleph", + "inject_repl_sub_query", +] + + +def inject_repl_config_helpers( + owner: "AlephMCPServerLocal", + session: _Session, +) -> None: + """Expose runtime configuration helpers inside the Python REPL.""" + + def set_backend(backend: str) -> str: + ok, message = owner._apply_sub_query_runtime_config(sub_query_backend=backend) + if not ok: + raise ValueError(message) + snapshot = owner._get_sub_query_config_snapshot() + return ( + "sub_query_backend set to " + f"{snapshot['sub_query_backend']!r} " + f"(resolved: {snapshot['sub_query_backend_resolved']!r})" + ) + + def get_config() -> dict[str, object]: + return owner._get_sub_query_config_snapshot() + + session.repl.set_variable("set_backend", set_backend) + session.repl.set_variable("get_config", get_config) + + +def inject_repl_sub_query( + owner: "AlephMCPServerLocal", + session: _Session, + context_id: str, +) -> None: + """Inject the recursive `sub_query` helper into the Python REPL.""" + + async def sub_query(prompt: str, context_slice: str | None = None) -> str: + success, output, _truncated, _backend = await owner._run_sub_query( + prompt=prompt, + context_slice=context_slice, + context_id=context_id, + backend="auto", + ) + if not success: + return f"[ERROR: sub_query failed: {output}]" + return output + + session.repl.inject_sub_query(sub_query) + + +def inject_repl_sub_aleph( + owner: "AlephMCPServerLocal", + session: _Session, + context_id: str, +) -> None: + """Inject the recursive `sub_aleph` helper into the Python REPL.""" + + async def sub_aleph( + query: str, + context: ContextType | None = None, + ) -> AlephResponse: + context_slice: str | None + if context is None: + context_slice = None + elif isinstance(context, str): + context_slice = context + else: + context_slice = _coerce_context_to_text(context) + response, _meta = await owner._run_sub_aleph( + query=query, + context_slice=context_slice, + context_id=context_id, + ) + return response + + session.repl.inject_sub_aleph(sub_aleph) + + +def configure_session( + owner: "AlephMCPServerLocal", + session: _Session, + context_id: str, + loop: asyncio.AbstractEventLoop | None = None, +) -> None: + """Configure a session REPL with loop state and recursive helpers.""" + + if loop is not None: + session.repl.set_loop(loop) + inject_repl_sub_query(owner, session, context_id) + inject_repl_sub_aleph(owner, session, context_id) + inject_repl_config_helpers(owner, session) diff --git a/aleph/mcp/server_bootstrap.py b/aleph/mcp/server_bootstrap.py index 850ca08..a76ba42 100644 --- a/aleph/mcp/server_bootstrap.py +++ b/aleph/mcp/server_bootstrap.py @@ -82,6 +82,13 @@ def build_server_argument_parser( default=default_workspace_mode, help="Path scope for action tools: fixed (workspace root only), git (any git repo), any (no path restriction)", ) + parser.add_argument( + "--action-policy", + type=str, + choices=["read-write", "read-only"], + default=None, + help="Filesystem/process policy for action tools: read-write (default) or read-only.", + ) parser.add_argument( "--require-confirmation", action="store_true", @@ -243,6 +250,8 @@ def apply_server_env_overrides(args: argparse.Namespace) -> None: if args.swarm_mode: os.environ["ALEPH_SWARM_MODE"] = "true" + if getattr(args, "action_policy", None) is not None: + os.environ["ALEPH_ACTION_POLICY"] = args.action_policy if args.swarm_name is not None: os.environ["ALEPH_SWARM_NAME"] = args.swarm_name if args.enable_session_sharing: @@ -258,7 +267,9 @@ def build_runtime_configs( *, detect_workspace_root: Callable[[], Path], normalize_context_policy: Callable[[str | None, str], str], + normalize_action_policy: Callable[[str | None, str], str], default_context_policy: str, + default_action_policy: str, sandbox_config_factory: Callable[..., Any], action_config_factory: Callable[..., Any], ) -> tuple[Any, Any, str]: @@ -282,6 +293,10 @@ def build_runtime_configs( env_settings.context_policy, default_context_policy, ), + action_policy=normalize_action_policy( + getattr(args, "action_policy", None) or env_settings.action_policy, + default_action_policy, + ), require_confirmation=bool(args.require_confirmation), max_read_bytes=args.max_file_size, max_write_bytes=args.max_write_bytes, diff --git a/aleph/mcp/session.py b/aleph/mcp/session.py index bd9a225..fa9e84b 100644 --- a/aleph/mcp/session.py +++ b/aleph/mcp/session.py @@ -11,7 +11,7 @@ import json from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Literal +from typing import Any, Callable, Literal from ..repl.sandbox import REPLEnvironment, SandboxConfig from ..types import ContentFormat, ContextMetadata @@ -114,6 +114,8 @@ class _Session: information_gain: list[int] = field(default_factory=list) # Chunk metadata for navigation chunks: list[dict] | None = None + # Optional binding back to a workspace asset (file or generated manifest) + workspace_binding: dict[str, Any] | None = None # Lightweight task tracking tasks: list[dict[str, Any]] = field(default_factory=list) task_counter: int = 0 @@ -196,6 +198,7 @@ def _session_to_payload( "confidence_history": list(session.confidence_history), "information_gain": list(session.information_gain), "chunks": session.chunks, + "workspace_binding": session.workspace_binding, "tasks": tasks_payload, "task_counter": session.task_counter, "evidence": [ @@ -218,6 +221,243 @@ def _session_to_payload( return payload +def snapshot_session_state(session: _Session) -> dict[str, Any]: + """Capture the mutable reasoning/task state for a session.""" + reasoning_trace = session.repl._namespace.get("_reasoning_trace") + if not isinstance(reasoning_trace, list): + reasoning_trace = [] + return { + "created_at": session.created_at, + "iterations": session.iterations, + "think_history": list(session.think_history), + "evidence": list(session.evidence), + "confidence_history": list(session.confidence_history), + "information_gain": list(session.information_gain), + "chunks": list(session.chunks) if isinstance(session.chunks, list) else session.chunks, + "workspace_binding": ( + dict(session.workspace_binding) + if isinstance(session.workspace_binding, dict) + else None + ), + "tasks": [dict(task) for task in session.tasks if isinstance(task, dict)], + "task_counter": session.task_counter, + "max_depth_seen": session.max_depth_seen, + "reasoning_trace": list(reasoning_trace), + } + + +def restore_session_state(session: _Session, state: dict[str, Any]) -> None: + """Restore mutable reasoning/task state onto an existing session.""" + session.created_at = state["created_at"] + session.iterations = int(state["iterations"]) + session.think_history = list(state["think_history"]) + session.evidence = list(state["evidence"]) + session.confidence_history = list(state["confidence_history"]) + session.information_gain = list(state["information_gain"]) + chunks = state["chunks"] + session.chunks = list(chunks) if isinstance(chunks, list) else chunks + binding = state["workspace_binding"] + session.workspace_binding = dict(binding) if isinstance(binding, dict) else None + session.tasks = [dict(task) for task in state["tasks"] if isinstance(task, dict)] + session.task_counter = int(state["task_counter"]) + session.max_depth_seen = int(state["max_depth_seen"]) + session.repl._namespace["_tasks"] = session.tasks + reasoning_trace = state["reasoning_trace"] + if reasoning_trace: + session.repl._namespace["_reasoning_trace"] = list(reasoning_trace) + else: + session.repl._namespace.pop("_reasoning_trace", None) + + +def create_session( + *, + sessions: dict[str, _Session], + context: str, + context_id: str, + fmt: ContentFormat, + line_number_base: LineNumberBase, + sandbox_config: SandboxConfig, + analyze_text_context: Callable[[str, ContentFormat], ContextMetadata], + configure_session: Callable[[_Session, str, asyncio.AbstractEventLoop | None], None], + close_node_repl: Callable[[str], None] | None = None, +) -> ContextMetadata: + """Create or replace a session for a context id.""" + if close_node_repl is not None: + close_node_repl(context_id) + + meta = analyze_text_context(context, fmt) + repl = REPLEnvironment( + context=context, + context_var_name="ctx", + config=sandbox_config, + loop=asyncio.get_running_loop(), + ) + repl.set_variable("line_number_base", line_number_base) + sessions[context_id] = _Session( + repl=repl, + meta=meta, + line_number_base=line_number_base, + ) + configure_session(sessions[context_id], context_id, asyncio.get_running_loop()) + return meta + + +def get_or_create_session( + *, + sessions: dict[str, _Session], + context_id: str, + line_number_base: LineNumberBase | None, + sandbox_config: SandboxConfig, + analyze_text_context: Callable[[str, ContentFormat], ContextMetadata], + configure_session: Callable[[_Session, str, asyncio.AbstractEventLoop | None], None], +) -> _Session: + """Get an existing session or create an empty one.""" + session = sessions.get(context_id) + if session is not None: + configure_session(session, context_id, asyncio.get_running_loop()) + return session + + base = ( + line_number_base + if line_number_base is not None + else DEFAULT_LINE_NUMBER_BASE + ) + meta = analyze_text_context("", ContentFormat.TEXT) + repl = REPLEnvironment( + context="", + context_var_name="ctx", + config=sandbox_config, + loop=asyncio.get_running_loop(), + ) + repl.set_variable("line_number_base", base) + session = _Session(repl=repl, meta=meta, line_number_base=base) + sessions[context_id] = session + configure_session(session, context_id, asyncio.get_running_loop()) + return session + + +def replace_session_context( + *, + sessions: dict[str, _Session], + context: str, + context_id: str, + fmt: ContentFormat, + line_number_base: LineNumberBase, + sandbox_config: SandboxConfig, + analyze_text_context: Callable[[str, ContentFormat], ContextMetadata], + configure_session: Callable[[_Session, str, asyncio.AbstractEventLoop | None], None], + close_node_repl: Callable[[str], None] | None = None, + preserve_state: bool = False, +) -> ContextMetadata: + """Replace the session context, optionally preserving reasoning/task state.""" + previous_state = None + if preserve_state and context_id in sessions: + previous_state = snapshot_session_state(sessions[context_id]) + + meta = create_session( + sessions=sessions, + context=context, + context_id=context_id, + fmt=fmt, + line_number_base=line_number_base, + sandbox_config=sandbox_config, + analyze_text_context=analyze_text_context, + configure_session=configure_session, + close_node_repl=close_node_repl, + ) + if previous_state is not None: + restore_session_state(sessions[context_id], previous_state) + return meta + + +def build_memory_pack_payload( + sessions: dict[str, _Session], + *, + include_ctx: bool = True, +) -> tuple[dict[str, Any], list[str]]: + """Serialize all known sessions into a memory-pack payload.""" + sessions_payload: list[dict[str, Any]] = [] + skipped: list[str] = [] + for sid, sess in sessions.items(): + try: + sessions_payload.append( + _session_to_payload(sid, sess, include_ctx=include_ctx) + ) + except Exception: + skipped.append(sid) + payload = { + "schema": "aleph.memory_pack.v1", + "created_at": datetime.now().isoformat(), + "sessions": sessions_payload, + "skipped": skipped, + } + return payload, skipped + + +def _resolve_session_payload_id(session_payload: Any) -> str | None: + if not isinstance(session_payload, dict): + return None + + raw_id = session_payload.get("id") + if isinstance(raw_id, str) and raw_id.strip(): + return raw_id.strip() + + raw_session_id = session_payload.get("session_id") + if isinstance(raw_session_id, str) and raw_session_id.strip(): + return raw_session_id.strip() + + raw_context_id = session_payload.get("context_id") + if isinstance(raw_context_id, str) and raw_context_id.strip(): + return raw_context_id.strip() + + return None + + +def load_memory_pack_payload( + payload: dict[str, Any], + *, + sessions: dict[str, _Session], + sandbox_config: SandboxConfig, + configure_session: Callable[[_Session, str, asyncio.AbstractEventLoop | None], None], + loop: asyncio.AbstractEventLoop | None, + close_node_repl: Callable[[str], None] | None = None, + skip_existing: bool = False, +) -> tuple[list[str], list[dict[str, str]]]: + """Load sessions from a memory-pack payload into the session registry.""" + if payload.get("schema") != "aleph.memory_pack.v1": + raise ValueError("Invalid memory pack schema") + + session_payloads = payload.get("sessions") + if not isinstance(session_payloads, list): + raise ValueError("Invalid memory pack payload: sessions must be a list") + + loaded: list[str] = [] + skipped: list[dict[str, str]] = [] + for session_payload in session_payloads: + resolved_id = _resolve_session_payload_id(session_payload) + if not resolved_id: + skipped.append({"id": "", "error": "missing session identifier"}) + continue + if skip_existing and resolved_id in sessions: + continue + try: + if close_node_repl is not None: + close_node_repl(resolved_id) + session = _session_from_payload( + session_payload, + resolved_id, + sandbox_config, + loop, + ) + configure_session(session, resolved_id, loop) + sessions[resolved_id] = session + loaded.append(resolved_id) + except Exception as exc: + skipped.append({"id": resolved_id, "error": str(exc)}) + + return loaded, skipped + + def _session_from_payload( obj: dict[str, Any], resolved_id: str, @@ -279,32 +519,32 @@ def _session_from_payload( for task in tasks_payload: if not isinstance(task, dict): continue - if "id" not in task or "title" not in task: - continue + tasks.append(dict(task)) + + def _task_counter_seed(items: list[dict[str, Any]]) -> int: + best = 0 + for task in items: raw_id = task.get("id") - if raw_id is None: - continue - try: - task_id = int(raw_id) - except (TypeError, ValueError): + if isinstance(raw_id, int): + best = max(best, raw_id) continue - tasks.append({ - "id": task_id, - "title": str(task.get("title")), - "status": str(task.get("status") or "todo"), - "note": task.get("note"), - "created_at": task.get("created_at"), - "updated_at": task.get("updated_at"), - }) + if isinstance(raw_id, str): + digits = "".join(ch for ch in raw_id if ch.isdigit()) + if digits: + try: + best = max(best, int(digits)) + except ValueError: + continue + return best raw_task_counter = obj.get("task_counter") if isinstance(raw_task_counter, (int, str)): try: task_counter = int(raw_task_counter) except (TypeError, ValueError): - task_counter = max((t["id"] for t in tasks), default=0) + task_counter = _task_counter_seed(tasks) else: - task_counter = max((t["id"] for t in tasks), default=0) + task_counter = _task_counter_seed(tasks) session = _Session( repl=repl, @@ -316,9 +556,15 @@ def _session_from_payload( confidence_history=list(obj.get("confidence_history") or []), information_gain=list(obj.get("information_gain") or []), chunks=obj.get("chunks"), + workspace_binding=( + dict(obj["workspace_binding"]) + if isinstance(obj.get("workspace_binding"), dict) + else None + ), tasks=tasks, task_counter=task_counter, ) + repl._namespace["_tasks"] = session.tasks ev_list = obj.get("evidence") if isinstance(ev_list, list): diff --git a/aleph/mcp/sub_query_orchestration.py b/aleph/mcp/sub_query_orchestration.py new file mode 100644 index 0000000..dafa8d0 --- /dev/null +++ b/aleph/mcp/sub_query_orchestration.py @@ -0,0 +1,801 @@ +"""Sub-query and sub-Aleph orchestration helpers for the local MCP server.""" + +from __future__ import annotations + +import asyncio +import json +import os +import re +import time +from pathlib import Path +from typing import TYPE_CHECKING, Callable + +from ..config import AlephConfig +from ..core import Aleph +from ..observability import traced_span +from ..prompts.system import DEFAULT_SYSTEM_PROMPT +from ..providers.registry import get_provider +from ..sub_query import detect_backend +from ..sub_query.api_backend import run_api_sub_query +from ..sub_query.cli_backend import CLI_BACKENDS, run_cli_sub_query +from ..sub_query.codex_mcp_backend import ( + build_codex_mcp_tool_call, + compose_sub_query_prompt, + extract_codex_mcp_result_text, + suppress_mcp_notification_validation_logs, +) +from ..sub_query.config import ( + resolve_codex_mode, + resolve_codex_model, + resolve_codex_profile, + resolve_codex_reasoning_effort, +) +from ..types import AlephResponse, ContentFormat, ContextMetadata +from .env_utils import _get_env_bool, _get_env_int +from .formatting import _to_jsonable +from .remote_servers import register_remote_server +from .session import _Evidence, _analyze_text_context as _fallback_analyze_text_context + +if TYPE_CHECKING: + from .local_server import AlephMCPServerLocal + + +__all__ = [ + "build_sub_aleph_cli_prompt", + "ensure_internal_codex_mcp_server", + "ensure_streamable_http_server", + "extract_final_answer", + "format_streamable_http_url", + "normalize_streamable_http_path", + "run_internal_codex_mcp_query", + "run_streamable_http_server", + "run_sub_aleph", + "run_sub_query", + "wait_for_streamable_http_ready", +] + + +_FINAL_RE = re.compile(r"FINAL\((.*?)\)", re.DOTALL) +_FINAL_VAR_RE = re.compile(r"FINAL_VAR\((.*?)\)", re.DOTALL) +_SHARED_SESSION_BACKENDS = {"claude", "codex", "gemini", "kimi"} + + +def extract_final_answer(text: str) -> tuple[str, bool]: + match = _FINAL_RE.search(text) + if match: + return match.group(1).strip(), True + match_var = _FINAL_VAR_RE.search(text) + if match_var: + raw = match_var.group(1).strip() + if len(raw) >= 2 and ((raw[0] == raw[-1] == '"') or (raw[0] == raw[-1] == "'")): + raw = raw[1:-1].strip() + return raw, True + return text.strip(), False + + +def build_sub_aleph_cli_prompt( + *, + query: str, + context_slice: str, + context_format: ContentFormat, + cfg: AlephConfig, + analyze_text_context: Callable[[str, ContentFormat], ContextMetadata], +) -> str: + meta = analyze_text_context(context_slice, context_format) + system_template = cfg.system_prompt or DEFAULT_SYSTEM_PROMPT + system_prompt = system_template.format( + query=query, + context_var=cfg.context_var_name, + context_format=meta.format.value, + context_size_chars=meta.size_chars, + context_size_lines=meta.size_lines, + context_size_tokens=meta.size_tokens_estimate, + context_preview="[OMITTED FOR CONTEXT ISOLATION]", + structure_hint=meta.structure_hint or "N/A", + ) + instructions = ( + "SINGLE-SHOT MODE (no live Python REPL in this call):\n" + "- Do not output code blocks.\n" + "- Answer directly and wrap the final answer in FINAL(...).\n" + ) + return f"{system_prompt}\n\n{instructions}\nQUERY:\n{query}" + + +def normalize_streamable_http_path(path: str) -> str: + if not path: + return "/mcp" + return path if path.startswith("/") else f"/{path}" + + +def format_streamable_http_url(host: str, port: int, path: str) -> str: + connect_host = "127.0.0.1" if host in {"0.0.0.0", "::"} else host + return f"http://{connect_host}:{port}{path}" + + +async def wait_for_streamable_http_ready( + owner: "AlephMCPServerLocal", + host: str, + port: int, + timeout_seconds: float = 2.0, +) -> tuple[bool, str]: + deadline = time.monotonic() + timeout_seconds + connect_host = "127.0.0.1" if host in {"0.0.0.0", "::"} else host + + while time.monotonic() < deadline: + if owner._streamable_http_task and owner._streamable_http_task.done(): + exc = owner._streamable_http_task.exception() + if exc: + return False, f"Streamable HTTP server failed to start: {exc}" + return False, "Streamable HTTP server stopped unexpectedly." + try: + reader, writer = await asyncio.wait_for( + asyncio.open_connection(connect_host, port), + timeout=0.2, + ) + writer.close() + await writer.wait_closed() + return True, "" + except Exception: + await asyncio.sleep(0.05) + + return ( + False, + f"Timed out waiting for streamable HTTP server on {connect_host}:{port}.", + ) + + +async def run_streamable_http_server( + owner: "AlephMCPServerLocal", + host: str, + port: int, +) -> None: + try: + import uvicorn + except Exception as exc: + raise RuntimeError( + "uvicorn is required for streamable HTTP transport. " + "Install with: pip install uvicorn" + ) from exc + + app = owner.server.streamable_http_app() + config = uvicorn.Config( + app, + host=host, + port=port, + log_level="warning", + access_log=False, + lifespan="on", + ) + server = uvicorn.Server(config) + await server.serve() + + +async def ensure_streamable_http_server( + owner: "AlephMCPServerLocal", + host: str, + port: int, + path: str, +) -> tuple[bool, str]: + normalized_path = normalize_streamable_http_path(path) + async with owner._streamable_http_lock: + if owner._streamable_http_task and not owner._streamable_http_task.done(): + url = owner._streamable_http_url or format_streamable_http_url( + host, + port, + normalized_path, + ) + return True, url + if owner._streamable_http_task and owner._streamable_http_task.done(): + owner._streamable_http_task = None + owner._streamable_http_url = None + + owner.server.settings.host = host + owner.server.settings.port = port + owner.server.settings.streamable_http_path = normalized_path + + owner._streamable_http_task = asyncio.create_task( + owner._run_streamable_http_server(host, port) + ) + owner._streamable_http_host = host + owner._streamable_http_port = port + owner._streamable_http_path = normalized_path + owner._streamable_http_url = format_streamable_http_url( + host, + port, + normalized_path, + ) + + ok, err = await owner._wait_for_streamable_http_ready(host, port) + if not ok: + return False, err + return True, owner._streamable_http_url or format_streamable_http_url( + host, + port, + normalized_path, + ) + + +async def ensure_internal_codex_mcp_server( + owner: "AlephMCPServerLocal", + cwd: Path | None, +) -> str: + server_id = "__aleph_internal_codex__" + handle = owner._remote_servers.get(server_id) + if handle is None: + handle = register_remote_server( + owner._remote_servers, + server_id, + command="codex", + args=["mcp-server", "-c", "mcp_servers={}"], + cwd=cwd, + allow_tools=["codex", "codex-reply"], + ) + elif handle.cwd != cwd: + await owner._reset_remote_server_handle(handle) + handle.cwd = cwd + + with suppress_mcp_notification_validation_logs(): + ok, res = await owner._ensure_remote_server(server_id) + if not ok: + raise RuntimeError(str(res)) + return server_id + + +async def run_internal_codex_mcp_query( + owner: "AlephMCPServerLocal", + *, + prompt: str, + context_slice: str | None, + cwd: Path | None, + mcp_server_url: str | None, + mcp_server_name: str, + thread_id: str | None = None, +) -> tuple[bool, str, str | None]: + full_prompt = compose_sub_query_prompt(prompt, context_slice) + + tool_name, arguments = build_codex_mcp_tool_call( + prompt=full_prompt, + cwd=cwd, + mcp_server_url=mcp_server_url, + mcp_server_name=mcp_server_name, + trust_mcp_server=True, + model=resolve_codex_model(owner.sub_query_config.codex_model), + reasoning_effort=resolve_codex_reasoning_effort( + owner.sub_query_config.codex_reasoning_effort + ), + profile=resolve_codex_profile(owner.sub_query_config.codex_profile), + thread_id=thread_id, + ) + + try: + server_id = await owner._ensure_internal_codex_mcp_server(cwd) + except Exception as exc: + return False, f"Failed to start internal Codex MCP server: {exc}", None + + with suppress_mcp_notification_validation_logs(): + ok, result = await owner._remote_call_tool( + server_id, + tool_name, + arguments, + timeout_seconds=owner.sub_query_config.cli_timeout_seconds, + ) + if not ok: + return False, str(result), None + + output, resolved_thread_id = extract_codex_mcp_result_text(result) + if not output: + output = json.dumps(_to_jsonable(result), ensure_ascii=True) + + if len(output) > owner.sub_query_config.cli_max_output_chars: + output = ( + output[: owner.sub_query_config.cli_max_output_chars] + + "\n...[truncated]" + ) + + return True, output, resolved_thread_id + + +async def _prepare_cli_shared_session( + owner: "AlephMCPServerLocal", + *, + prompt: str, + context_id: str, + resolved_backend: str, +) -> tuple[bool, str, str | None, str]: + mcp_server_url = None + server_name = "aleph_shared" + share_session = _get_env_bool("ALEPH_SUB_QUERY_SHARE_SESSION", False) + + if share_session and resolved_backend in _SHARED_SESSION_BACKENDS: + host = os.environ.get("ALEPH_SUB_QUERY_HTTP_HOST", "127.0.0.1") + port = _get_env_int("ALEPH_SUB_QUERY_HTTP_PORT", 8765) + path = os.environ.get("ALEPH_SUB_QUERY_HTTP_PATH", "/mcp") + server_name = ( + os.environ.get( + "ALEPH_SUB_QUERY_MCP_SERVER_NAME", + "aleph_shared", + ).strip() + or "aleph_shared" + ) + ok, url_or_err = await owner._ensure_streamable_http_server(host, port, path) + if not ok: + return False, url_or_err, None, server_name + mcp_server_url = url_or_err + prompt = ( + f"{prompt}\n\n" + f"[MCP tools are available via the live Aleph server. " + f"Use context_id={context_id!r} when calling tools. " + f"Tools are prefixed with `mcp__{server_name}__`.]" + ) + + return True, prompt, mcp_server_url, server_name + + +async def run_sub_query( + owner: "AlephMCPServerLocal", + *, + prompt: str, + context_slice: str | None, + context_id: str, + backend: str, + validation_regex: str | None = None, + max_retries: int | None = None, + retry_prompt: str | None = None, +) -> tuple[bool, str, bool, str]: + session = owner._sessions.get(context_id) + if session: + session.iterations += 1 + + truncated = False + if context_slice and len(context_slice) > owner.sub_query_config.max_context_chars: + context_slice = context_slice[: owner.sub_query_config.max_context_chars] + truncated = True + + resolved_backend = backend + if backend == "auto": + resolved_backend = detect_backend(owner.sub_query_config) + + allowed_backends = {"auto", "api", *CLI_BACKENDS} + if resolved_backend not in allowed_backends: + allowed_list = ", ".join(sorted(allowed_backends)) + return ( + False, + f"Unsupported backend '{resolved_backend}'. Choose from: {allowed_list}.", + truncated, + resolved_backend, + ) + + resolved_validation_regex = validation_regex + if resolved_validation_regex is None: + resolved_validation_regex = ( + owner.sub_query_config.validation_regex + or os.environ.get("ALEPH_SUB_QUERY_VALIDATION_REGEX") + ) + if resolved_validation_regex is not None: + resolved_validation_regex = resolved_validation_regex.strip() + if not resolved_validation_regex: + resolved_validation_regex = None + + resolved_max_retries = ( + owner.sub_query_config.max_retries if max_retries is None else max_retries + ) + if max_retries is None: + resolved_max_retries = _get_env_int( + "ALEPH_SUB_QUERY_MAX_RETRIES", resolved_max_retries + ) + + resolved_retry_prompt = ( + owner.sub_query_config.retry_prompt if retry_prompt is None else retry_prompt + ) + if retry_prompt is None: + env_retry_prompt = os.environ.get("ALEPH_SUB_QUERY_RETRY_PROMPT") + if env_retry_prompt: + resolved_retry_prompt = env_retry_prompt + + validation_re: re.Pattern[str] | None = None + if resolved_validation_regex: + try: + validation_re = re.compile(resolved_validation_regex, re.MULTILINE) + except re.error as exc: + return ( + False, + f"Invalid validation regex: {exc}", + truncated, + resolved_backend, + ) + + attempt = 0 + base_prompt = prompt + prompt_for_attempt = base_prompt + codex_thread_id: str | None = None + with traced_span( + "aleph.sub_query", + { + "aleph.context_id": context_id, + "aleph.sub_query.backend.requested": backend, + "aleph.sub_query.backend.resolved": resolved_backend, + "aleph.sub_query.context_chars": len(context_slice or ""), + "aleph.sub_query.context_truncated": truncated, + "aleph.sub_query.validation_enabled": bool(resolved_validation_regex), + }, + ) as span: + success = False + output = "" + try: + while True: + run_prompt = prompt_for_attempt + if resolved_backend in CLI_BACKENDS: + ok, prepared_prompt, mcp_server_url, server_name = ( + await _prepare_cli_shared_session( + owner, + prompt=run_prompt, + context_id=context_id, + resolved_backend=resolved_backend, + ) + ) + if not ok: + return ( + False, + f"Failed to start streamable HTTP server: {prepared_prompt}", + truncated, + resolved_backend, + ) + run_prompt = prepared_prompt + cwd = ( + owner.action_config.workspace_root + if owner.action_config.enabled + else None + ) + if ( + resolved_backend == "codex" + and resolve_codex_mode(owner.sub_query_config.codex_mode) + == "mcp" + ): + success, output, codex_thread_id = ( + await owner._run_internal_codex_mcp_query( + prompt=run_prompt, + context_slice=context_slice, + cwd=cwd, + mcp_server_url=mcp_server_url, + mcp_server_name=server_name, + thread_id=codex_thread_id, + ) + ) + else: + success, output = await run_cli_sub_query( + prompt=run_prompt, + context_slice=context_slice, + backend=resolved_backend, # type: ignore[arg-type] + timeout=owner.sub_query_config.cli_timeout_seconds, + cwd=cwd, + max_output_chars=owner.sub_query_config.cli_max_output_chars, + max_context_chars=owner.sub_query_config.max_context_chars, + mcp_server_url=mcp_server_url, + mcp_server_name=server_name, + trust_mcp_server=True, + claude_model=owner.sub_query_config.claude_model, + claude_effort=owner.sub_query_config.claude_effort, + codex_mode=owner.sub_query_config.codex_mode, + codex_model=owner.sub_query_config.codex_model, + codex_reasoning_effort=owner.sub_query_config.codex_reasoning_effort, + codex_profile=owner.sub_query_config.codex_profile, + ) + else: + success, output = await run_api_sub_query( + prompt=run_prompt, + context_slice=context_slice, + model=owner.sub_query_config.api_model, + api_key_env=owner.sub_query_config.api_key_env, + api_base_url_env=owner.sub_query_config.api_base_url_env, + api_model_env=owner.sub_query_config.api_model_env, + timeout=owner.sub_query_config.api_timeout_seconds, + system_prompt=owner.sub_query_config.system_prompt + if owner.sub_query_config.include_system_prompt + else None, + max_context_chars=owner.sub_query_config.max_context_chars, + ) + + if not success: + break + + if validation_re and not validation_re.search(output): + if attempt >= resolved_max_retries: + success = False + output = ( + f"Output failed validation regex {resolved_validation_regex!r} " + f"after {attempt + 1} attempt(s). Last output: {output}" + ) + break + attempt += 1 + prompt_for_attempt = ( + f"{base_prompt}\n\n" + f"{resolved_retry_prompt}\n" + f"Required format regex: {resolved_validation_regex}" + ) + continue + + break + except Exception as exc: + span.record_exception(exc) + success = False + output = f"{type(exc).__name__}: {exc}" + + span.set_attribute("aleph.sub_query.success", success) + span.set_attribute("aleph.sub_query.attempts", attempt + 1) + span.set_attribute("aleph.sub_query.output_chars", len(output)) + + if session: + note_parts = [f"backend={resolved_backend}"] + if resolved_validation_regex: + note_parts.append(f"validation={resolved_validation_regex!r}") + if attempt: + note_parts.append(f"retries={attempt}") + if truncated: + note_parts.append("truncated_context") + session.evidence.append( + _Evidence( + source="sub_query", + line_range=None, + pattern=None, + snippet=output[:200] if success else f"[ERROR] {output[:150]}", + note=" ".join(note_parts), + ) + ) + session.information_gain.append(1 if success else 0) + + return success, output, truncated, resolved_backend + + +async def run_sub_aleph( + owner: "AlephMCPServerLocal", + *, + query: str, + context_slice: str | None, + context_id: str, + current_depth: int = 1, + root_model: str | None = None, + sub_model: str | None = None, + max_depth: int | None = None, + max_iterations: int | None = None, + max_tokens: int | None = None, + max_sub_queries: int | None = None, + max_wall_time_seconds: float | None = None, + temperature: float | None = None, + analyze_text_context: Callable[[str, ContentFormat], ContextMetadata] | None = None, +) -> tuple[AlephResponse, dict[str, object]]: + session = owner._sessions.get(context_id) + if session: + session.iterations += 1 + session.max_depth_seen = max(session.max_depth_seen, current_depth) + + cfg = AlephConfig.from_env() + budget = cfg.to_budget() + if max_tokens is not None: + budget.max_tokens = max_tokens + if max_iterations is not None: + budget.max_iterations = max_iterations + if max_depth is not None: + budget.max_depth = max_depth + if max_wall_time_seconds is not None: + budget.max_wall_time_seconds = max_wall_time_seconds + if max_sub_queries is not None: + budget.max_sub_queries = max_sub_queries + + resolved_root = root_model or cfg.root_model + resolved_sub = sub_model or cfg.sub_model or resolved_root + + temp_val = 0.0 + if temperature is not None: + try: + temp_val = float(temperature) + except (TypeError, ValueError): + temp_val = 0.0 + + resolved_backend = detect_backend(owner.sub_query_config) + truncated_context = False + start_time = time.perf_counter() + response: AlephResponse | None = None + + if resolved_backend in CLI_BACKENDS: + cli_context = context_slice or "" + if cli_context and len(cli_context) > owner.sub_query_config.max_context_chars: + cli_context = cli_context[: owner.sub_query_config.max_context_chars] + truncated_context = True + + context_format = session.meta.format if session else ContentFormat.TEXT + prompt = build_sub_aleph_cli_prompt( + query=query, + context_slice=cli_context, + context_format=context_format, + cfg=cfg, + analyze_text_context=analyze_text_context or _fallback_analyze_text_context, + ) + + mcp_server_url = None + server_name = "aleph_shared" + share_session = _get_env_bool("ALEPH_SUB_QUERY_SHARE_SESSION", False) + if share_session and resolved_backend in _SHARED_SESSION_BACKENDS: + ok, prepared_prompt, mcp_server_url, server_name = ( + await _prepare_cli_shared_session( + owner, + prompt=prompt, + context_id=context_id, + resolved_backend=resolved_backend, + ) + ) + if not ok: + response = AlephResponse( + answer="", + success=False, + total_iterations=0, + max_depth_reached=0, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=time.perf_counter() - start_time, + trajectory=[], + error=f"Failed to start streamable HTTP server: {prepared_prompt}", + error_type="cli_error", + ) + else: + prompt = prepared_prompt + + if mcp_server_url is not None or not share_session: + try: + cwd = ( + owner.action_config.workspace_root + if owner.action_config.enabled + else None + ) + if ( + resolved_backend == "codex" + and resolve_codex_mode(owner.sub_query_config.codex_mode) == "mcp" + ): + success, output, _thread_id = await owner._run_internal_codex_mcp_query( + prompt=prompt, + context_slice=cli_context if cli_context else None, + cwd=cwd, + mcp_server_url=mcp_server_url, + mcp_server_name=server_name, + ) + else: + success, output = await run_cli_sub_query( + prompt=prompt, + context_slice=cli_context if cli_context else None, + backend=resolved_backend, # type: ignore[arg-type] + timeout=owner.sub_query_config.cli_timeout_seconds, + cwd=cwd, + max_output_chars=owner.sub_query_config.cli_max_output_chars, + max_context_chars=owner.sub_query_config.max_context_chars, + mcp_server_url=mcp_server_url, + mcp_server_name=server_name, + trust_mcp_server=True, + claude_model=owner.sub_query_config.claude_model, + claude_effort=owner.sub_query_config.claude_effort, + codex_mode=owner.sub_query_config.codex_mode, + codex_model=owner.sub_query_config.codex_model, + codex_reasoning_effort=owner.sub_query_config.codex_reasoning_effort, + codex_profile=owner.sub_query_config.codex_profile, + ) + except Exception as exc: + success, output = False, f"{type(exc).__name__}: {exc}" + + wall_time = time.perf_counter() - start_time + if success: + answer, _ = extract_final_answer(output) + if not answer: + response = AlephResponse( + answer="", + success=False, + total_iterations=current_depth, + max_depth_reached=current_depth, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=wall_time, + trajectory=[], + error="Empty response from CLI backend", + error_type="cli_error", + ) + else: + response = AlephResponse( + answer=answer, + success=True, + total_iterations=current_depth, + max_depth_reached=current_depth, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=wall_time, + trajectory=[], + ) + else: + response = AlephResponse( + answer="", + success=False, + total_iterations=current_depth, + max_depth_reached=current_depth, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=wall_time, + trajectory=[], + error=output, + error_type="cli_error", + ) + else: + try: + provider = get_provider(cfg.provider, api_key=cfg.api_key) + runner = Aleph( + provider=provider, + root_model=resolved_root, + sub_model=resolved_sub, + budget=budget, + sandbox_config=owner.sandbox_config, + system_prompt=cfg.system_prompt, + enable_caching=cfg.enable_caching, + log_trajectory=cfg.log_trajectory, + ) + response = await runner.complete( + query=query, + context=context_slice or "", + root_model=resolved_root, + sub_model=resolved_sub, + budget=budget, + temperature=temp_val, + ) + except Exception as exc: + response = AlephResponse( + answer="", + success=False, + total_iterations=0, + max_depth_reached=0, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=0.0, + trajectory=[], + error=str(exc), + error_type="provider_error", + ) + + if response is None: + response = AlephResponse( + answer="", + success=False, + total_iterations=current_depth, + max_depth_reached=current_depth, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=time.perf_counter() - start_time, + trajectory=[], + error="CLI backend could not start.", + error_type="cli_error", + ) + + if session: + note_parts = [ + f"backend={resolved_backend}", + f"models={resolved_root}/{resolved_sub}", + ] + if budget.max_depth is not None: + note_parts.append(f"max_depth={budget.max_depth}") + if truncated_context: + note_parts.append("truncated_context") + session.evidence.append( + _Evidence( + source="sub_aleph", + line_range=None, + pattern=None, + snippet=response.answer[:200] + if response.success + else f"[ERROR] {str(response.error)[:150]}", + note=" ".join(note_parts), + ) + ) + session.information_gain.append(1 if response.success else 0) + + meta: dict[str, object] = { + "root_model": resolved_root, + "sub_model": resolved_sub, + "budget": budget, + "temperature": temp_val, + "backend": resolved_backend, + "truncated_context": truncated_context, + } + return response, meta diff --git a/aleph/mcp/workspace.py b/aleph/mcp/workspace.py index 096139d..eec7b35 100644 --- a/aleph/mcp/workspace.py +++ b/aleph/mcp/workspace.py @@ -94,6 +94,19 @@ def _validate_line_number_base(value: int) -> LineNumberBase: return cast(LineNumberBase, value) +def _resolve_line_number_base(session: Any | None, value: int | None) -> LineNumberBase: + if session is not None: + if value is None: + return cast(LineNumberBase, session.line_number_base) + base = _validate_line_number_base(value) + if base != session.line_number_base: + raise ValueError("line_number_base does not match existing session") + return base + if value is None: + return DEFAULT_LINE_NUMBER_BASE + return _validate_line_number_base(value) + + def _uri_to_path(uri: str) -> Path | None: """Convert a file:// URI to a Path, or return None for non-file URIs.""" try: diff --git a/aleph/mcp/workspace_contexts.py b/aleph/mcp/workspace_contexts.py new file mode 100644 index 0000000..3e73fb1 --- /dev/null +++ b/aleph/mcp/workspace_contexts.py @@ -0,0 +1,342 @@ +"""Workspace-backed context helpers for MCP sessions. + +These helpers let Aleph treat some contexts as refreshable workspace assets +instead of anonymous blobs. The first supported bindings are: + +- file: a context loaded from a workspace file via ``load_file`` +- manifest: a generated workspace manifest for large codebases/projects +""" + +from __future__ import annotations + +from collections import Counter +from datetime import datetime +import os +from pathlib import Path +from typing import Any, Iterable + +from ..types import ContentFormat +from .io_utils import _load_text_from_path + +WorkspaceBinding = dict[str, Any] + +_SKIP_DIRS = { + ".git", + ".hg", + ".svn", + ".venv", + "venv", + "node_modules", + "dist", + "build", + "__pycache__", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".turbo", + ".next", + ".parcel-cache", + "coverage", +} + +_IMPORTANT_FILES = { + "README.md", + "pyproject.toml", + "package.json", + "package-lock.json", + "pnpm-lock.yaml", + "yarn.lock", + "Cargo.toml", + "go.mod", + "requirements.txt", + "Dockerfile", + "docker-compose.yml", + "docker-compose.yaml", + "Makefile", + ".env.example", +} + +_LANGUAGE_BY_SUFFIX = { + ".py": "python", + ".pyi": "python", + ".js": "javascript", + ".jsx": "javascript", + ".ts": "typescript", + ".tsx": "typescript", + ".json": "json", + ".jsonl": "jsonl", + ".yaml": "yaml", + ".yml": "yaml", + ".md": "markdown", + ".toml": "toml", + ".rs": "rust", + ".go": "go", + ".java": "java", + ".kt": "kotlin", + ".swift": "swift", + ".rb": "ruby", + ".php": "php", + ".c": "c", + ".h": "c", + ".cpp": "cpp", + ".hpp": "cpp", + ".cs": "csharp", + ".sh": "shell", + ".sql": "sql", + ".html": "html", + ".css": "css", +} + + +def _display_path(path: Path, workspace_root: Path) -> str: + try: + return str(path.resolve().relative_to(workspace_root.resolve())) + except Exception: + return str(path.resolve()) + + +def make_file_binding(path: Path, workspace_root: Path) -> WorkspaceBinding: + resolved = path.resolve() + stat = resolved.stat() + return { + "kind": "file", + "path": str(resolved), + "workspace_root": str(workspace_root.resolve()), + "display_path": _display_path(resolved, workspace_root), + "size_bytes": stat.st_size, + "mtime_ns": stat.st_mtime_ns, + "refreshed_at": datetime.now().isoformat(), + } + + +def _iter_workspace_files(root: Path, include_hidden: bool) -> Iterable[Path]: + if root.is_file(): + yield root + return + + for dirpath, dirnames, filenames in os.walk(root): + current = Path(dirpath) + dirnames[:] = sorted( + d for d in dirnames + if d not in _SKIP_DIRS and (include_hidden or not d.startswith(".")) + ) + for filename in sorted(filenames): + if not include_hidden and filename.startswith("."): + continue + path = current / filename + if path.is_file(): + yield path + + +def _language_for_path(path: Path) -> str: + if path.name in _IMPORTANT_FILES: + return "project-config" + suffix = path.suffix.lower() + if suffix in _LANGUAGE_BY_SUFFIX: + return _LANGUAGE_BY_SUFFIX[suffix] + if not suffix: + return "plain" + return suffix.lstrip(".") + + +def build_workspace_manifest( + *, + workspace_root: Path, + roots: list[Path], + max_files: int = 2000, + include_hidden: bool = False, +) -> tuple[str, WorkspaceBinding, str | None]: + resolved_workspace_root = workspace_root.resolve() + resolved_roots = [root.resolve() for root in roots] + + files: list[tuple[str, str, int]] = [] + language_counts: Counter[str] = Counter() + top_level_counts: Counter[str] = Counter() + important_files: list[str] = [] + truncated = False + + for root in resolved_roots: + for path in _iter_workspace_files(root, include_hidden=include_hidden): + language = _language_for_path(path) + display = _display_path(path, resolved_workspace_root) + try: + stat = path.stat() + size_bytes = stat.st_size + except OSError: + size_bytes = 0 + + files.append((display, language, size_bytes)) + language_counts[language] += 1 + top_component = display.split("/", 1)[0] if "/" in display else display + top_level_counts[top_component] += 1 + if path.name in _IMPORTANT_FILES and display not in important_files: + important_files.append(display) + + if len(files) >= max_files: + truncated = True + break + if truncated: + break + + lines: list[str] = [ + "# Aleph Workspace Manifest", + "", + f"Workspace root: {resolved_workspace_root}", + f"Generated at: {datetime.now().isoformat()}", + "", + "Indexed roots:", + ] + for root in resolved_roots: + lines.append(f"- {_display_path(root, resolved_workspace_root)}") + + lines.extend( + [ + "", + f"Files indexed: {len(files)}" + (f" (truncated at {max_files})" if truncated else ""), + "", + "Language summary:", + ] + ) + for language, count in language_counts.most_common(12): + lines.append(f"- {language}: {count}") + + if top_level_counts: + lines.extend(["", "Top-level paths:"]) + for name, count in top_level_counts.most_common(12): + lines.append(f"- {name}: {count}") + + if important_files: + lines.extend(["", "Key project files:"]) + for path in sorted(important_files)[:20]: + lines.append(f"- {path}") + + lines.extend(["", "File listing:"]) + for display, language, size_bytes in files: + lines.append(f"- {display} | {language} | {size_bytes} bytes") + + note: str | None = None + if truncated: + note = f"Manifest truncated at {max_files} files. Increase max_files for broader coverage." + + binding: WorkspaceBinding = { + "kind": "manifest", + "workspace_root": str(resolved_workspace_root), + "roots": [str(root) for root in resolved_roots], + "max_files": max_files, + "include_hidden": include_hidden, + "file_count": len(files), + "truncated": truncated, + "refreshed_at": datetime.now().isoformat(), + } + return "\n".join(lines), binding, note + + +def binding_summary(binding: WorkspaceBinding | None) -> str | None: + if not binding: + return None + kind = str(binding.get("kind") or "") + if kind == "file": + return f"file:{binding.get('display_path') or binding.get('path')}" + if kind == "manifest": + file_count = binding.get("file_count") + return f"manifest:{file_count} files" + return kind or None + + +def binding_status(binding: WorkspaceBinding | None) -> dict[str, Any] | None: + if not binding: + return None + + kind = str(binding.get("kind") or "") + if kind == "file": + path_text = binding.get("path") + if not isinstance(path_text, str): + return { + "kind": "file", + "exists": False, + "refreshable": True, + "stale": True, + "reason": "missing file path", + } + path = Path(path_text) + if not path.exists(): + return { + "kind": "file", + "path": path_text, + "display_path": binding.get("display_path"), + "exists": False, + "refreshable": True, + "stale": True, + "reason": "file missing", + } + stat = path.stat() + stale = ( + stat.st_size != int(binding.get("size_bytes") or 0) + or stat.st_mtime_ns != int(binding.get("mtime_ns") or 0) + ) + reason = "file changed on disk" if stale else None + return { + "kind": "file", + "path": path_text, + "display_path": binding.get("display_path"), + "exists": True, + "refreshable": True, + "stale": stale, + "reason": reason, + "last_refreshed_at": binding.get("refreshed_at"), + } + + if kind == "manifest": + return { + "kind": "manifest", + "exists": True, + "roots": list(binding.get("roots") or []), + "file_count": int(binding.get("file_count") or 0), + "truncated": bool(binding.get("truncated") or False), + "refreshable": True, + "stale": False, + "reason": None, + "last_refreshed_at": binding.get("refreshed_at"), + } + + return { + "kind": kind or "unknown", + "exists": False, + "refreshable": False, + "stale": False, + "reason": None, + } + + +def refresh_workspace_binding( + binding: WorkspaceBinding, + *, + max_read_bytes: int, + timeout_seconds: float, +) -> tuple[str, ContentFormat, str | None, WorkspaceBinding]: + kind = str(binding.get("kind") or "") + if kind == "file": + path_text = binding.get("path") + workspace_root_text = binding.get("workspace_root") + if not isinstance(path_text, str) or not isinstance(workspace_root_text, str): + raise ValueError("Invalid file binding: missing path metadata") + path = Path(path_text) + workspace_root = Path(workspace_root_text) + text, fmt, warning = _load_text_from_path(path, max_read_bytes, timeout_seconds) + return text, fmt, warning, make_file_binding(path, workspace_root) + + if kind == "manifest": + workspace_root_text = binding.get("workspace_root") + roots_text = binding.get("roots") + if not isinstance(workspace_root_text, str) or not isinstance(roots_text, list): + raise ValueError("Invalid manifest binding: missing roots metadata") + roots = [Path(str(root)) for root in roots_text] + text, new_binding, note = build_workspace_manifest( + workspace_root=Path(workspace_root_text), + roots=roots, + max_files=int(binding.get("max_files") or 2000), + include_hidden=bool(binding.get("include_hidden") or False), + ) + return text, ContentFormat.TEXT, note, new_binding + + raise ValueError("Context is not bound to a refreshable workspace asset") diff --git a/aleph/mcp/workspace_tools.py b/aleph/mcp/workspace_tools.py new file mode 100644 index 0000000..d45759a --- /dev/null +++ b/aleph/mcp/workspace_tools.py @@ -0,0 +1,158 @@ +"""Workspace-oriented MCP tool registrations for the local server.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +from ..types import ContentFormat +from .formatting import _format_error, _format_payload +from .workspace import _scoped_path +from .workspace_contexts import ( + build_workspace_manifest, + refresh_workspace_binding, +) + +if TYPE_CHECKING: + from mcp.server.fastmcp import Context + + from .local_server import AlephMCPServerLocal +else: + Context = Any + + +def register_workspace_tools(owner: "AlephMCPServerLocal") -> None: + _tool = owner._tool_decorator + + @_tool() + async def load_workspace_manifest( + paths: list[str] | str | None = None, + context_id: str = "workspace", + max_files: int = 2000, + include_hidden: bool = False, + confirm: bool = False, + output: Literal["markdown", "json", "object"] = "markdown", + ctx: Context = None, # type: ignore[assignment] + ) -> str | dict[str, Any]: + """Load a compact manifest of the current workspace for large codebase/project analysis.""" + err = owner._require_actions(confirm) + if err: + return _format_error(err, output=output) + if max_files <= 0: + return _format_error("max_files must be greater than 0.", output=output) + + if ctx is not None: + await owner._maybe_resolve_workspace_from_roots(ctx) + + if isinstance(paths, str): + paths = [paths] + + resolved_roots = [] + for path in paths or [str(owner.action_config.workspace_root)]: + try: + resolved = _scoped_path( + owner.action_config.workspace_root, + path, + owner.action_config.workspace_mode, + ) + except Exception as exc: + return _format_error(str(exc), output=output) + if not resolved.exists(): + return _format_error(f"Path not found: {path}", output=output) + resolved_roots.append(resolved) + + text, binding, note = build_workspace_manifest( + workspace_root=owner.action_config.workspace_root, + roots=resolved_roots, + max_files=max_files, + include_hidden=include_hidden, + ) + meta = owner._create_session(text, context_id, ContentFormat.TEXT, 1) + session = owner._sessions[context_id] + session.workspace_binding = binding + owner._record_action( + session, + note="load_workspace_manifest", + snippet=", ".join(str(root) for root in resolved_roots)[:200], + ) + + payload = { + "status": "success", + "context_id": context_id, + "workspace_root": str(owner.action_config.workspace_root), + "roots": [str(root) for root in resolved_roots], + "file_count": int(binding.get("file_count") or 0), + "truncated": bool(binding.get("truncated") or False), + "binding": binding, + "note": note, + "size_chars": meta.size_chars, + "size_lines": meta.size_lines, + } + if output == "object": + return payload + if output == "json": + return _format_payload(payload, output="json") + return owner._format_context_loaded(context_id, meta, 1, note=note) + + @_tool() + async def refresh_context( + context_id: str = "default", + confirm: bool = False, + output: Literal["markdown", "json", "object"] = "markdown", + ) -> str | dict[str, Any]: + """Refresh a context from its bound workspace file or manifest.""" + err = owner._require_actions(confirm) + if err: + return _format_error(err, output=output) + if context_id not in owner._sessions: + return _format_error(f"No context loaded with ID '{context_id}'.", output=output) + + session = owner._sessions[context_id] + if not session.workspace_binding: + return _format_error( + "This context is not bound to a refreshable workspace file or manifest.", + output=output, + ) + + try: + text, fmt, note, refreshed_binding = refresh_workspace_binding( + session.workspace_binding, + max_read_bytes=owner.action_config.max_read_bytes, + timeout_seconds=owner.action_config.max_cmd_seconds, + ) + except Exception as exc: + return _format_error(f"Refresh failed: {exc}", output=output) + + meta = owner._replace_session_context( + text, + context_id, + fmt, + session.line_number_base, + preserve_state=True, + ) + refreshed_session = owner._sessions[context_id] + refreshed_session.iterations += 1 + refreshed_session.workspace_binding = refreshed_binding + owner._record_action( + refreshed_session, + note="refresh_context", + snippet=str(refreshed_binding.get("display_path") or refreshed_binding.get("kind") or context_id), + ) + + payload = { + "status": "success", + "context_id": context_id, + "binding": refreshed_binding, + "size_chars": meta.size_chars, + "size_lines": meta.size_lines, + "note": note, + } + if output == "object": + return payload + if output == "json": + return _format_payload(payload, output="json") + return owner._format_context_loaded( + context_id, + meta, + refreshed_session.line_number_base, + note=note or "Context refreshed from workspace binding.", + ) diff --git a/aleph/repl/node_worker.cjs b/aleph/repl/node_worker.cjs index 8ff057e..d97862f 100644 --- a/aleph/repl/node_worker.cjs +++ b/aleph/repl/node_worker.cjs @@ -1450,10 +1450,226 @@ function detectUpdatedVariables(code, previousCtx, nextCtx) { return Array.from(updated); } +function findMatchingParenStart(source, closeIndex) { + let depth = 0; + let quote = null; + for (let index = closeIndex; index >= 0; index -= 1) { + const ch = source[index]; + const prev = index > 0 ? source[index - 1] : ""; + if (quote) { + if (ch === quote && prev !== "\\") { + quote = null; + } + continue; + } + if (ch === "'" || ch === '"' || ch === "`") { + quote = ch; + continue; + } + if (ch === ")") { + depth += 1; + continue; + } + if (ch === "(") { + depth -= 1; + if (depth === 0) { + return index; + } + } + } + return -1; +} + +function findTypeTerminator(source, startIndex, terminatorChar) { + let quote = null; + const stack = []; + for (let index = startIndex; index < source.length; index += 1) { + const ch = source[index]; + const prev = index > startIndex ? source[index - 1] : ""; + if (quote) { + if (ch === quote && prev !== "\\") { + quote = null; + } + continue; + } + if (ch === "'" || ch === '"' || ch === "`") { + quote = ch; + continue; + } + if (ch === "/" && source[index + 1] === "/") { + const newline = source.indexOf("\n", index + 2); + if (newline === -1) { + return -1; + } + index = newline; + continue; + } + if (ch === "/" && source[index + 1] === "*") { + const commentEnd = source.indexOf("*/", index + 2); + if (commentEnd === -1) { + return -1; + } + index = commentEnd + 1; + continue; + } + if (ch === "(" || ch === "[" || ch === "{" || ch === "<") { + stack.push(ch); + continue; + } + if (ch === ")" || ch === "]" || ch === "}" || ch === ">") { + stack.pop(); + continue; + } + if (ch === terminatorChar && stack.length === 0) { + return index; + } + } + return -1; +} + +function stripTypeAnnotationsFromParams(paramsSource) { + let out = ""; + let index = 0; + let quote = null; + while (index < paramsSource.length) { + const ch = paramsSource[index]; + const prev = index > 0 ? paramsSource[index - 1] : ""; + if (quote) { + out += ch; + if (ch === quote && prev !== "\\") { + quote = null; + } + index += 1; + continue; + } + if (ch === "'" || ch === '"' || ch === "`") { + quote = ch; + out += ch; + index += 1; + continue; + } + if (ch === ":") { + let scan = index + 1; + let nestedQuote = null; + let parenDepth = 0; + let bracketDepth = 0; + let braceDepth = 0; + let angleDepth = 0; + while (scan < paramsSource.length) { + const current = paramsSource[scan]; + const currentPrev = scan > index + 1 ? paramsSource[scan - 1] : ""; + if (nestedQuote) { + if (current === nestedQuote && currentPrev !== "\\") { + nestedQuote = null; + } + scan += 1; + continue; + } + if (current === "'" || current === '"' || current === "`") { + nestedQuote = current; + scan += 1; + continue; + } + if (current === "(") parenDepth += 1; + else if (current === ")") { + if (parenDepth === 0 && bracketDepth === 0 && braceDepth === 0 && angleDepth === 0) break; + parenDepth -= 1; + } else if (current === "[") bracketDepth += 1; + else if (current === "]") bracketDepth -= 1; + else if (current === "{") braceDepth += 1; + else if (current === "}") braceDepth -= 1; + else if (current === "<") angleDepth += 1; + else if (current === ">") angleDepth -= 1; + else if ( + (current === "," || current === ")") && + parenDepth === 0 && + bracketDepth === 0 && + braceDepth === 0 && + angleDepth === 0 + ) { + break; + } + scan += 1; + } + while (scan < paramsSource.length && /\s/.test(paramsSource[scan])) { + scan += 1; + } + index = scan; + continue; + } + out += ch; + index += 1; + } + return out; +} + +function stripArrowFunctionTypes(source) { + let result = source; + let cursor = 0; + while (cursor < result.length) { + const arrowIndex = result.indexOf("=>", cursor); + if (arrowIndex === -1) { + break; + } + let paramEnd = arrowIndex - 1; + while (paramEnd >= 0 && /\s/.test(result[paramEnd])) { + paramEnd -= 1; + } + if (paramEnd < 0 || result[paramEnd] !== ")") { + cursor = arrowIndex + 2; + continue; + } + const paramStart = findMatchingParenStart(result, paramEnd); + if (paramStart === -1) { + cursor = arrowIndex + 2; + continue; + } + const params = result.slice(paramStart + 1, paramEnd); + const strippedParams = stripTypeAnnotationsFromParams(params); + result = + result.slice(0, paramStart + 1) + + strippedParams + + result.slice(paramEnd); + const nextArrow = result.indexOf("=>", paramStart); + cursor = nextArrow === -1 ? result.length : nextArrow + 2; + } + return result; +} + +function stripVariableDeclarationTypes(source) { + const declarationPattern = /\b(const|let|var)\s+([A-Za-z_$][\w$]*)\s*:/g; + let result = ""; + let lastIndex = 0; + let match; + while ((match = declarationPattern.exec(source)) !== null) { + const colonIndex = declarationPattern.lastIndex - 1; + const typeStart = colonIndex + 1; + const equalsIndex = findTypeTerminator(source, typeStart, "="); + if (equalsIndex === -1) { + continue; + } + result += source.slice(lastIndex, colonIndex); + lastIndex = equalsIndex; + declarationPattern.lastIndex = equalsIndex; + } + result += source.slice(lastIndex); + return result; +} + +function stripTypeScriptFallback(code) { + let result = String(code); + result = stripArrowFunctionTypes(result); + result = stripVariableDeclarationTypes(result); + return result; +} + function normalizeCode(code, language) { if (language === "typescript") { - if (typeof stripTypeScriptTypes !== "function") { - throw new Error("TypeScript execution requires Node support for stripTypeScriptTypes"); + if ( + process.env.ALEPH_NODE_FORCE_TS_FALLBACK === "true" || + typeof stripTypeScriptTypes !== "function" + ) { + return stripTypeScriptFallback(code); } return stripTypeScriptTypes(code); } diff --git a/aleph/settings.py b/aleph/settings.py index 7d59521..bd94985 100644 --- a/aleph/settings.py +++ b/aleph/settings.py @@ -245,6 +245,10 @@ def _coerce_share_session(cls, value: object) -> bool | None: class MCPServerEnvSettings(_AlephBaseSettings): tool_docs: Literal["concise", "full"] = Field(default="concise", validation_alias="ALEPH_TOOL_DOCS") context_policy: str | None = Field(default=None, validation_alias="ALEPH_CONTEXT_POLICY") + action_policy: Literal["read-write", "read-only"] = Field( + default="read-write", + validation_alias="ALEPH_ACTION_POLICY", + ) workspace_root: str | None = Field(default=None, validation_alias="ALEPH_WORKSPACE_ROOT") remote_tool_timeout_seconds: float = Field( default=120.0, @@ -264,6 +268,16 @@ def _coerce_tool_docs(cls, value: object) -> Literal["concise", "full"]: return text # type: ignore[return-value] return "concise" + @field_validator("action_policy", mode="before") + @classmethod + def _coerce_action_policy(cls, value: object) -> Literal["read-write", "read-only"]: + text = (_strip_optional_text(value) or "read-write").lower() + if text in {"read-write", "workspace-write", "write"}: + return "read-write" + if text in {"read-only", "readonly", "safe"}: + return "read-only" + return "read-write" + @field_validator("context_policy", "workspace_root", "swarm_name", "swarm_context_prefix", mode="before") @classmethod def _strip_text_values(cls, value: object) -> str | None: diff --git a/pyproject.toml b/pyproject.toml index 9ebf05d..2cac9c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "aleph-rlm" -version = "0.9.2" +version = "0.9.3" description = "MCP server for recursive LLM reasoning—load context, iterate with search/code/think tools, converge on answers" readme = "README.md" license = { file = "LICENSE" } diff --git a/tests/test_compatibility_aliases.py b/tests/test_compatibility_aliases.py index 9679aea..af68beb 100644 --- a/tests/test_compatibility_aliases.py +++ b/tests/test_compatibility_aliases.py @@ -5,6 +5,7 @@ import pytest from aleph.core import Aleph +from aleph.mcp.actions import ActionConfig as RuntimeActionConfig, require_actions from aleph.mcp.local_server import ActionConfig, AlephMCPServerLocal from aleph.types import ContentFormat @@ -20,6 +21,33 @@ def test_server_normalizes_output_feedback_env_alias(sandbox_config) -> None: assert server.output_feedback == "metadata" +def test_local_server_reexports_runtime_action_config() -> None: + assert ActionConfig is RuntimeActionConfig + + +def test_runtime_require_actions_preserves_read_only_guards() -> None: + cfg = RuntimeActionConfig(enabled=True, action_policy="read-only") + + command_error = require_actions(cfg, confirm=True, requires_command=True) + write_error = require_actions(cfg, confirm=True, requires_write=True) + + assert command_error is not None + assert "read-only" in command_error + assert "Process execution is blocked" in command_error + + assert write_error is not None + assert "read-only" in write_error + assert "Filesystem writes are blocked" in write_error + + +def test_runtime_require_actions_preserves_confirmation_guard() -> None: + cfg = RuntimeActionConfig(enabled=True, require_confirmation=True) + + error = require_actions(cfg, confirm=False) + + assert error == "Confirmation required. Re-run with confirm=true." + + @pytest.mark.asyncio async def test_configure_accepts_minimal_output_feedback_alias(sandbox_config) -> None: server = AlephMCPServerLocal(sandbox_config=sandbox_config) diff --git a/tests/test_mcp_contracts.py b/tests/test_mcp_contracts.py new file mode 100644 index 0000000..3425b92 --- /dev/null +++ b/tests/test_mcp_contracts.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import pytest + +from aleph.mcp.local_server import ActionConfig, AlephMCPServerLocal +from aleph.repl.sandbox import SandboxConfig + + +async def _call_tool(server: AlephMCPServerLocal, tool_name: str, **kwargs: Any) -> Any: + _, payload = await server.server.call_tool(tool_name, kwargs) + return payload["result"] + + +@pytest.mark.asyncio +async def test_workspace_and_status_object_json_contracts(tmp_path: Path) -> None: + (tmp_path / "README.md").write_text("# Demo\n", encoding="utf-8") + src_dir = tmp_path / "src" + src_dir.mkdir() + (src_dir / "app.py").write_text("print('hi')\n", encoding="utf-8") + + server = AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0), + action_config=ActionConfig(enabled=True, workspace_root=tmp_path), + ) + + manifest_obj = await _call_tool( + server, + "load_workspace_manifest", + context_id="repo", + output="object", + confirm=True, + ) + manifest_json = await _call_tool( + server, + "load_workspace_manifest", + context_id="repo-json", + output="json", + confirm=True, + ) + assert set(manifest_obj.keys()) == { + "status", + "context_id", + "workspace_root", + "roots", + "file_count", + "truncated", + "binding", + "note", + "size_chars", + "size_lines", + } + assert set(json.loads(manifest_json).keys()) == set(manifest_obj.keys()) + + status_obj = await _call_tool(server, "get_status", context_id="repo", output="object") + status_json = await _call_tool(server, "get_status", context_id="repo", output="json") + expected_status_keys = { + "context_id", + "iterations", + "evidence_count", + "tasks_count", + "variables", + "size_chars", + "size_lines", + "workspace_root", + "workspace_root_source", + "context_policy", + "action_policy", + "auto_memory_pack", + "workspace_binding", + "workspace_binding_summary", + "workspace_binding_status", + } + assert set(status_obj.keys()) == expected_status_keys + assert set(json.loads(status_json).keys()) == expected_status_keys + + list_obj = await _call_tool(server, "list_contexts", output="object") + list_json = await _call_tool(server, "list_contexts", output="json") + assert set(list_obj.keys()) == {"count", "items"} + assert set(json.loads(list_json).keys()) == {"count", "items"} + assert list_obj["count"] >= 2 + assert set(list_obj["items"][0].keys()) == { + "id", + "chars", + "lines", + "iterations", + "evidence", + "workspace_binding", + "workspace_binding_summary", + } + + +@pytest.mark.asyncio +async def test_refresh_context_object_contract(tmp_path: Path) -> None: + file_path = tmp_path / "notes.txt" + file_path.write_text("alpha\n", encoding="utf-8") + server = AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0), + action_config=ActionConfig(enabled=True, workspace_root=tmp_path), + ) + await _call_tool( + server, + "load_file", + path="notes.txt", + context_id="notes", + confirm=True, + ) + + file_path.write_text("beta\n", encoding="utf-8") + refresh_obj = await _call_tool( + server, + "refresh_context", + context_id="notes", + output="object", + confirm=True, + ) + refresh_json = await _call_tool( + server, + "refresh_context", + context_id="notes", + output="json", + confirm=True, + ) + expected_keys = {"status", "context_id", "binding", "size_chars", "size_lines", "note"} + assert set(refresh_obj.keys()) == expected_keys + assert set(json.loads(refresh_json).keys()) == expected_keys diff --git a/tests/test_mcp_formatting.py b/tests/test_mcp_formatting.py new file mode 100644 index 0000000..aa66ab7 --- /dev/null +++ b/tests/test_mcp_formatting.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from aleph.mcp.local_server import _format_payload + + +def test_local_server_format_payload_redacts_ctx_and_truncates_large_strings() -> None: + payload = { + "ctx": "alpha\n" * 200, + "note": "z" * 20_000, + } + + rendered = _format_payload(payload, output="object") + + assert rendered["ctx"]["redacted"] is True + assert rendered["ctx"]["reason"] == "context_field_blocked" + assert rendered["ctx"]["original_chars"] == len(payload["ctx"]) + assert "value_preview" in rendered["ctx"] + + assert rendered["note"] != payload["note"] + assert "TRUNCATED" in rendered["note"] diff --git a/tests/test_mcp_recipe_runtime.py b/tests/test_mcp_recipe_runtime.py new file mode 100644 index 0000000..3fd7cf1 --- /dev/null +++ b/tests/test_mcp_recipe_runtime.py @@ -0,0 +1,420 @@ +"""Focused tests for the recipe_runtime extraction. + +Tests that: +1. recipe_preview / recipe_context_slice work as standalone functions +2. execute_recipe / compile_recipe_code work when called via the canonical module +3. The server wrapper methods delegate correctly +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from aleph.mcp.local_server import AlephMCPServerLocal, _Session, _analyze_text_context +from aleph.mcp.recipe_runtime import ( + compile_recipe_code, + execute_recipe, + recipe_context_slice, + recipe_preview, +) +from aleph.repl.sandbox import REPLEnvironment, SandboxConfig +from aleph.types import ContentFormat + + +def _make_server() -> AlephMCPServerLocal: + return AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0, max_output_chars=5000) + ) + + +async def _load_context( + server: AlephMCPServerLocal, text: str, context_id: str = "default" +) -> None: + meta = _analyze_text_context(text, ContentFormat.TEXT) + repl = REPLEnvironment( + context=text, + context_var_name="ctx", + config=server.sandbox_config, + loop=asyncio.get_running_loop(), + ) + repl.set_variable("line_number_base", 1) + server._sessions[context_id] = _Session(repl=repl, meta=meta, line_number_base=1) + + +class TestRecipePreview: + def test_short_string(self) -> None: + assert recipe_preview("hello") == "hello" + + def test_long_string_truncated(self) -> None: + long_text = "x" * 300 + result = recipe_preview(long_text) + assert len(result) == 180 + assert result.endswith("...") + + def test_custom_limit(self) -> None: + result = recipe_preview("a" * 100, limit=50) + assert len(result) == 50 + assert result.endswith("...") + + def test_list_value(self) -> None: + result = recipe_preview(["hello", "world"]) + assert isinstance(result, str) + + def test_exact_limit(self) -> None: + text = "x" * 180 + assert recipe_preview(text) == text + + +class TestRecipeContextSlice: + def test_no_field_returns_text(self) -> None: + result = recipe_context_slice("hello world", None) + assert result == "hello world" + + def test_dict_field_extraction(self) -> None: + data = {"name": "Alice", "age": "30"} + result = recipe_context_slice(data, "name") + assert "Alice" in result + + def test_list_of_dicts(self) -> None: + data = [ + {"name": "Alice", "age": "30"}, + {"name": "Bob", "age": "25"}, + ] + result = recipe_context_slice(data, "name") + assert "Alice" in result + assert "Bob" in result + + def test_list_of_non_dicts(self) -> None: + data = ["hello", "world"] + result = recipe_context_slice(data, "name") + assert "hello" in result + + def test_missing_field(self) -> None: + data = {"name": "Alice"} + result = recipe_context_slice(data, "missing") + assert "None" in result + + +class TestExecuteRecipeViaRuntime: + @pytest.mark.asyncio + async def test_dry_run(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "search", "pattern": "Line"}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe, dry_run=True) + assert ok is True + assert payload["mode"] == "dry_run" + assert "estimate" in payload + + @pytest.mark.asyncio + async def test_search_step(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "search", "pattern": "Line"}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + assert payload["step_count"] == 1 + assert payload["trace"][0]["op"] == "search" + assert payload["trace"][0]["result_count"] > 0 + + @pytest.mark.asyncio + async def test_peek_step(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "peek", "start": 0, "end": 11}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + assert "Hello World" in str(payload["value"]) + + @pytest.mark.asyncio + async def test_filter_step(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "search", "pattern": "Line"}, + {"op": "filter", "pattern": "2"}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + + @pytest.mark.asyncio + async def test_assign_load_cycle(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "search", "pattern": "Line"}, + {"op": "assign", "name": "results"}, + {"op": "load", "name": "results"}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + + @pytest.mark.asyncio + async def test_take_step_string(self) -> None: + server = _make_server() + await _load_context(server, "Hello World") + + recipe = { + "steps": [ + {"op": "peek", "start": 0, "end": 11}, + {"op": "take", "count": 5}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + assert payload["value"] == "Hello" + + @pytest.mark.asyncio + async def test_take_step_list(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "search", "pattern": "Line"}, + {"op": "take", "count": 1}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + + @pytest.mark.asyncio + async def test_lines_step(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "lines", "start": 0, "end": 1}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + + @pytest.mark.asyncio + async def test_finalize_stops_early(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "search", "pattern": "Line"}, + {"op": "finalize"}, + {"op": "search", "pattern": "should_not_run"}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + assert payload["step_count"] == 3 + assert len(payload["trace"]) == 2 + + @pytest.mark.asyncio + async def test_invalid_recipe_returns_errors(self) -> None: + server = _make_server() + ok, payload = await execute_recipe(server, recipe={"steps": "not_a_list"}) + assert ok is False + assert "errors" in payload + + @pytest.mark.asyncio + async def test_missing_context_returns_error(self) -> None: + server = _make_server() + + recipe = { + "steps": [{"op": "search", "pattern": "test"}], + "context_id": "nonexistent", + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is False + assert "error" in payload + assert "nonexistent" in payload["error"] + + @pytest.mark.asyncio + async def test_budget_exceeded(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "search", "pattern": "Line"}, + {"op": "search", "pattern": "Line"}, + ], + "budget": {"max_steps": 1, "max_sub_queries": 0}, + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is False + assert "exceeded" in payload.get("error", "").lower() + + @pytest.mark.asyncio + async def test_store_variable(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2\nLine 3") + + recipe = { + "steps": [ + {"op": "search", "pattern": "Line", "store": "search_results"}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + assert "search_results" in payload.get("variables", []) + + @pytest.mark.asyncio + async def test_context_id_override(self) -> None: + server = _make_server() + await _load_context(server, "Hello World", context_id="custom") + + recipe = { + "steps": [{"op": "search", "pattern": "Hello"}], + } + ok, payload = await execute_recipe( + server, recipe=recipe, context_id_override="custom" + ) + assert ok is True + assert payload["context_id"] == "custom" + + +class TestCompileRecipeCodeViaRuntime: + @pytest.mark.asyncio + async def test_compile_dict_return(self) -> None: + server = _make_server() + await _load_context(server, "Hello World") + + code = """ +recipe = { + "steps": [ + {"op": "search", "pattern": "Hello"}, + ] +} +""" + ok, payload = await compile_recipe_code( + server, code=code, context_id="default", language="python" + ) + assert ok is True + assert "recipe" in payload + assert payload["recipe"]["steps"][0]["op"] == "search" + assert "estimate" in payload + + @pytest.mark.asyncio + async def test_compile_no_recipe_value(self) -> None: + server = _make_server() + await _load_context(server, "Hello World") + + code = "x = 42" + ok, payload = await compile_recipe_code( + server, code=code, context_id="default", language="python" + ) + assert ok is False + assert "error" in payload + + @pytest.mark.asyncio + async def test_compile_invalid_recipe(self) -> None: + server = _make_server() + await _load_context(server, "Hello World") + + code = 'recipe = {"steps": "not_a_list"}' + ok, payload = await compile_recipe_code( + server, code=code, context_id="default", language="python" + ) + assert ok is False + assert "error" in payload + + @pytest.mark.asyncio + async def test_compile_missing_context(self) -> None: + server = _make_server() + ok, payload = await compile_recipe_code( + server, code="recipe = {}", context_id="nonexistent" + ) + assert ok is False + assert "error" in payload + + +class TestServerWrapperDelegation: + """Tests that the server methods on AlephMCPServerLocal delegate correctly.""" + + @pytest.mark.asyncio + async def test_server_execute_recipe_delegates(self) -> None: + server = _make_server() + await _load_context(server, "Hello World\nLine 2") + + recipe = { + "steps": [{"op": "search", "pattern": "Hello"}], + } + ok, payload = await server._execute_recipe(recipe=recipe) + assert ok is True + assert payload["step_count"] == 1 + + @pytest.mark.asyncio + async def test_server_compile_recipe_code_delegates(self) -> None: + server = _make_server() + await _load_context(server, "Hello World") + + code = 'recipe = {"steps": [{"op": "search", "pattern": "Hello"}]}' + ok, payload = await server._compile_recipe_code( + code=code, context_id="default", language="python" + ) + assert ok is True + + def test_server_recipe_preview_delegates(self) -> None: + server = _make_server() + result = server._recipe_preview("hello") + assert result == "hello" + + def test_server_recipe_context_slice_delegates(self) -> None: + server = _make_server() + result = server._recipe_context_slice("hello world", None) + assert result == "hello world" + + +class TestChunkStep: + @pytest.mark.asyncio + async def test_chunk_op(self) -> None: + server = _make_server() + text = "Hello World\n" * 20 + await _load_context(server, text) + + recipe = { + "steps": [ + {"op": "chunk", "chunk_size": 50, "overlap": 0}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + assert payload["trace"][0]["result_count"] > 1 + + @pytest.mark.asyncio + async def test_chunk_with_overlap(self) -> None: + server = _make_server() + text = "Hello World\n" * 20 + await _load_context(server, text) + + recipe = { + "steps": [ + {"op": "chunk", "chunk_size": 50, "overlap": 10}, + ] + } + ok, payload = await execute_recipe(server, recipe=recipe) + assert ok is True + assert payload["trace"][0]["result_count"] > 1 diff --git a/tests/test_mcp_server_bootstrap.py b/tests/test_mcp_server_bootstrap.py index 03d8df2..9af332a 100644 --- a/tests/test_mcp_server_bootstrap.py +++ b/tests/test_mcp_server_bootstrap.py @@ -6,11 +6,13 @@ from unittest.mock import patch from aleph.mcp.local_server import ( + DEFAULT_ACTION_POLICY, DEFAULT_CONTEXT_POLICY, DEFAULT_TOOL_DOCS_MODE, DEFAULT_WORKSPACE_MODE, ActionConfig, SandboxConfig, + _normalize_action_policy, _normalize_context_policy, ) from aleph.mcp.server_bootstrap import ( @@ -44,6 +46,7 @@ def test_apply_server_env_overrides_sets_sub_query_and_swarm_env(): sub_query_codex_reasoning_effort="high", sub_query_codex_profile="subquery", context_policy="isolated", + action_policy="read-only", swarm_mode=True, swarm_name="release-cutover", enable_session_sharing=True, @@ -65,6 +68,7 @@ def test_apply_server_env_overrides_sets_sub_query_and_swarm_env(): assert os.environ["ALEPH_SUB_QUERY_CODEX_REASONING_EFFORT"] == "high" assert os.environ["ALEPH_SUB_QUERY_CODEX_PROFILE"] == "subquery" assert os.environ["ALEPH_CONTEXT_POLICY"] == "isolated" + assert os.environ["ALEPH_ACTION_POLICY"] == "read-only" assert os.environ["ALEPH_SWARM_MODE"] == "true" assert os.environ["ALEPH_SWARM_NAME"] == "release-cutover" assert os.environ["ALEPH_SWARM_SESSION_SHARING"] == "true" @@ -89,6 +93,7 @@ def test_build_runtime_configs_uses_explicit_workspace_root(tmp_path: Path): max_file_size=456, max_write_bytes=789, tool_docs="full", + action_policy=None, ) with patch.dict(os.environ, {}, clear=True): @@ -96,7 +101,9 @@ def test_build_runtime_configs_uses_explicit_workspace_root(tmp_path: Path): args, detect_workspace_root=lambda: auto_root, normalize_context_policy=_normalize_context_policy, + normalize_action_policy=_normalize_action_policy, default_context_policy=DEFAULT_CONTEXT_POLICY, + default_action_policy=DEFAULT_ACTION_POLICY, sandbox_config_factory=SandboxConfig, action_config_factory=ActionConfig, ) @@ -108,6 +115,7 @@ def test_build_runtime_configs_uses_explicit_workspace_root(tmp_path: Path): assert action_config.workspace_root == workspace_root.resolve() assert action_config.workspace_mode == "git" assert action_config.context_policy == DEFAULT_CONTEXT_POLICY + assert action_config.action_policy == DEFAULT_ACTION_POLICY assert action_config.require_confirmation is True assert action_config.max_read_bytes == 456 assert action_config.max_write_bytes == 789 diff --git a/tests/test_mcp_sub_query_orchestration.py b/tests/test_mcp_sub_query_orchestration.py new file mode 100644 index 0000000..93fb7f0 --- /dev/null +++ b/tests/test_mcp_sub_query_orchestration.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from aleph.config import AlephConfig +from aleph.mcp.local_server import ( + _analyze_text_context, + _build_sub_aleph_cli_prompt, + _extract_final_answer, +) +from aleph.mcp.sub_query_orchestration import ( + build_sub_aleph_cli_prompt, + extract_final_answer, + format_streamable_http_url, + normalize_streamable_http_path, +) +from aleph.types import ContentFormat + + +def test_extract_final_answer_variants_match_local_server_wrapper() -> None: + assert extract_final_answer("prefix FINAL(done)") == ("done", True) + assert extract_final_answer("FINAL_VAR('named_result')") == ("named_result", True) + assert _extract_final_answer("plain answer") == ("plain answer", False) + + +def test_build_sub_aleph_cli_prompt_redacts_context_preview() -> None: + cfg = AlephConfig( + system_prompt=( + "Query={query}\n" + "Preview={context_preview}\n" + "Format={context_format}\n" + "Chars={context_size_chars}" + ) + ) + context_slice = "top secret context" + + module_prompt = build_sub_aleph_cli_prompt( + query="Summarize", + context_slice=context_slice, + context_format=ContentFormat.TEXT, + cfg=cfg, + analyze_text_context=_analyze_text_context, + ) + wrapper_prompt = _build_sub_aleph_cli_prompt( + query="Summarize", + context_slice=context_slice, + context_format=ContentFormat.TEXT, + cfg=cfg, + ) + + assert module_prompt == wrapper_prompt + assert "[OMITTED FOR CONTEXT ISOLATION]" in module_prompt + assert "top secret context" not in module_prompt + assert f"Chars={len(context_slice)}" in module_prompt + + +def test_streamable_http_helpers_normalize_urls() -> None: + assert normalize_streamable_http_path("") == "/mcp" + assert normalize_streamable_http_path("rpc") == "/rpc" + assert normalize_streamable_http_path("/custom") == "/custom" + assert format_streamable_http_url("0.0.0.0", 8765, "/mcp") == "http://127.0.0.1:8765/mcp" + assert format_streamable_http_url("::", 8765, "/mcp") == "http://127.0.0.1:8765/mcp" + assert format_streamable_http_url("localhost", 8765, "/mcp") == "http://localhost:8765/mcp" diff --git a/tests/test_node_bridge.py b/tests/test_node_bridge.py new file mode 100644 index 0000000..d06c37d --- /dev/null +++ b/tests/test_node_bridge.py @@ -0,0 +1,175 @@ +"""Focused tests for the node_bridge extraction. + +Tests that: +1. close_node_repl / configure_node_repl / get_or_create_node_repl / + sync_session_from_node_repl work as standalone functions +2. The server wrapper methods delegate correctly +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from aleph.mcp.local_server import AlephMCPServerLocal, _Session, _analyze_text_context +from aleph.mcp.node_bridge import ( + close_node_repl, + configure_node_repl, + get_or_create_node_repl, + sync_session_from_node_repl, +) +from aleph.repl.sandbox import REPLEnvironment, SandboxConfig +from aleph.types import ContentFormat + + +def _make_server() -> AlephMCPServerLocal: + return AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0, max_output_chars=5000) + ) + + +def _make_session(text: str = "test context") -> _Session: + meta = _analyze_text_context(text, ContentFormat.TEXT) + repl = REPLEnvironment( + context=text, + context_var_name="ctx", + config=SandboxConfig(timeout_seconds=5.0, max_output_chars=5000), + ) + repl.set_variable("line_number_base", 1) + return _Session(repl=repl, meta=meta, line_number_base=1) + + +class TestCloseNodeRepl: + def test_close_existing(self) -> None: + mock_repl = MagicMock() + node_repls: dict = {"ctx1": mock_repl} + close_node_repl(node_repls, "ctx1") + mock_repl.close.assert_called_once() + assert "ctx1" not in node_repls + + def test_close_nonexistent(self) -> None: + node_repls: dict = {} + close_node_repl(node_repls, "missing") # should not raise + + def test_close_leaves_others(self) -> None: + mock1 = MagicMock() + mock2 = MagicMock() + node_repls: dict = {"a": mock1, "b": mock2} + close_node_repl(node_repls, "a") + assert "b" in node_repls + mock2.close.assert_not_called() + + +class TestConfigureNodeRepl: + def test_registers_expected_callbacks(self) -> None: + session = _make_session() + # Inject minimal callables so configure_node_repl can find them + session.repl.set_variable("sub_query", lambda p, c=None: "ok") + session.repl.set_variable("sub_aleph", lambda q, c=None: "ok") + session.repl.set_variable("set_backend", lambda b: "ok") + session.repl.set_variable("get_config", lambda: {}) + + mock_repl = MagicMock() + configure_node_repl(mock_repl, session) + + registered = {call.args[0] for call in mock_repl.register_callback.call_args_list} + expected = { + "sub_query", "sub_query_map", "sub_query_batch", + "sub_query_strict", "sub_aleph", "set_backend", "get_config", + } + assert expected == registered + + +class TestGetOrCreateNodeRepl: + def test_missing_session_raises(self) -> None: + with pytest.raises(KeyError): + get_or_create_node_repl({}, {}, "missing", SandboxConfig()) + + @pytest.mark.asyncio + async def test_creates_and_caches(self) -> None: + sessions: dict = {"default": _make_session("hello world")} + node_repls: dict = {} + cfg = SandboxConfig(timeout_seconds=5.0) + + # Inject minimal callables for configure_node_repl + sessions["default"].repl.set_variable("sub_query", lambda p, c=None: "ok") + sessions["default"].repl.set_variable("sub_aleph", lambda q, c=None: "ok") + sessions["default"].repl.set_variable("set_backend", lambda b: "ok") + sessions["default"].repl.set_variable("get_config", lambda: {}) + + repl = get_or_create_node_repl(node_repls, sessions, "default", cfg) + assert "default" in node_repls + assert node_repls["default"] is repl + + # Second call returns same instance (not recreated) + repl2 = get_or_create_node_repl(node_repls, sessions, "default", cfg) + assert repl2 is repl + + +class TestSyncSessionFromNodeRepl: + def test_no_node_repl_returns_empty(self) -> None: + result = sync_session_from_node_repl({}, {}, "missing", _analyze_text_context) + assert result == [] + + def test_no_session_returns_empty(self) -> None: + mock_repl = MagicMock() + result = sync_session_from_node_repl( + {"ctx1": mock_repl}, {}, "ctx1", _analyze_text_context + ) + assert result == [] + + def test_syncs_context_back(self) -> None: + session = _make_session("original text") + sessions: dict = {"ctx1": session} + + mock_node_repl = MagicMock() + mock_node_repl.get_variable.return_value = "updated text from node" + mock_node_repl.drain_citations.return_value = [{"cite": "test"}] + node_repls: dict = {"ctx1": mock_node_repl} + + citations = sync_session_from_node_repl( + node_repls, sessions, "ctx1", _analyze_text_context + ) + + assert citations == [{"cite": "test"}] + assert session.repl.get_variable("ctx") == "updated text from node" + assert session.meta.size_chars == len("updated text from node") + + +class TestServerDelegation: + """Verify server methods delegate to node_bridge functions.""" + + def test_close_node_repl_delegates(self) -> None: + server = _make_server() + mock_repl = MagicMock() + server._node_repls["test"] = mock_repl + server._close_node_repl("test") + mock_repl.close.assert_called_once() + assert "test" not in server._node_repls + + def test_configure_node_repl_delegates(self) -> None: + server = _make_server() + session = _make_session() + session.repl.set_variable("sub_query", lambda p, c=None: "ok") + session.repl.set_variable("sub_aleph", lambda q, c=None: "ok") + session.repl.set_variable("set_backend", lambda b: "ok") + session.repl.set_variable("get_config", lambda: {}) + + mock_repl = MagicMock() + server._configure_node_repl(mock_repl, session) + assert mock_repl.register_callback.call_count == 7 + + def test_sync_session_from_node_repl_delegates(self) -> None: + server = _make_server() + session = _make_session("original") + server._sessions["ctx1"] = session + + mock_node = MagicMock() + mock_node.get_variable.return_value = "new text" + mock_node.drain_citations.return_value = [] + server._node_repls["ctx1"] = mock_node + + result = server._sync_session_from_node_repl("ctx1") + assert result == [] + assert session.repl.get_variable("ctx") == "new text" diff --git a/tests/test_node_runtime.py b/tests/test_node_runtime.py index 189e696..2e797e8 100644 --- a/tests/test_node_runtime.py +++ b/tests/test_node_runtime.py @@ -31,6 +31,27 @@ def test_exec_typescript_expression(self, sandbox_config) -> None: finally: repl.close() + def test_exec_typescript_expression_with_fallback_strip(self, sandbox_config, monkeypatch) -> None: + monkeypatch.setenv("ALEPH_NODE_FORCE_TS_FALLBACK", "true") + repl = NodeREPLEnvironment(context="hello", config=sandbox_config) + try: + result = repl.execute( + """ +const routes: string[] = ["read", "write"]; +const mapped = routes.map((route: string) => route.toUpperCase()); +const report: { routeCount: number; first: string } = { + routeCount: mapped.length, + first: mapped[0], +}; +report + """, + language="typescript", + ) + assert result.error is None + assert result.return_value == {"routeCount": 2, "first": "READ"} + finally: + repl.close() + def test_context_helpers_and_variable_lookup(self, sandbox_config) -> None: repl = NodeREPLEnvironment(context="Line 1: Hello World\nLine 2: Goodbye", config=sandbox_config) try: @@ -377,6 +398,20 @@ def test_recipe_typescript_compilation(self, sandbox_config) -> None: finally: repl.close() + def test_recipe_typescript_compilation_with_fallback_strip(self, sandbox_config, monkeypatch) -> None: + monkeypatch.setenv("ALEPH_NODE_FORCE_TS_FALLBACK", "true") + repl = NodeREPLEnvironment(context="test", config=sandbox_config) + try: + result = repl.execute( + 'const r: object = Recipe("ts").search("err").take(2).compile(); r', + language="typescript", + ) + assert result.error is None + assert result.return_value["steps"][0]["op"] == "search" + assert result.return_value["steps"][1]["count"] == 2 + finally: + repl.close() + def test_recipe_json_serialization(self, sandbox_config) -> None: repl = NodeREPLEnvironment(context="test", config=sandbox_config) try: diff --git a/tests/test_repl_injection.py b/tests/test_repl_injection.py new file mode 100644 index 0000000..70b7adc --- /dev/null +++ b/tests/test_repl_injection.py @@ -0,0 +1,173 @@ +"""Focused tests for the repl_injection extraction.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from aleph.mcp.local_server import AlephMCPServerLocal, _Session, _analyze_text_context +from aleph.mcp.repl_injection import ( + configure_session, + inject_repl_config_helpers, + inject_repl_sub_aleph, + inject_repl_sub_query, +) +from aleph.repl.sandbox import REPLEnvironment, SandboxConfig +from aleph.types import AlephResponse, ContentFormat + + +def _make_server() -> AlephMCPServerLocal: + return AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0, max_output_chars=5000) + ) + + +def _make_session(text: str = "test context") -> _Session: + meta = _analyze_text_context(text, ContentFormat.TEXT) + repl = REPLEnvironment( + context=text, + context_var_name="ctx", + config=SandboxConfig(timeout_seconds=5.0, max_output_chars=5000), + ) + repl.set_variable("line_number_base", 1) + return _Session(repl=repl, meta=meta, line_number_base=1) + + +class TestInjectReplConfigHelpers: + def test_registers_set_backend_and_get_config(self) -> None: + server = _make_server() + session = _make_session() + + inject_repl_config_helpers(server, session) + + set_backend = session.repl.get_variable("set_backend") + get_config = session.repl.get_variable("get_config") + + assert callable(set_backend) + assert callable(get_config) + with patch.dict("os.environ", {}, clear=False): + result = set_backend("codex") + snapshot = get_config() + + assert "sub_query_backend set to 'codex'" in result + assert snapshot["sub_query_backend"] == "codex" + + +class TestInjectReplSubQuery: + @pytest.mark.asyncio + async def test_success_passthrough(self) -> None: + server = _make_server() + session = _make_session() + server._run_sub_query = AsyncMock(return_value=(True, "OK", False, "codex")) # type: ignore[method-assign] + + inject_repl_sub_query(server, session, "ctx1") + + result = await session.repl._sub_query_fn("summarize", "slice") # type: ignore[misc] + + assert result == "OK" + server._run_sub_query.assert_awaited_once() + + @pytest.mark.asyncio + async def test_failure_wraps_error(self) -> None: + server = _make_server() + session = _make_session() + server._run_sub_query = AsyncMock(return_value=(False, "boom", False, "codex")) # type: ignore[method-assign] + + inject_repl_sub_query(server, session, "ctx1") + + result = await session.repl._sub_query_fn("summarize", None) # type: ignore[misc] + + assert result == "[ERROR: sub_query failed: boom]" + + +class TestInjectReplSubAleph: + @pytest.mark.asyncio + async def test_coerces_structured_context(self) -> None: + server = _make_server() + session = _make_session() + response = AlephResponse( + answer="done", + success=True, + total_iterations=1, + max_depth_reached=1, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=0.0, + trajectory=[], + ) + server._run_sub_aleph = AsyncMock(return_value=(response, {})) # type: ignore[method-assign] + + inject_repl_sub_aleph(server, session, "ctx1") + + result = await session.repl._sub_aleph_fn("q", {"a": 1}) # type: ignore[misc] + + assert result is response + kwargs = server._run_sub_aleph.await_args.kwargs + assert kwargs["context_slice"] == '{\n "a": 1\n}' + assert kwargs["context_id"] == "ctx1" + + +class TestConfigureSession: + @pytest.mark.asyncio + async def test_sets_loop_and_injects_helpers(self) -> None: + server = _make_server() + session = _make_session() + server._run_sub_query = AsyncMock(return_value=(True, "OK", False, "codex")) # type: ignore[method-assign] + response = AlephResponse( + answer="done", + success=True, + total_iterations=1, + max_depth_reached=1, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=0.0, + trajectory=[], + ) + server._run_sub_aleph = AsyncMock(return_value=(response, {})) # type: ignore[method-assign] + + loop = asyncio.get_running_loop() + configure_session(server, session, "ctx1", loop=loop) + + assert session.repl._loop is loop # type: ignore[attr-defined] + assert callable(session.repl.get_variable("set_backend")) + assert callable(session.repl.get_variable("get_config")) + assert session.repl._sub_query_fn is not None # type: ignore[attr-defined] + assert session.repl._sub_aleph_fn is not None # type: ignore[attr-defined] + + +class TestServerWrapperDelegation: + @pytest.mark.asyncio + async def test_server_configure_session_delegates(self) -> None: + server = _make_server() + session = _make_session() + server._run_sub_query = AsyncMock(return_value=(True, "OK", False, "codex")) # type: ignore[method-assign] + response = AlephResponse( + answer="done", + success=True, + total_iterations=1, + max_depth_reached=1, + total_tokens=0, + total_cost_usd=0.0, + wall_time_seconds=0.0, + trajectory=[], + ) + server._run_sub_aleph = AsyncMock(return_value=(response, {})) # type: ignore[method-assign] + + server._configure_session(session, "ctx1", loop=asyncio.get_running_loop()) + + assert callable(session.repl.get_variable("set_backend")) + assert session.repl._sub_query_fn is not None # type: ignore[attr-defined] + assert session.repl._sub_aleph_fn is not None # type: ignore[attr-defined] + + @pytest.mark.asyncio + async def test_server_inject_repl_sub_query_delegates(self) -> None: + server = _make_server() + session = _make_session() + server._run_sub_query = AsyncMock(return_value=(True, "OK", False, "codex")) # type: ignore[method-assign] + + server._inject_repl_sub_query(session, "ctx1") + + result = await session.repl._sub_query_fn("summarize", None) # type: ignore[misc] + assert result == "OK" diff --git a/tests/test_workspace_tools.py b/tests/test_workspace_tools.py new file mode 100644 index 0000000..91c9554 --- /dev/null +++ b/tests/test_workspace_tools.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from aleph.mcp.local_server import ActionConfig, AlephMCPServerLocal +from aleph.repl.sandbox import SandboxConfig + + +async def _call_tool(server: AlephMCPServerLocal, tool_name: str, **kwargs: Any) -> Any: + _, payload = await server.server.call_tool(tool_name, kwargs) + return payload["result"] + + +@pytest.mark.asyncio +async def test_workspace_manifest_contract_and_refresh(tmp_path: Path) -> None: + (tmp_path / "README.md").write_text("# Repo\n", encoding="utf-8") + src_dir = tmp_path / "src" + src_dir.mkdir() + (src_dir / "main.py").write_text("print('alpha')\n", encoding="utf-8") + + server = AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0), + action_config=ActionConfig(enabled=True, workspace_root=tmp_path), + ) + + manifest = await _call_tool( + server, + "load_workspace_manifest", + context_id="workspace", + output="object", + confirm=True, + ) + assert manifest["status"] == "success" + assert manifest["binding"]["kind"] == "manifest" + initial_file_count = manifest["binding"]["file_count"] + assert initial_file_count >= 2 + + status = await _call_tool(server, "get_status", context_id="workspace", output="object") + assert status["action_policy"] == "read-write" + assert status["workspace_binding"]["kind"] == "manifest" + assert status["workspace_binding_status"]["kind"] == "manifest" + + listed = await _call_tool(server, "list_contexts", output="object") + item = next(entry for entry in listed["items"] if entry["id"] == "workspace") + assert item["workspace_binding_summary"].startswith("manifest:") + + (src_dir / "extra.py").write_text("print('beta')\n", encoding="utf-8") + refreshed = await _call_tool( + server, + "refresh_context", + context_id="workspace", + output="object", + confirm=True, + ) + assert refreshed["status"] == "success" + assert refreshed["binding"]["kind"] == "manifest" + assert refreshed["binding"]["file_count"] == initial_file_count + 1 + + +@pytest.mark.asyncio +async def test_load_file_creates_refreshable_workspace_binding(tmp_path: Path) -> None: + file_path = tmp_path / "notes.txt" + file_path.write_text("alpha\n", encoding="utf-8") + + server = AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0), + action_config=ActionConfig(enabled=True, workspace_root=tmp_path), + ) + + result = await _call_tool( + server, + "load_file", + path="notes.txt", + context_id="notes", + confirm=True, + ) + assert "Context loaded 'notes'" in result + + status = await _call_tool(server, "get_status", context_id="notes", output="object") + assert status["workspace_binding"]["kind"] == "file" + assert status["workspace_binding_summary"] == "file:notes.txt" + assert status["workspace_binding_status"]["exists"] is True + assert status["workspace_binding_status"]["stale"] is False + + file_path.write_text("beta\n", encoding="utf-8") + refreshed = await _call_tool( + server, + "refresh_context", + context_id="notes", + output="object", + confirm=True, + ) + assert refreshed["status"] == "success" + assert server._sessions["notes"].repl.get_variable("ctx") == "beta\n" + + +@pytest.mark.asyncio +async def test_refresh_context_preserves_reasoning_and_task_state(tmp_path: Path) -> None: + file_path = tmp_path / "story.txt" + file_path.write_text("alpha\nbeta\ngamma\n", encoding="utf-8") + + server = AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0), + action_config=ActionConfig(enabled=True, workspace_root=tmp_path), + ) + await _call_tool( + server, + "load_file", + path="story.txt", + context_id="story", + confirm=True, + ) + await _call_tool(server, "search_context", context_id="story", pattern="beta") + await _call_tool( + server, + "tasks", + context_id="story", + action="add", + description="check beta branch", + ) + await _call_tool(server, "think", context_id="story", question="What changed in beta?") + + file_path.write_text("alpha\nbeta\nbeta-2\ngamma\n", encoding="utf-8") + refreshed = await _call_tool( + server, + "refresh_context", + context_id="story", + output="object", + confirm=True, + ) + assert refreshed["status"] == "success" + + status = await _call_tool(server, "get_status", context_id="story", output="object") + assert status["tasks_count"] == 1 + assert status["evidence_count"] >= 2 + assert status["iterations"] >= 4 + + task_list = await _call_tool(server, "tasks", context_id="story", action="list") + assert "check beta branch" in task_list + + assert "What changed in beta?" in server._sessions["story"].think_history + + +@pytest.mark.asyncio +async def test_read_only_action_policy_blocks_writes_and_commands(tmp_path: Path) -> None: + server = AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0), + action_config=ActionConfig( + enabled=True, + workspace_root=tmp_path, + action_policy="read-only", + ), + ) + await _call_tool(server, "load_context", context="persist me", context_id="doc") + + save_result = await _call_tool( + server, + "save_session", + path="pack.json", + confirm=True, + output="object", + ) + assert "read-only" in save_result["error"] + + write_result = await _call_tool( + server, + "write_file", + path="blocked.txt", + content="nope", + confirm=True, + output="object", + ) + assert "read-only" in write_result["error"] + + command_result = await _call_tool( + server, + "run_command", + cmd="echo hi", + confirm=True, + output="object", + ) + assert "read-only" in command_result["error"] + + tests_result = await _call_tool( + server, + "run_tests", + confirm=True, + output="object", + ) + assert "read-only" in tests_result["error"] + + +@pytest.mark.asyncio +async def test_workspace_binding_persists_through_memory_pack(tmp_path: Path) -> None: + file_path = tmp_path / "persisted.txt" + file_path.write_text("v1\n", encoding="utf-8") + + server = AlephMCPServerLocal( + sandbox_config=SandboxConfig(timeout_seconds=5.0), + action_config=ActionConfig(enabled=True, workspace_root=tmp_path), + ) + await _call_tool( + server, + "load_file", + path="persisted.txt", + context_id="persisted", + confirm=True, + ) + await _call_tool( + server, + "tasks", + context_id="persisted", + action="add", + description="re-open after restore", + ) + + save_result = await _call_tool( + server, + "save_session", + path="workspace-pack.json", + confirm=True, + output="object", + ) + assert save_result["status"] == "success" + + server._sessions.clear() + load_result = await _call_tool( + server, + "load_session", + path="workspace-pack.json", + confirm=True, + output="object", + ) + assert "persisted" in load_result["loaded"] + + status = await _call_tool(server, "get_status", context_id="persisted", output="object") + assert status["workspace_binding"]["kind"] == "file" + assert status["workspace_binding_summary"] == "file:persisted.txt" + assert status["tasks_count"] == 1 + + task_list = await _call_tool(server, "tasks", context_id="persisted", action="list") + assert "re-open after restore" in task_list + + file_path.write_text("v2\n", encoding="utf-8") + refreshed = await _call_tool( + server, + "refresh_context", + context_id="persisted", + confirm=True, + output="object", + ) + assert refreshed["status"] == "success" + assert server._sessions["persisted"].repl.get_variable("ctx") == "v2\n" diff --git a/web/index.html b/web/index.html index 26c3af6..9f4ebb1 100644 --- a/web/index.html +++ b/web/index.html @@ -615,7 +615,7 @@
-
v0.9.2
+
v0.9.3

Load It
Once.

Aleph keeps big working context in RAM so agents can search, run Python, recurse, and return small answers instead of re-sending the whole file.