diff --git a/Cargo.lock b/Cargo.lock index 37c3795..ffb7f9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -511,9 +511,22 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "conformance-server" +version = "0.1.0" +dependencies = [ + "async-trait", + "base64 0.22.1", + "pulseengine-mcp-protocol", + "pulseengine-mcp-server", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "conformance-tests" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "chrono", @@ -1038,7 +1051,7 @@ dependencies = [ [[package]] name = "hello-world-with-auth" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "async-trait", @@ -2094,7 +2107,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-auth" -version = "0.15.0" +version = "0.16.0" dependencies = [ "aes-gcm", "anyhow", @@ -2134,7 +2147,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-external-validation" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "arbitrary", @@ -2172,7 +2185,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-integration-tests" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "assert_matches", @@ -2198,7 +2211,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-logging" -version = "0.15.0" +version = "0.16.0" dependencies = [ "chrono", "hex", @@ -2216,7 +2229,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-macros" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "async-trait", @@ -2242,7 +2255,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-protocol" -version = "0.15.0" +version = "0.16.0" dependencies = [ "async-trait", "chrono", @@ -2259,7 +2272,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-security" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "async-trait", @@ -2281,7 +2294,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-security-middleware" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "assert_matches", @@ -2314,7 +2327,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-server" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "async-trait", @@ -2342,7 +2355,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-transport" -version = "0.15.0" +version = "0.16.0" dependencies = [ "anyhow", "async-stream", diff --git a/Cargo.toml b/Cargo.toml index 6108e99..42df842 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,18 +11,19 @@ members = [ "mcp-external-validation", "integration-tests", "conformance-tests", - # Examples (5 total - consolidated from 19) + # Examples (6 total - consolidated from 19) "examples/hello-world", # Minimal starter example "examples/hello-world-with-auth", # Security/auth integration "examples/ultra-simple", # Macro showcase (8 lines) "examples/ui-enabled-server", # MCP Apps Extension (SEP-1865) "examples/resources-demo", # Resource handling patterns + "examples/conformance-server", # MCP conformance test server ] resolver = "2" [workspace.package] -version = "0.15.0" +version = "0.16.0" rust-version = "1.88" edition = "2024" license = "MIT OR Apache-2.0" @@ -100,15 +101,15 @@ assert_matches = "1.5" serde_yaml = "0.9" # Framework internal dependencies (published versions) -pulseengine-mcp-protocol = { version = "0.15.0", path = "mcp-protocol" } -pulseengine-mcp-logging = { version = "0.15.0", path = "mcp-logging" } -pulseengine-mcp-auth = { version = "0.15.0", path = "mcp-auth" } -pulseengine-mcp-security = { version = "0.15.0", path = "mcp-security" } -pulseengine-mcp-security-middleware = { version = "0.15.0", path = "mcp-security-middleware" } -pulseengine-mcp-transport = { version = "0.15.0", path = "mcp-transport" } -pulseengine-mcp-server = { version = "0.15.0", path = "mcp-server" } -pulseengine-mcp-macros = { version = "0.15.0", path = "mcp-macros" } -pulseengine-mcp-external-validation = { version = "0.15.0", path = "mcp-external-validation" } +pulseengine-mcp-protocol = { version = "0.16.0", path = "mcp-protocol" } +pulseengine-mcp-logging = { version = "0.16.0", path = "mcp-logging" } +pulseengine-mcp-auth = { version = "0.16.0", path = "mcp-auth" } +pulseengine-mcp-security = { version = "0.16.0", path = "mcp-security" } +pulseengine-mcp-security-middleware = { version = "0.16.0", path = "mcp-security-middleware" } +pulseengine-mcp-transport = { version = "0.16.0", path = "mcp-transport" } +pulseengine-mcp-server = { version = "0.16.0", path = "mcp-server" } +pulseengine-mcp-macros = { version = "0.16.0", path = "mcp-macros" } +pulseengine-mcp-external-validation = { version = "0.16.0", path = "mcp-external-validation" } [profile.release] opt-level = "s" diff --git a/examples/conformance-server/Cargo.toml b/examples/conformance-server/Cargo.toml new file mode 100644 index 0000000..f3a04e0 --- /dev/null +++ b/examples/conformance-server/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "conformance-server" +version = "0.1.0" +edition = "2021" +description = "MCP conformance test server implementing all required fixtures" + +[[bin]] +name = "conformance-server" +path = "src/main.rs" + +[dependencies] +pulseengine-mcp-server = { path = "../../mcp-server" } +pulseengine-mcp-protocol = { path = "../../mcp-protocol" } +tokio = { version = "1.0", features = ["full"] } +async-trait = "0.1" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +base64 = "0.22" diff --git a/examples/conformance-server/src/main.rs b/examples/conformance-server/src/main.rs new file mode 100644 index 0000000..4e9d7f5 --- /dev/null +++ b/examples/conformance-server/src/main.rs @@ -0,0 +1,739 @@ +//! MCP Conformance Test Server +//! +//! This server implements all fixtures required by the official +//! `@modelcontextprotocol/conformance` test suite. +//! +//! Run with: cargo run --bin conformance-server +//! Test with: npx @modelcontextprotocol/conformance server --url http://localhost:3000/mcp + +use async_trait::async_trait; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine}; +use pulseengine_mcp_protocol::*; +use pulseengine_mcp_server::common_backend::CommonMcpError; +use pulseengine_mcp_server::{ + try_current_context, CreateMessageRequest, ElicitationRequest, McpBackend, McpServer, + SamplingContent, SamplingMessage, SamplingRole, ServerConfig, TransportConfig, +}; + +/// Minimal 1x1 red PNG image (base64 encoded) +const MINIMAL_PNG: &str = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="; + +/// Minimal WAV audio (44 bytes: RIFF header + minimal data) +fn minimal_wav_base64() -> String { + // Minimal valid WAV: 44-byte header with 0 data samples + let wav_bytes: Vec = vec![ + 0x52, 0x49, 0x46, 0x46, // "RIFF" + 0x24, 0x00, 0x00, 0x00, // File size - 8 + 0x57, 0x41, 0x56, 0x45, // "WAVE" + 0x66, 0x6D, 0x74, 0x20, // "fmt " + 0x10, 0x00, 0x00, 0x00, // Subchunk1Size (16) + 0x01, 0x00, // AudioFormat (1 = PCM) + 0x01, 0x00, // NumChannels (1) + 0x44, 0xAC, 0x00, 0x00, // SampleRate (44100) + 0x88, 0x58, 0x01, 0x00, // ByteRate + 0x02, 0x00, // BlockAlign + 0x10, 0x00, // BitsPerSample (16) + 0x64, 0x61, 0x74, 0x61, // "data" + 0x00, 0x00, 0x00, 0x00, // Subchunk2Size (0) + ]; + BASE64.encode(&wav_bytes) +} + +#[derive(Clone)] +struct ConformanceBackend; + +#[async_trait] +impl McpBackend for ConformanceBackend { + type Error = CommonMcpError; + type Config = (); + + async fn initialize(_config: Self::Config) -> std::result::Result { + Ok(Self) + } + + fn get_server_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::default(), + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .enable_prompts() + .enable_logging() + .build(), + server_info: Implementation::new("MCP Conformance Test Server", "1.0.0"), + instructions: Some("Conformance test server for MCP protocol validation".to_string()), + } + } + + async fn health_check(&self) -> std::result::Result<(), Self::Error> { + Ok(()) + } + + // ==================== TOOLS ==================== + + async fn list_tools( + &self, + _params: PaginatedRequestParam, + ) -> std::result::Result { + let empty_schema = serde_json::json!({ + "type": "object", + "properties": {} + }); + + Ok(ListToolsResult { + tools: vec![ + // tools-call-simple-text + Tool { + name: "test_simple_text".to_string(), + title: Some("Simple Text Tool".to_string()), + description: "Returns simple text content".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-image + Tool { + name: "test_image_content".to_string(), + title: Some("Image Content Tool".to_string()), + description: "Returns image content".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-audio + Tool { + name: "test_audio_content".to_string(), + title: Some("Audio Content Tool".to_string()), + description: "Returns audio content".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-embedded-resource + Tool { + name: "test_embedded_resource".to_string(), + title: Some("Embedded Resource Tool".to_string()), + description: "Returns embedded resource content".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-mixed-content (test_multiple_content_types) + Tool { + name: "test_multiple_content_types".to_string(), + title: Some("Multiple Content Types Tool".to_string()), + description: "Returns multiple content types".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-with-logging + Tool { + name: "test_tool_with_logging".to_string(), + title: Some("Tool With Logging".to_string()), + description: "Tool that emits log notifications".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-error + Tool { + name: "test_error_handling".to_string(), + title: Some("Error Handling Tool".to_string()), + description: "Tool that returns an error".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-with-progress + Tool { + name: "test_tool_with_progress".to_string(), + title: Some("Tool With Progress".to_string()), + description: "Tool that emits progress notifications".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-sampling + Tool { + name: "test_sampling".to_string(), + title: Some("Sampling Tool".to_string()), + description: "Tool that requests LLM sampling".to_string(), + input_schema: empty_schema.clone(), + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + // tools-call-elicitation + Tool { + name: "test_elicitation".to_string(), + title: Some("Elicitation Tool".to_string()), + description: "Tool that requests user input".to_string(), + input_schema: empty_schema, + output_schema: None, + annotations: None, + icons: None, + execution: None, + _meta: None, + }, + ], + next_cursor: None, + }) + } + + async fn call_tool( + &self, + request: CallToolRequestParam, + ) -> std::result::Result { + match request.name.as_str() { + "test_simple_text" => Ok(CallToolResult { + content: vec![Content::text( + "This is a simple text response from the test tool.", + )], + is_error: Some(false), + structured_content: None, + _meta: None, + }), + + "test_image_content" => Ok(CallToolResult { + content: vec![Content::image(MINIMAL_PNG, "image/png")], + is_error: Some(false), + structured_content: None, + _meta: None, + }), + + "test_audio_content" => Ok(CallToolResult { + content: vec![Content::audio(minimal_wav_base64(), "audio/wav")], + is_error: Some(false), + structured_content: None, + _meta: None, + }), + + "test_embedded_resource" => Ok(CallToolResult { + content: vec![Content::resource( + "test://static-text", + Some("text/plain".to_string()), + Some("Embedded resource content".to_string()), + )], + is_error: Some(false), + structured_content: None, + _meta: None, + }), + + "test_multiple_content_types" => Ok(CallToolResult { + content: vec![ + Content::text("Text content"), + Content::image(MINIMAL_PNG, "image/png"), + Content::resource( + "test://static-text", + Some("text/plain".to_string()), + Some("Resource content".to_string()), + ), + ], + is_error: Some(false), + structured_content: None, + _meta: None, + }), + + "test_tool_with_logging" => { + // Send log notifications if context is available + if let Some(ctx) = try_current_context() { + // Send a few log messages at different levels + let _ = ctx + .send_log( + LogLevel::Info, + Some("conformance_test"), + serde_json::json!({"message": "Starting tool execution"}), + ) + .await; + let _ = ctx + .send_log( + LogLevel::Debug, + Some("conformance_test"), + serde_json::json!({"step": 1, "action": "processing"}), + ) + .await; + let _ = ctx + .send_log( + LogLevel::Info, + Some("conformance_test"), + serde_json::json!({"message": "Tool execution completed"}), + ) + .await; + } + Ok(CallToolResult { + content: vec![Content::text("Tool executed with logging")], + is_error: Some(false), + structured_content: None, + _meta: None, + }) + } + + "test_error_handling" => Ok(CallToolResult { + content: vec![Content::text("This is an error message from the tool")], + is_error: Some(true), + structured_content: None, + _meta: None, + }), + + "test_tool_with_progress" => { + // Send progress notifications if context is available + if let Some(ctx) = try_current_context() { + // Simulate progress over a few steps + let total = 10u64; + for i in 0..=total { + let _ = ctx.send_progress(i, Some(total)).await; + // Small delay to make progress visible in tests + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } + } + Ok(CallToolResult { + content: vec![Content::text("Tool executed with progress")], + is_error: Some(false), + structured_content: None, + _meta: None, + }) + } + + "test_sampling" => { + // Request LLM sampling if context is available + if let Some(ctx) = try_current_context() { + let sampling_request = CreateMessageRequest { + messages: vec![SamplingMessage { + role: SamplingRole::User, + content: SamplingContent::Text { + text: "What is 2 + 2? Answer with just the number.".to_string(), + }, + }], + system_prompt: Some("You are a helpful assistant.".to_string()), + max_tokens: 100, + temperature: Some(0.0), + ..Default::default() + }; + + match ctx + .request_sampling(sampling_request, std::time::Duration::from_secs(30)) + .await + { + Ok(response) => { + let response_text = match &response.content { + SamplingContent::Text { text } => text.clone(), + SamplingContent::Image { .. } => "Image response".to_string(), + }; + return Ok(CallToolResult { + content: vec![Content::text(format!( + "LLM response: {} (model: {})", + response_text, response.model + ))], + is_error: Some(false), + structured_content: None, + _meta: None, + }); + } + Err(e) => { + return Ok(CallToolResult { + content: vec![Content::text(format!("Sampling error: {e}"))], + is_error: Some(true), + structured_content: None, + _meta: None, + }); + } + } + } + // No context available - return a message indicating this + Ok(CallToolResult { + content: vec![Content::text("Sampling not available (no context)")], + is_error: Some(false), + structured_content: None, + _meta: None, + }) + } + + "test_elicitation" => { + // Request user input via elicitation if context is available + if let Some(ctx) = try_current_context() { + let elicitation_request = ElicitationRequest { + message: "Please provide your name:".to_string(), + requested_schema: serde_json::json!({ + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Your name" + } + }, + "required": ["name"] + }), + meta: None, + }; + + match ctx + .request_elicitation( + elicitation_request, + std::time::Duration::from_secs(60), + ) + .await + { + Ok(response) => { + let user_input = response + .content + .as_ref() + .and_then(|c| c.get("name")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let action = match &response.action { + pulseengine_mcp_server::ElicitationAction::Accept => "accepted", + pulseengine_mcp_server::ElicitationAction::Decline => "declined", + pulseengine_mcp_server::ElicitationAction::Cancel => "cancelled", + }; + return Ok(CallToolResult { + content: vec![Content::text(format!( + "User {action} with name: {user_input}" + ))], + is_error: Some(false), + structured_content: None, + _meta: None, + }); + } + Err(e) => { + return Ok(CallToolResult { + content: vec![Content::text(format!("Elicitation error: {e}"))], + is_error: Some(true), + structured_content: None, + _meta: None, + }); + } + } + } + // No context available - return a message indicating this + Ok(CallToolResult { + content: vec![Content::text("Elicitation not available (no context)")], + is_error: Some(false), + structured_content: None, + _meta: None, + }) + } + + _ => Err(CommonMcpError::InvalidParams(format!( + "Unknown tool: {}", + request.name + ))), + } + } + + // ==================== RESOURCES ==================== + + async fn list_resources( + &self, + _params: PaginatedRequestParam, + ) -> std::result::Result { + Ok(ListResourcesResult { + resources: vec![ + // resources-read-text + Resource { + uri: "test://static-text".to_string(), + name: "Static Text Resource".to_string(), + title: None, + description: Some("A static text resource for conformance testing".to_string()), + mime_type: Some("text/plain".to_string()), + annotations: None, + icons: None, + raw: None, + _meta: None, + }, + // resources-read-binary + Resource { + uri: "test://static-binary".to_string(), + name: "Static Binary Resource".to_string(), + title: None, + description: Some("A static binary resource (PNG image)".to_string()), + mime_type: Some("image/png".to_string()), + annotations: None, + icons: None, + raw: None, + _meta: None, + }, + // resources-subscribe / resources-unsubscribe + Resource { + uri: "test://watched-resource".to_string(), + name: "Watched Resource".to_string(), + title: None, + description: Some("A resource that can be subscribed to".to_string()), + mime_type: Some("text/plain".to_string()), + annotations: None, + icons: None, + raw: None, + _meta: None, + }, + ], + next_cursor: None, + }) + } + + async fn list_resource_templates( + &self, + _params: PaginatedRequestParam, + ) -> std::result::Result { + Ok(ListResourceTemplatesResult { + resource_templates: vec![ + // resources-templates-read + ResourceTemplate { + uri_template: "test://template/{id}/data".to_string(), + name: "Template Resource".to_string(), + description: Some("A parameterized template resource".to_string()), + mime_type: Some("text/plain".to_string()), + }, + ], + next_cursor: None, + }) + } + + async fn read_resource( + &self, + params: ReadResourceRequestParam, + ) -> std::result::Result { + let uri = ¶ms.uri; + + // Handle static resources + if uri == "test://static-text" { + return Ok(ReadResourceResult { + contents: vec![ResourceContents { + uri: uri.clone(), + mime_type: Some("text/plain".to_string()), + text: Some("This is the content of the static text resource.".to_string()), + blob: None, + _meta: None, + }], + }); + } + + if uri == "test://static-binary" { + return Ok(ReadResourceResult { + contents: vec![ResourceContents { + uri: uri.clone(), + mime_type: Some("image/png".to_string()), + text: None, + blob: Some(MINIMAL_PNG.to_string()), + _meta: None, + }], + }); + } + + if uri == "test://watched-resource" { + return Ok(ReadResourceResult { + contents: vec![ResourceContents { + uri: uri.clone(), + mime_type: Some("text/plain".to_string()), + text: Some("Watched resource content".to_string()), + blob: None, + _meta: None, + }], + }); + } + + // Handle template resources: test://template/{id}/data + if uri.starts_with("test://template/") && uri.ends_with("/data") { + let id = uri + .strip_prefix("test://template/") + .and_then(|s| s.strip_suffix("/data")) + .unwrap_or("unknown"); + + return Ok(ReadResourceResult { + contents: vec![ResourceContents { + uri: uri.clone(), + mime_type: Some("text/plain".to_string()), + text: Some(format!("Template resource data for id: {id}")), + blob: None, + _meta: None, + }], + }); + } + + Err(CommonMcpError::InvalidParams(format!( + "Resource not found: {uri}" + ))) + } + + // ==================== PROMPTS ==================== + + async fn list_prompts( + &self, + _params: PaginatedRequestParam, + ) -> std::result::Result { + Ok(ListPromptsResult { + prompts: vec![ + // prompts-get-simple + Prompt { + name: "test_simple_prompt".to_string(), + title: Some("Simple Test Prompt".to_string()), + description: Some("A simple prompt without arguments".to_string()), + arguments: None, + icons: None, + }, + // prompts-get-with-args + Prompt { + name: "test_prompt_with_arguments".to_string(), + title: Some("Prompt With Arguments".to_string()), + description: Some("A prompt that requires arguments".to_string()), + arguments: Some(vec![ + PromptArgument { + name: "arg1".to_string(), + description: Some("First argument".to_string()), + required: Some(true), + }, + PromptArgument { + name: "arg2".to_string(), + description: Some("Second argument".to_string()), + required: Some(true), + }, + ]), + icons: None, + }, + // prompts-get-embedded-resource + Prompt { + name: "test_prompt_with_embedded_resource".to_string(), + title: Some("Prompt With Embedded Resource".to_string()), + description: Some("A prompt that includes an embedded resource".to_string()), + arguments: Some(vec![PromptArgument { + name: "resourceUri".to_string(), + description: Some("URI of the resource to embed".to_string()), + required: Some(true), + }]), + icons: None, + }, + // prompts-get-with-image + Prompt { + name: "test_prompt_with_image".to_string(), + title: Some("Prompt With Image".to_string()), + description: Some("A prompt that includes an image".to_string()), + arguments: None, + icons: None, + }, + ], + next_cursor: None, + }) + } + + async fn get_prompt( + &self, + params: GetPromptRequestParam, + ) -> std::result::Result { + match params.name.as_str() { + "test_simple_prompt" => Ok(GetPromptResult { + description: Some("A simple test prompt".to_string()), + messages: vec![PromptMessage::new_text( + PromptMessageRole::User, + "This is a simple test prompt message.", + )], + }), + + "test_prompt_with_arguments" => { + let arg1 = params + .arguments + .as_ref() + .and_then(|a| a.get("arg1")) + .map(|s| s.as_str()) + .unwrap_or("default1"); + let arg2 = params + .arguments + .as_ref() + .and_then(|a| a.get("arg2")) + .map(|s| s.as_str()) + .unwrap_or("default2"); + + Ok(GetPromptResult { + description: Some("A prompt with arguments".to_string()), + messages: vec![PromptMessage::new_text( + PromptMessageRole::User, + format!("Prompt with arg1={arg1} and arg2={arg2}"), + )], + }) + } + + "test_prompt_with_embedded_resource" => { + // Return actual embedded resource content + let resource_uri = params + .arguments + .as_ref() + .and_then(|a| a.get("resourceUri")) + .map(|s| s.as_str()) + .unwrap_or("test://static-text"); + + Ok(GetPromptResult { + description: Some("A prompt with an embedded resource".to_string()), + messages: vec![PromptMessage::new_resource( + PromptMessageRole::User, + resource_uri, + Some("text/plain".to_string()), + Some("This is the embedded resource content.".to_string()), + )], + }) + } + + "test_prompt_with_image" => Ok(GetPromptResult { + description: Some("A prompt with an image".to_string()), + messages: vec![PromptMessage::new_image( + PromptMessageRole::User, + MINIMAL_PNG, + "image/png", + )], + }), + + _ => Err(CommonMcpError::InvalidParams(format!( + "Unknown prompt: {}", + params.name + ))), + } + } +} + +#[tokio::main] +async fn main() -> std::result::Result<(), Box> { + let backend = ConformanceBackend::initialize(()).await?; + + let port = std::env::var("PORT") + .ok() + .and_then(|p| p.parse().ok()) + .unwrap_or(3000); + + let mut config = ServerConfig::default(); + config.auth_config.enabled = false; + config.transport_config = TransportConfig::StreamableHttp { port, host: None }; + + let mut server = McpServer::new(backend, config).await?; + + eprintln!("MCP Conformance Test Server running on http://localhost:{port}"); + eprintln!(); + eprintln!("Test with:"); + eprintln!(" npx @modelcontextprotocol/conformance server --url http://localhost:{port}/mcp"); + eprintln!(); + + server.run().await?; + Ok(()) +} diff --git a/mcp-macros/src/mcp_tool.rs b/mcp-macros/src/mcp_tool.rs index a27f275..1006017 100644 --- a/mcp-macros/src/mcp_tool.rs +++ b/mcp-macros/src/mcp_tool.rs @@ -131,10 +131,14 @@ pub fn mcp_tool_impl(attr: TokenStream, item: TokenStream) -> syn::Result syn::Result syn::Result, + /// Whether a ToolContext parameter was detected (first non-self param) + has_tool_context: bool, +} + /// Extract parameter information from function signature -fn extract_parameters( - sig: &syn::Signature, - tool_name: &str, -) -> syn::Result<(syn::Type, Vec)> { +/// +/// This function: +/// 1. Detects if the first non-self parameter is a ToolContext type +/// 2. Skips ToolContext from schema generation (it's injected at runtime) +/// 3. Extracts remaining parameters for JSON schema and extraction code +fn extract_parameters(sig: &syn::Signature, tool_name: &str) -> syn::Result { let mut param_fields = Vec::new(); let mut param_types = Vec::new(); let mut param_names = Vec::new(); + let mut has_tool_context = false; + let mut first_non_self_param = true; for input in &sig.inputs { match input { @@ -680,6 +722,16 @@ fn extract_parameters( continue; } syn::FnArg::Typed(pat_type) => { + // Check if this is the first non-self param and if it's ToolContext + if first_non_self_param { + first_non_self_param = false; + if is_tool_context_type(&pat_type.ty) { + has_tool_context = true; + // Skip ToolContext from schema generation and JSON extraction + continue; + } + } + if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { let param_name = &pat_ident.ident; let param_type = &*pat_type.ty; @@ -765,7 +817,11 @@ fn extract_parameters( })? }; - Ok((param_struct, param_fields)) + Ok(ExtractedParameters { + param_struct, + param_fields, + has_tool_context, + }) } /// Parameters for tool implementation generation @@ -864,9 +920,12 @@ fn generate_tool_implementation( } /// Generate JSON schema for method parameters from function signature +/// +/// This function filters out ToolContext parameters since they are runtime-injected +/// and should not appear in the tool's input schema. fn generate_input_schema_for_method(sig: &syn::Signature) -> syn::Result { - // Collect non-self parameters - let params: Vec<_> = sig + // Collect non-self parameters, skipping ToolContext if it's the first one + let all_params: Vec<_> = sig .inputs .iter() .filter_map(|input| { @@ -878,6 +937,17 @@ fn generate_input_schema_for_method(sig: &syn::Signature) -> syn::Result = if let Some(first) = all_params.first() { + if is_tool_context_type(&first.ty) { + all_params.into_iter().skip(1).collect() + } else { + all_params + } + } else { + all_params + }; + match params.len() { 0 => { // No parameters - return empty schema @@ -982,6 +1052,83 @@ fn extract_option_inner_type(ty: &syn::Type) -> (bool, &syn::Type) { (false, ty) } +/// Check if a type is a ToolContext parameter +/// +/// Supports the following patterns: +/// - `Arc` +/// - `&dyn ToolContext` +/// - Custom type names containing "ToolContext" (for flexibility) +fn is_tool_context_type(ty: &syn::Type) -> bool { + match ty { + // Handle reference types: &dyn ToolContext + syn::Type::Reference(type_ref) => { + if let syn::Type::TraitObject(trait_obj) = &*type_ref.elem { + return trait_obj.bounds.iter().any(|bound| { + if let syn::TypeParamBound::Trait(trait_bound) = bound { + trait_bound + .path + .segments + .last() + .map(|seg| seg.ident == "ToolContext") + .unwrap_or(false) + } else { + false + } + }); + } + false + } + // Handle path types: Arc, Box, etc. + syn::Type::Path(type_path) => { + if let Some(segment) = type_path.path.segments.last() { + // Check if it's a wrapper like Arc, Box, Rc containing dyn ToolContext + if matches!(segment.ident.to_string().as_str(), "Arc" | "Box" | "Rc") { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + for arg in &args.args { + // Check for dyn ToolContext (collapsed pattern) + if let syn::GenericArgument::Type(syn::Type::TraitObject(trait_obj)) = + arg + { + return trait_obj.bounds.iter().any(|bound| { + if let syn::TypeParamBound::Trait(trait_bound) = bound { + trait_bound + .path + .segments + .last() + .map(|seg| seg.ident == "ToolContext") + .unwrap_or(false) + } else { + false + } + }); + } + } + } + } + // Also check if the type itself is named ToolContext (for direct use) + if segment.ident == "ToolContext" { + return true; + } + } + false + } + // Handle trait objects directly: dyn ToolContext + syn::Type::TraitObject(trait_obj) => trait_obj.bounds.iter().any(|bound| { + if let syn::TypeParamBound::Trait(trait_bound) = bound { + trait_bound + .path + .segments + .last() + .map(|seg| seg.ident == "ToolContext") + .unwrap_or(false) + } else { + false + } + }), + _ => false, + } +} + /// Check if a type is a primitive or standard library type (not a custom struct) fn is_primitive_or_std_type(ty: &syn::Type) -> bool { match ty { @@ -1080,6 +1227,8 @@ fn generate_parameter_schema(sig: &syn::Signature) -> syn::Result { } /// Generate parameter extraction and method call for tools +/// +/// This function handles ToolContext detection and injection for the #[mcp_tools] macro. fn generate_method_call_with_params( sig: &syn::Signature, method_name: &syn::Ident, @@ -1088,12 +1237,24 @@ fn generate_method_call_with_params( let mut param_declarations = Vec::new(); let mut param_names = Vec::new(); let mut param_types = Vec::new(); + let mut has_tool_context = false; + let mut first_non_self_param = true; - // Collect all parameters (skip self) + // Collect all parameters (skip self and ToolContext) for input in &sig.inputs { match input { syn::FnArg::Receiver(_) => continue, // Skip self syn::FnArg::Typed(pat_type) => { + // Check if this is the first non-self param and if it's ToolContext + if first_non_self_param { + first_non_self_param = false; + if is_tool_context_type(&pat_type.ty) { + has_tool_context = true; + // Skip ToolContext from JSON extraction - it's injected at runtime + continue; + } + } + if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { let param_name = &pat_ident.ident; let param_type = &*pat_type.ty; @@ -1158,8 +1319,16 @@ fn generate_method_call_with_params( } } - if param_declarations.is_empty() { - // No parameters - call method directly + // Generate context acquisition if needed + let context_decl = if has_tool_context { + quote! { let __tool_ctx = pulseengine_mcp_server::current_context(); } + } else { + quote! {} + }; + + // Generate method call with optional context as first parameter + if param_declarations.is_empty() && !has_tool_context { + // No parameters and no context - call method directly if is_async { Ok(quote! { self.#method_name().await @@ -1169,8 +1338,44 @@ fn generate_method_call_with_params( self.#method_name() }) } + } else if param_declarations.is_empty() && has_tool_context { + // Only ToolContext parameter + if is_async { + Ok(quote! { + { + #context_decl + self.#method_name(__tool_ctx).await + } + }) + } else { + Ok(quote! { + { + #context_decl + self.#method_name(__tool_ctx) + } + }) + } + } else if has_tool_context { + // Has parameters AND ToolContext - context is first param + if is_async { + Ok(quote! { + { + #context_decl + #(#param_declarations)* + self.#method_name(__tool_ctx, #(#param_names),*).await + } + }) + } else { + Ok(quote! { + { + #context_decl + #(#param_declarations)* + self.#method_name(__tool_ctx, #(#param_names),*) + } + }) + } } else { - // Has parameters - extract them and call method + // Has parameters but no context if is_async { Ok(quote! { { diff --git a/mcp-protocol/src/model.rs b/mcp-protocol/src/model.rs index c6e61d6..3e7393c 100644 --- a/mcp-protocol/src/model.rs +++ b/mcp-protocol/src/model.rs @@ -613,15 +613,26 @@ pub enum Content { #[serde(rename = "image")] Image { data: String, + #[serde(rename = "mimeType")] + mime_type: String, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + _meta: Option, + }, + /// Audio content with base64-encoded data + #[serde(rename = "audio")] + Audio { + /// Base64-encoded audio data + data: String, + /// MIME type (e.g., "audio/wav", "audio/mp3") + #[serde(rename = "mimeType")] mime_type: String, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] _meta: Option, }, #[serde(rename = "resource")] Resource { - #[serde(with = "serde_json_string_or_object")] - resource: String, - text: Option, + /// The embedded resource contents + resource: EmbeddedResourceContents, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] _meta: Option, }, @@ -660,7 +671,11 @@ pub enum ToolResultContent { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "image")] - Image { data: String, mime_type: String }, + Image { + data: String, + #[serde(rename = "mimeType")] + mime_type: String, + }, } impl Content { @@ -679,10 +694,37 @@ impl Content { } } - pub fn resource(resource: impl Into, text: Option) -> Self { + /// Create audio content with base64-encoded data + pub fn audio(data: impl Into, mime_type: impl Into) -> Self { + Self::Audio { + data: data.into(), + mime_type: mime_type.into(), + _meta: None, + } + } + + /// Create embedded resource content + pub fn resource( + uri: impl Into, + mime_type: Option, + text: Option, + ) -> Self { Self::Resource { - resource: resource.into(), - text, + resource: EmbeddedResourceContents { + uri: uri.into(), + mime_type, + text, + blob: None, + _meta: None, + }, + _meta: None, + } + } + + /// Create embedded resource content from ResourceContents + pub fn from_resource_contents(contents: ResourceContents) -> Self { + Self::Resource { + resource: contents, _meta: None, } } @@ -769,17 +811,9 @@ impl Content { /// _meta: None, /// } /// ``` + /// Create a UI HTML resource content (for MCP Apps Extension / MCP-UI) pub fn ui_html(uri: impl Into, html: impl Into) -> Self { - let resource_json = serde_json::json!({ - "uri": uri.into(), - "mimeType": "text/html", - "text": html.into() - }); - Self::Resource { - resource: resource_json.to_string(), - text: None, - _meta: None, - } + Self::resource(uri, Some("text/html".to_string()), Some(html.into())) } /// Create a UI resource content with custom MIME type (for MCP Apps Extension / MCP-UI) @@ -803,16 +837,7 @@ impl Content { mime_type: impl Into, content: impl Into, ) -> Self { - let resource_json = serde_json::json!({ - "uri": uri.into(), - "mimeType": mime_type.into(), - "text": content.into() - }); - Self::Resource { - resource: resource_json.to_string(), - text: None, - _meta: None, - } + Self::resource(uri, Some(mime_type.into()), Some(content.into())) } /// Get text content if this is a text content type @@ -1136,8 +1161,11 @@ pub struct ReadResourceRequestParam { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ResourceContents { pub uri: String, + #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")] pub mime_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub blob: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub _meta: Option, @@ -1149,6 +1177,9 @@ pub struct ReadResourceResult { pub contents: Vec, } +/// Embedded resource contents for tool responses (alias for ResourceContents) +pub type EmbeddedResourceContents = ResourceContents; + /// Raw resource (for internal use) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RawResource { @@ -1183,16 +1214,50 @@ impl PromptMessage { }, } } + + /// Create a new resource message with embedded resource content + pub fn new_resource( + role: PromptMessageRole, + uri: impl Into, + mime_type: Option, + text: Option, + ) -> Self { + Self { + role, + content: PromptMessageContent::Resource { + resource: EmbeddedResourceContents { + uri: uri.into(), + mime_type, + text, + blob: None, + _meta: None, + }, + }, + } + } } impl CompleteResult { - /// Create a simple completion result - pub fn simple(completion: impl Into) -> Self { + /// Create a simple completion result with a single value + pub fn simple(value: impl Into) -> Self { Self { - completion: vec![CompletionInfo { - completion: completion.into(), + completion: CompletionValues { + values: vec![value.into()], + total: None, has_more: Some(false), - }], + }, + } + } + + /// Create a completion result with multiple values + pub fn with_values(values: Vec) -> Self { + let total = values.len() as u64; + Self { + completion: CompletionValues { + values, + total: Some(total), + has_more: Some(false), + }, } } } @@ -1203,7 +1268,9 @@ pub struct Prompt { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub title: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub icons: Option>, @@ -1248,7 +1315,17 @@ pub enum PromptMessageContent { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "image")] - Image { data: String, mime_type: String }, + Image { + data: String, + #[serde(rename = "mimeType")] + mime_type: String, + }, + /// Embedded resource content in prompts + #[serde(rename = "resource")] + Resource { + /// The embedded resource contents + resource: EmbeddedResourceContents, + }, } /// Prompt message @@ -1312,29 +1389,60 @@ impl CompletionContext { } } +/// Reference type for completion requests +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum CompletionRef { + /// Reference to a prompt for argument completion + #[serde(rename = "ref/prompt")] + Prompt { name: String }, + /// Reference to a resource for URI completion + #[serde(rename = "ref/resource")] + Resource { uri: String }, +} + +/// Completion argument being completed +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionArgument { + pub name: String, + pub value: String, +} + /// Completion request parameters #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompleteRequestParam { - pub ref_: String, - pub argument: serde_json::Value, + #[serde(rename = "ref")] + pub ref_: CompletionRef, + pub argument: CompletionArgument, /// Optional context for context-aware completion (MCP 2025-06-18) #[serde(skip_serializing_if = "Option::is_none")] pub context: Option, } -/// Completion information +/// Completion values object #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CompletionInfo { - pub completion: String, +pub struct CompletionValues { + /// Array of completion suggestions + pub values: Vec, + /// Optional total number of available matches + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, + /// Boolean indicating if additional results exist + #[serde(rename = "hasMore", skip_serializing_if = "Option::is_none")] pub has_more: Option, } /// Complete result #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompleteResult { - pub completion: Vec, + /// The completion object containing values and metadata + pub completion: CompletionValues, } +// Keep the old type for backwards compatibility +#[deprecated(note = "Use CompletionValues instead")] +pub type CompletionInfo = CompletionValues; + /// Set logging level parameters #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SetLevelRequestParam { @@ -2125,29 +2233,3 @@ impl Default for TaskSupport { Self::Optional } } - -/// Serde module for serializing/deserializing JSON strings as objects -mod serde_json_string_or_object { - use serde::{Deserialize, Deserializer, Serialize, Serializer}; - use serde_json::Value; - - pub fn serialize(value: &str, serializer: S) -> Result - where - S: Serializer, - { - // Parse the string as JSON and serialize it as an object - match serde_json::from_str::(value) { - Ok(json_value) => json_value.serialize(serializer), - Err(_) => serializer.serialize_str(value), // Fall back to string if not valid JSON - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - // Deserialize as JSON Value and convert to string - let value = Value::deserialize(deserializer)?; - Ok(value.to_string()) - } -} diff --git a/mcp-protocol/src/model_tests.rs b/mcp-protocol/src/model_tests.rs index 293da58..38ac7c2 100644 --- a/mcp-protocol/src/model_tests.rs +++ b/mcp-protocol/src/model_tests.rs @@ -169,12 +169,15 @@ mod tests { } // Resource content - let resource_content = - Content::resource("file://path/to/resource", Some("text".to_string())); + let resource_content = Content::resource( + "file://path/to/resource", + Some("text/plain".to_string()), + Some("text".to_string()), + ); match &resource_content { - Content::Resource { resource, text, .. } => { - assert_eq!(resource, "file://path/to/resource"); - assert_eq!(text.as_ref().unwrap(), "text"); + Content::Resource { resource, .. } => { + assert_eq!(resource.uri, "file://path/to/resource"); + assert_eq!(resource.text.as_ref().unwrap(), "text"); } _ => panic!("Expected resource content"), } @@ -397,9 +400,9 @@ mod tests { #[test] fn test_complete_result_simple() { let result = CompleteResult::simple("Completion text"); - assert_eq!(result.completion.len(), 1); - assert_eq!(result.completion[0].completion, "Completion text"); - assert_eq!(result.completion[0].has_more, Some(false)); + assert_eq!(result.completion.values.len(), 1); + assert_eq!(result.completion.values[0], "Completion text"); + assert_eq!(result.completion.has_more, Some(false)); } #[test] @@ -1451,4 +1454,208 @@ mod tests { assert!(task.is_terminal()); assert!(!task.is_running()); } + + // ============================================================================ + // New Content Type Tests (MCP 2025-11-25) + // ============================================================================ + + #[test] + fn test_content_audio() { + let content = Content::audio("base64audiodata", "audio/wav"); + + if let Content::Audio { + data, + mime_type, + _meta, + } = content + { + assert_eq!(data, "base64audiodata"); + assert_eq!(mime_type, "audio/wav"); + assert!(_meta.is_none()); + } else { + panic!("Expected Audio variant"); + } + } + + #[test] + fn test_content_audio_serialization() { + let content = Content::audio("ZGF0YQ==", "audio/mp3"); + + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains("\"audio\"")); + assert!(json.contains("\"ZGF0YQ==\"")); + assert!(json.contains("\"mimeType\":\"audio/mp3\"")); + + // Round-trip + let deserialized: Content = serde_json::from_str(&json).unwrap(); + if let Content::Audio { + data, mime_type, .. + } = deserialized + { + assert_eq!(data, "ZGF0YQ=="); + assert_eq!(mime_type, "audio/mp3"); + } else { + panic!("Expected Audio variant"); + } + } + + #[test] + fn test_content_resource_new_signature() { + let content = Content::resource( + "file:///path/to/file.txt", + Some("text/plain".to_string()), + Some("File contents here".to_string()), + ); + + if let Content::Resource { resource, _meta } = content { + assert_eq!(resource.uri, "file:///path/to/file.txt"); + assert_eq!(resource.mime_type, Some("text/plain".to_string())); + assert_eq!(resource.text, Some("File contents here".to_string())); + assert!(_meta.is_none()); + } else { + panic!("Expected Resource variant"); + } + } + + #[test] + fn test_content_resource_minimal() { + let content = Content::resource("file:///test.bin", None, None); + + if let Content::Resource { resource, _meta } = content { + assert_eq!(resource.uri, "file:///test.bin"); + assert!(resource.mime_type.is_none()); + assert!(resource.text.is_none()); + } else { + panic!("Expected Resource variant"); + } + } + + #[test] + fn test_content_from_resource_contents() { + let resource_contents = ResourceContents { + uri: "http://example.com/data.json".to_string(), + mime_type: Some("application/json".to_string()), + text: Some("{\"key\": \"value\"}".to_string()), + blob: None, + _meta: None, + }; + + let content = Content::from_resource_contents(resource_contents); + + if let Content::Resource { resource, _meta } = content { + assert_eq!(resource.uri, "http://example.com/data.json"); + assert_eq!(resource.mime_type, Some("application/json".to_string())); + assert_eq!(resource.text, Some("{\"key\": \"value\"}".to_string())); + } else { + panic!("Expected Resource variant"); + } + } + + #[test] + fn test_prompt_message_new_resource() { + let message = PromptMessage::new_resource( + PromptMessageRole::User, + "file:///document.pdf", + Some("application/pdf".to_string()), + Some("PDF text contents".to_string()), + ); + + // Use matches! macro since PromptMessageRole doesn't implement PartialEq + assert!(matches!(message.role, PromptMessageRole::User)); + if let PromptMessageContent::Resource { resource } = message.content { + assert_eq!(resource.uri, "file:///document.pdf"); + assert_eq!(resource.mime_type, Some("application/pdf".to_string())); + assert_eq!(resource.text, Some("PDF text contents".to_string())); + } else { + panic!("Expected Resource content"); + } + } + + #[test] + fn test_prompt_message_new_resource_minimal() { + let message = PromptMessage::new_resource( + PromptMessageRole::Assistant, + "file:///image.png", + None, + None, + ); + + assert!(matches!(message.role, PromptMessageRole::Assistant)); + if let PromptMessageContent::Resource { resource } = message.content { + assert_eq!(resource.uri, "file:///image.png"); + assert!(resource.mime_type.is_none()); + assert!(resource.text.is_none()); + } else { + panic!("Expected Resource content"); + } + } + + #[test] + fn test_complete_result_simple_new_api() { + let result = CompleteResult::simple("completion_value"); + + assert_eq!(result.completion.values.len(), 1); + assert_eq!(result.completion.values[0], "completion_value"); + assert!(result.completion.total.is_none()); + assert_eq!(result.completion.has_more, Some(false)); + } + + #[test] + fn test_complete_result_with_values_new_api() { + let values = vec![ + "option1".to_string(), + "option2".to_string(), + "option3".to_string(), + ]; + let result = CompleteResult::with_values(values); + + assert_eq!(result.completion.values.len(), 3); + assert_eq!(result.completion.values[0], "option1"); + assert_eq!(result.completion.values[1], "option2"); + assert_eq!(result.completion.values[2], "option3"); + assert_eq!(result.completion.total, Some(3)); + assert_eq!(result.completion.has_more, Some(false)); + } + + #[test] + fn test_complete_result_with_values_empty_new_api() { + let result = CompleteResult::with_values(vec![]); + + assert!(result.completion.values.is_empty()); + assert_eq!(result.completion.total, Some(0)); + assert_eq!(result.completion.has_more, Some(false)); + } + + #[test] + fn test_complete_result_serialization_new_api() { + let result = CompleteResult::simple("test_completion"); + + let json = serde_json::to_string(&result).unwrap(); + assert!(json.contains("\"values\"")); + assert!(json.contains("\"test_completion\"")); + + // Round-trip + let deserialized: CompleteResult = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.completion.values.len(), 1); + assert_eq!(deserialized.completion.values[0], "test_completion"); + } + + #[test] + fn test_resource_contents_serialization() { + let contents = ResourceContents { + uri: "file:///test.txt".to_string(), + mime_type: Some("text/plain".to_string()), + text: Some("Hello".to_string()), + blob: None, + _meta: None, + }; + + let json = serde_json::to_string(&contents).unwrap(); + assert!(json.contains("\"uri\":\"file:///test.txt\"")); + assert!(json.contains("\"mimeType\":\"text/plain\"")); + assert!(json.contains("\"text\":\"Hello\"")); + // blob and _meta should not be serialized when None + assert!(!json.contains("\"blob\"")); + assert!(!json.contains("\"_meta\"")); + } } diff --git a/mcp-server/src/backend.rs b/mcp-server/src/backend.rs index af49795..10b76ab 100644 --- a/mcp-server/src/backend.rs +++ b/mcp-server/src/backend.rs @@ -184,7 +184,13 @@ pub trait McpBackend: Send + Sync + Clone { request: CompleteRequestParam, ) -> std::result::Result { let _ = request; - Ok(CompleteResult { completion: vec![] }) + Ok(CompleteResult { + completion: CompletionValues { + values: vec![], + total: None, + has_more: None, + }, + }) } // Elicitation (optional) diff --git a/mcp-server/src/backend_tests.rs b/mcp-server/src/backend_tests.rs index eb28c0a..8373945 100644 --- a/mcp-server/src/backend_tests.rs +++ b/mcp-server/src/backend_tests.rs @@ -355,16 +355,18 @@ async fn test_mock_backend_optional_methods() { // Test complete (default implementation) let complete_result = backend .complete(CompleteRequestParam { - ref_: "test://resource".to_string(), - argument: serde_json::json!({ - "name": "test", - "value": "test" - }), + ref_: pulseengine_mcp_protocol::CompletionRef::Resource { + uri: "test://resource".to_string(), + }, + argument: pulseengine_mcp_protocol::CompletionArgument { + name: "test".to_string(), + value: "test".to_string(), + }, context: None, }) .await .unwrap(); - assert!(complete_result.completion.is_empty()); + assert!(complete_result.completion.values.is_empty()); // Test set level (default accepts any level) let set_level_result = backend diff --git a/mcp-server/src/handler.rs b/mcp-server/src/handler.rs index 65eb61f..696ba4c 100644 --- a/mcp-server/src/handler.rs +++ b/mcp-server/src/handler.rs @@ -1,9 +1,11 @@ //! Generic request handler for MCP protocol +use crate::tool_context::{NoOpToolContext, ToolContext, create_tool_context, with_context}; use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack}; use pulseengine_mcp_auth::AuthenticationManager; use pulseengine_mcp_logging::{get_metrics, spans}; use pulseengine_mcp_protocol::*; +use pulseengine_mcp_transport::{Transport, try_current_session_id}; use std::collections::HashSet; use std::sync::Arc; @@ -73,6 +75,10 @@ pub struct GenericServerHandler { /// Note: This is a simplified global implementation. For per-client /// subscriptions, use a `HashMap>` instead. subscriptions: Arc>>, + /// Optional transport reference for bidirectional communication (shared across clones) + /// When set, enables tools to send notifications and make requests to the client. + /// Uses Arc> so that all clones of the handler share the same transport. + transport: Arc>>>, } /// Helper to create a JSON-RPC response with a result @@ -121,6 +127,66 @@ impl GenericServerHandler { auth_manager, middleware, subscriptions: Arc::new(RwLock::new(HashSet::new())), + transport: Arc::new(RwLock::new(None)), + } + } + + /// Set the transport for bidirectional communication + /// + /// When set, tools can send notifications and make requests to the client + /// using the task-local ToolContext. This updates the shared transport + /// reference, so all clones of this handler will see the change. + pub fn set_transport(&self, transport: Arc) { + // Use try_write to avoid blocking; if we can't acquire the lock, + // another thread is updating it which is fine. + if let Ok(mut guard) = self.transport.try_write() { + *guard = Some(transport); + } + } + + /// Create a ToolContext for the current request + /// + /// If a transport is set and supports bidirectional communication, returns + /// a context that can send notifications and requests. Otherwise returns + /// a no-op context. + async fn make_tool_context( + &self, + request_id: String, + tool_name: String, + progress_token: Option, + session_id: Option, + ) -> Arc { + // Read the transport from the shared RwLock + let transport_guard = self.transport.read().await; + if let Some(ref transport) = *transport_guard { + let supports_bidir = transport.supports_bidirectional(); + eprintln!( + "[DEBUG] make_tool_context: tool={tool_name}, session_id={session_id:?}, supports_bidirectional={supports_bidir}" + ); + debug!( + tool = %tool_name, + supports_bidirectional = %supports_bidir, + "Creating tool context" + ); + if supports_bidir { + eprintln!("[DEBUG] Creating DefaultToolContext for {tool_name}"); + // Use the factory function from tool_context module + create_tool_context( + transport.clone(), + request_id, + tool_name, + progress_token, + session_id, + ) + } else { + eprintln!("[DEBUG] Transport doesn't support bidirectional for {tool_name}"); + debug!(tool = %tool_name, "Transport doesn't support bidirectional, using NoOp context"); + Arc::new(NoOpToolContext::new(request_id, tool_name)) + } + } else { + eprintln!("[DEBUG] No transport available for {tool_name}"); + debug!(tool = %tool_name, "No transport available, using NoOp context"); + Arc::new(NoOpToolContext::new(request_id, tool_name)) } } @@ -244,11 +310,28 @@ impl GenericServerHandler { #[instrument(skip(self, request), fields(mcp.method = "initialize"))] async fn handle_initialize(&self, request: Request) -> std::result::Result { - let _params: InitializeRequestParam = serde_json::from_value(request.params)?; + let params: InitializeRequestParam = serde_json::from_value(request.params)?; + + // Negotiate protocol version: use the client's version if we support it, + // otherwise fall back to the server's latest supported version + let negotiated_version = + if pulseengine_mcp_protocol::is_protocol_version_supported(¶ms.protocol_version) { + params.protocol_version.clone() + } else { + // If client version is unsupported, use our latest version + // The client may reject this if it doesn't support our version + pulseengine_mcp_protocol::MCP_VERSION.to_string() + }; + + info!( + client_version = %params.protocol_version, + negotiated_version = %negotiated_version, + "Protocol version negotiated" + ); let server_info = self.backend.get_server_info(); let result = InitializeResult { - protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(), + protocol_version: negotiated_version, capabilities: server_info.capabilities, server_info: server_info.server_info.clone(), instructions: server_info.instructions, @@ -275,18 +358,63 @@ impl GenericServerHandler { #[instrument(skip(self, request), fields(mcp.method = "tools/call"))] async fn handle_call_tool(&self, request: Request) -> std::result::Result { - let params: CallToolRequestParam = serde_json::from_value(request.params)?; + let params: CallToolRequestParam = serde_json::from_value(request.params.clone())?; let tool_name = params.name.clone(); let start_time = Instant::now(); + // Extract request ID for context + let request_id = request + .id + .as_ref() + .map(|id| match id { + NumberOrString::Number(n) => n.to_string(), + NumberOrString::String(s) => s.to_string(), + }) + .unwrap_or_else(|| "unknown".to_string()); + + // Extract progress token if provided in the request + // Per MCP spec, progressToken can be either a string or integer + let progress_token = request + .params + .get("_meta") + .and_then(|m| m.get("progressToken")) + .and_then(|t| { + // Handle both string and integer tokens + if let Some(s) = t.as_str() { + Some(s.to_string()) + } else if let Some(n) = t.as_i64() { + Some(n.to_string()) + } else { + t.as_u64().map(|n| n.to_string()) + } + }); + // Get metrics collector for tool-specific tracking let metrics = get_metrics(); metrics.record_request_start(&tool_name).await; + // Create the tool context for bidirectional communication + // The session ID is retrieved from task-local storage (set by the transport layer) + let session_id = try_current_session_id(); + eprintln!( + "[DEBUG] handle_call_tool: session_id from task-local = {:?}", + session_id + ); + let context = self + .make_tool_context(request_id, tool_name.clone(), progress_token, session_id) + .await; + let result = { let span = spans::backend_operation_span("call_tool", Some(&tool_name)); let _guard = span.enter(); - match self.backend.call_tool(params).await { + + // Execute the backend call within the context scope + // This makes the context available via try_current_context() in tools + let backend = self.backend.clone(); + let tool_result = + with_context(context, async move { backend.call_tool(params).await }).await; + + match tool_result { Ok(result) => { let duration = start_time.elapsed(); metrics.record_request_end(&tool_name, duration, true).await; @@ -466,8 +594,8 @@ mod tests { use pulseengine_mcp_auth::config::AuthConfig; use pulseengine_mcp_logging::ErrorClassification; use pulseengine_mcp_protocol::{ - CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult, CompletionInfo, - Content, Error, GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult, + CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult, Content, Error, + GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult, ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, ListToolsResult, LoggingCapability, PaginatedRequestParam, Prompt, PromptMessage, PromptMessageContent, PromptMessageRole, PromptsCapability, ProtocolVersion, ReadResourceRequestParam, @@ -730,16 +858,11 @@ mod tests { } Ok(CompleteResult { - completion: vec![ - CompletionInfo { - completion: "completion1".to_string(), - has_more: Some(false), - }, - CompletionInfo { - completion: "completion2".to_string(), - has_more: Some(false), - }, - ], + completion: CompletionValues { + values: vec!["completion1".to_string(), "completion2".to_string()], + total: Some(2), + has_more: Some(false), + }, }) } @@ -859,10 +982,8 @@ mod tests { assert!(response.error.is_none()); let result: InitializeResult = serde_json::from_value(response.result.unwrap()).unwrap(); - assert_eq!( - result.protocol_version, - pulseengine_mcp_protocol::MCP_VERSION - ); + // Server now negotiates version with client - returns client's version if supported + assert_eq!(result.protocol_version, "2024-11-05"); assert_eq!(result.server_info.name, "test-server"); } @@ -1101,7 +1222,10 @@ mod tests { jsonrpc: "2.0".to_string(), method: "completion/complete".to_string(), params: json!({ - "ref_": "test_prompt", + "ref": { + "type": "ref/prompt", + "name": "test_prompt" + }, "argument": { "name": "query", "value": "test" @@ -1121,7 +1245,7 @@ mod tests { assert!(response.error.is_none()); let result: CompleteResult = serde_json::from_value(response.result.unwrap()).unwrap(); - assert_eq!(result.completion.len(), 2); + assert_eq!(result.completion.values.len(), 2); } #[tokio::test] @@ -1296,4 +1420,149 @@ mod tests { let error = HandlerError::Backend("Test backend error".to_string()); assert_eq!(error.to_string(), "Backend error: Test backend error"); } + + // ============================================================================ + // Transport and ToolContext Tests + // ============================================================================ + + /// Mock transport for testing set_transport and make_tool_context + struct MockTestTransport { + supports_bidir: bool, + } + + #[async_trait::async_trait] + impl pulseengine_mcp_transport::Transport for MockTestTransport { + async fn start( + &mut self, + _handler: pulseengine_mcp_transport::RequestHandler, + ) -> std::result::Result<(), pulseengine_mcp_transport::TransportError> { + Ok(()) + } + + async fn stop( + &mut self, + ) -> std::result::Result<(), pulseengine_mcp_transport::TransportError> { + Ok(()) + } + + async fn health_check( + &self, + ) -> std::result::Result<(), pulseengine_mcp_transport::TransportError> { + Ok(()) + } + + fn supports_bidirectional(&self) -> bool { + self.supports_bidir + } + + async fn send_notification( + &self, + _session_id: Option<&str>, + _method: &str, + _params: serde_json::Value, + ) -> std::result::Result<(), pulseengine_mcp_transport::TransportError> { + Ok(()) + } + + async fn send_request( + &self, + _session_id: Option<&str>, + _method: &str, + _params: serde_json::Value, + _timeout: std::time::Duration, + ) -> std::result::Result + { + Ok(serde_json::json!({})) + } + } + + #[tokio::test] + async fn test_set_transport() { + let handler = create_test_handler().await; + + // Initially no transport + { + let guard = handler.transport.read().await; + assert!(guard.is_none()); + } + + // Set transport + let transport = Arc::new(MockTestTransport { + supports_bidir: true, + }) as Arc; + handler.set_transport(transport); + + // Now transport should be set + { + let guard = handler.transport.read().await; + assert!(guard.is_some()); + } + } + + #[tokio::test] + async fn test_make_tool_context_no_transport() { + let handler = create_test_handler().await; + + // Without transport, should return NoOp context + let ctx = handler + .make_tool_context( + "req-1".to_string(), + "test-tool".to_string(), + None, + Some("session-1".to_string()), + ) + .await; + + assert_eq!(ctx.request_id(), "req-1"); + assert_eq!(ctx.tool_name(), "test-tool"); + } + + #[tokio::test] + async fn test_make_tool_context_with_bidirectional_transport() { + let handler = create_test_handler().await; + + // Set bidirectional transport + let transport = Arc::new(MockTestTransport { + supports_bidir: true, + }) as Arc; + handler.set_transport(transport); + + let ctx = handler + .make_tool_context( + "req-2".to_string(), + "bidir-tool".to_string(), + Some("progress-token".to_string()), + Some("session-2".to_string()), + ) + .await; + + assert_eq!(ctx.request_id(), "req-2"); + assert_eq!(ctx.tool_name(), "bidir-tool"); + assert_eq!(ctx.progress_token(), Some("progress-token")); + assert_eq!(ctx.session_id(), Some("session-2")); + } + + #[tokio::test] + async fn test_make_tool_context_with_non_bidirectional_transport() { + let handler = create_test_handler().await; + + // Set non-bidirectional transport + let transport = Arc::new(MockTestTransport { + supports_bidir: false, + }) as Arc; + handler.set_transport(transport); + + // Should return NoOp context since transport doesn't support bidirectional + let ctx = handler + .make_tool_context( + "req-3".to_string(), + "non-bidir-tool".to_string(), + None, + None, + ) + .await; + + assert_eq!(ctx.request_id(), "req-3"); + assert_eq!(ctx.tool_name(), "non-bidir-tool"); + } } diff --git a/mcp-server/src/handler_tests.rs b/mcp-server/src/handler_tests.rs index ce47fb0..6f0c83b 100644 --- a/mcp-server/src/handler_tests.rs +++ b/mcp-server/src/handler_tests.rs @@ -684,7 +684,10 @@ async fn test_handler_optional_methods() { ))), method: "completion/complete".to_string(), params: serde_json::json!({ - "ref_": "test://resource", + "ref": { + "type": "ref/resource", + "uri": "test://resource" + }, "argument": {"name": "test", "value": "test"} }), }; diff --git a/mcp-server/src/lib.rs b/mcp-server/src/lib.rs index 2aa6d45..9a22dd0 100644 --- a/mcp-server/src/lib.rs +++ b/mcp-server/src/lib.rs @@ -125,6 +125,7 @@ pub mod builder_trait; pub mod cli_helpers; pub mod common_backend; pub mod observability; +pub mod tool_context; pub mod backend; pub mod context; @@ -151,6 +152,8 @@ mod lib_tests; mod middleware_tests; #[cfg(test)] mod server_tests; +#[cfg(test)] +mod tool_context_tests; // Re-export core types pub use backend::{BackendError, McpBackend}; @@ -163,6 +166,13 @@ pub use context::RequestContext; pub use handler::{GenericServerHandler, HandlerError}; pub use middleware::{Middleware, MiddlewareStack}; pub use server::{McpServer, ServerConfig, ServerError}; +pub use tool_context::{ + CreateMessageRequest, CreateMessageResult, DefaultToolContext, ElicitationAction, + ElicitationRequest, ElicitationResult, IncludeContext, LogNotificationParams, ModelHint, + ModelPreferences, NoOpToolContext, NotificationSender, ProgressNotificationParams, + RequestSender, SamplingContent, SamplingMessage, SamplingRole, ToolContext, ToolContextError, + TransportBridge, create_tool_context, current_context, try_current_context, with_context, +}; // Re-export CLI helpers pub use cli_helpers::{CliError, DefaultLoggingConfig, LogFormat, LogOutput, create_server_info}; diff --git a/mcp-server/src/server.rs b/mcp-server/src/server.rs index fb2430e..397029b 100644 --- a/mcp-server/src/server.rs +++ b/mcp-server/src/server.rs @@ -2,6 +2,7 @@ use crate::observability::{MetricsCollector, MonitoringConfig}; use crate::{backend::McpBackend, handler::GenericServerHandler, middleware::MiddlewareStack}; +use async_trait::async_trait; use pulseengine_mcp_auth::{AuthConfig, AuthenticationManager}; use pulseengine_mcp_logging::{ AlertConfig, AlertManager, DashboardConfig, DashboardManager, PerformanceProfiler, @@ -9,13 +10,90 @@ use pulseengine_mcp_logging::{ }; use pulseengine_mcp_protocol::*; use pulseengine_mcp_security::{SecurityConfig, SecurityMiddleware}; -use pulseengine_mcp_transport::{Transport, TransportConfig}; +use pulseengine_mcp_transport::{RequestHandler, Transport, TransportConfig, TransportError}; use std::sync::Arc; +use std::time::Duration; use thiserror::Error; use tokio::signal; +use tokio::sync::RwLock; use tracing::{error, info, warn}; +/// A wrapper around a shared transport reference that implements Transport. +/// This allows the handler to access transport methods while the server +/// owns the transport behind a RwLock. +struct TransportHandle { + transport: Arc>>, +} + +#[async_trait] +impl Transport for TransportHandle { + async fn start(&mut self, _handler: RequestHandler) -> std::result::Result<(), TransportError> { + // The actual transport is started by the server, not through this handle + Err(TransportError::NotSupported( + "Cannot start transport through handle".to_string(), + )) + } + + async fn stop(&mut self) -> std::result::Result<(), TransportError> { + // The actual transport is stopped by the server, not through this handle + Err(TransportError::NotSupported( + "Cannot stop transport through handle".to_string(), + )) + } + + async fn health_check(&self) -> std::result::Result<(), TransportError> { + let transport = self.transport.read().await; + transport.health_check().await + } + + fn supports_bidirectional(&self) -> bool { + // We need to use try_read here since we can't await in a non-async fn + // If we can't get the lock, assume bidirectional is not supported + self.transport + .try_read() + .map(|t| t.supports_bidirectional()) + .unwrap_or(false) + } + + async fn send_notification( + &self, + session_id: Option<&str>, + method: &str, + params: serde_json::Value, + ) -> std::result::Result<(), TransportError> { + let transport = self.transport.read().await; + transport + .send_notification(session_id, method, params) + .await + } + + async fn send_request( + &self, + session_id: Option<&str>, + method: &str, + params: serde_json::Value, + timeout: Duration, + ) -> std::result::Result { + let transport = self.transport.read().await; + transport + .send_request(session_id, method, params, timeout) + .await + } + + fn register_pending_request( + &self, + request_id: &str, + ) -> Option> { + // Delegate to the underlying transport + // We need to use try_read here since we can't await in a non-async fn + self.transport + .try_read() + .ok() + .and_then(|t| t.register_pending_request(request_id)) + } +} + /// Error type for server operations #[derive(Debug, Error)] pub enum ServerError { @@ -110,7 +188,9 @@ pub struct McpServer { backend: Arc, handler: GenericServerHandler, auth_manager: Arc, - transport: Box, + /// Transport layer - wrapped in RwLock to allow both mutable access for + /// start/stop and shared access for bidirectional communication + transport: Arc>>, #[allow(dead_code)] middleware_stack: MiddlewareStack, monitoring_metrics: Arc, @@ -145,10 +225,11 @@ impl McpServer { Arc::new(AuthenticationManager::new_disabled()) }; - // Initialize transport - let transport = + // Initialize transport (wrap in Arc> for shared access) + let transport = Arc::new(tokio::sync::RwLock::new( pulseengine_mcp_transport::create_transport(config.transport_config.clone()) - .map_err(|e| ServerError::Transport(e.to_string()))?; + .map_err(|e| ServerError::Transport(e.to_string()))?, + )); // Initialize security middleware let security_middleware = SecurityMiddleware::new(config.security_config.clone()); @@ -191,7 +272,7 @@ impl McpServer { None }; - // Create handler + // Create handler (transport will be set after transport.start()) let handler = GenericServerHandler::new( backend.clone(), auth_manager.clone(), @@ -262,25 +343,38 @@ impl McpServer { // Metrics persistence is now handled internally by the logging metrics collector // No need for manual snapshot saving - // Start transport + // Create a transport handle for the handler to use for bidirectional communication. + // This wraps the shared transport reference and implements Transport. + let transport_handle: Arc = Arc::new(TransportHandle { + transport: self.transport.clone(), + }); + + // Wire up transport to handler BEFORE starting, since the handler's transport + // reference is shared via Arc> and will be accessible after start + self.handler.set_transport(transport_handle); + + // Start transport (acquire write lock for mutable access) let handler = self.handler.clone(); - self.transport - .start(Box::new(move |request| { - let handler = handler.clone(); - Box::pin(async move { - match handler.handle_request(request).await { - Ok(response) => response, - Err(error) => Response { - jsonrpc: "2.0".to_string(), - id: None, - result: None, - error: Some(error.into()), - }, - } - }) - })) - .await - .map_err(|e| ServerError::Transport(e.to_string()))?; + { + let mut transport_guard = self.transport.write().await; + transport_guard + .start(Box::new(move |request| { + let handler = handler.clone(); + Box::pin(async move { + match handler.handle_request(request).await { + Ok(response) => response, + Err(error) => Response { + jsonrpc: "2.0".to_string(), + id: None, + result: None, + error: Some(error.into()), + }, + } + }) + })) + .await + .map_err(|e| ServerError::Transport(e.to_string()))?; + } info!("MCP server started successfully"); @@ -310,11 +404,14 @@ impl McpServer { info!("Stopping MCP server"); - // Stop transport - self.transport - .stop() - .await - .map_err(|e| ServerError::Transport(e.to_string()))?; + // Stop transport (acquire write lock for mutable access) + { + let mut transport_guard = self.transport.write().await; + transport_guard + .stop() + .await + .map_err(|e| ServerError::Transport(e.to_string()))?; + } // Stop background services self.monitoring_metrics.stop_collection().await; @@ -364,8 +461,11 @@ impl McpServer { // Check backend health let backend_healthy = self.backend.health_check().await.is_ok(); - // Check transport health - let transport_healthy = self.transport.health_check().await.is_ok(); + // Check transport health (acquire read lock to access transport) + let transport_healthy = { + let transport_guard = self.transport.read().await; + transport_guard.health_check().await.is_ok() + }; // Check auth health let auth_healthy = self.auth_manager.health_check().await.is_ok(); diff --git a/mcp-server/src/server_tests.rs b/mcp-server/src/server_tests.rs index d63ff15..193014d 100644 --- a/mcp-server/src/server_tests.rs +++ b/mcp-server/src/server_tests.rs @@ -662,3 +662,199 @@ fn test_server_error_debug() { assert!(debug_str.contains("Backend")); assert!(debug_str.contains("test")); } + +// ============================================================================ +// Additional Server Error Tests +// ============================================================================ + +#[test] +fn test_server_error_all_variants() { + // Test each variant for coverage + let errors = vec![ + ServerError::Configuration("config error".to_string()), + ServerError::Transport("transport error".to_string()), + ServerError::Authentication("auth error".to_string()), + ServerError::Backend("backend error".to_string()), + ServerError::AlreadyRunning, + ServerError::NotRunning, + ServerError::ShutdownTimeout, + ]; + + for error in errors { + // All should implement Display + let display = error.to_string(); + assert!(!display.is_empty()); + + // All should implement Debug + let debug = format!("{error:?}"); + assert!(!debug.is_empty()); + } +} + +#[test] +fn test_server_error_std_error_trait() { + // Verify ServerError implements std::error::Error + let error: Box = + Box::new(ServerError::Configuration("test".to_string())); + assert!(error.to_string().contains("Server configuration error")); +} + +// ============================================================================ +// Transport Configuration Tests +// ============================================================================ + +#[tokio::test] +async fn test_server_with_websocket_transport() { + let backend = + MockServerBackend::initialize((false, false, false, "WebSocket Server".to_string())) + .await + .unwrap(); + + // Test with WebSocket transport + let config = ServerConfig { + transport_config: TransportConfig::WebSocket { + host: Some("127.0.0.1".to_string()), + port: 0, // Random port + }, + auth_config: AuthConfig { + storage: StorageConfig::Memory, + enabled: false, + cache_size: 100, + session_timeout_secs: 3600, + max_failed_attempts: 5, + rate_limit_window_secs: 900, + }, + ..Default::default() + }; + + let server = McpServer::new(backend, config).await; + assert!(server.is_ok()); +} + +#[tokio::test] +async fn test_server_with_streamable_http_transport() { + let backend = + MockServerBackend::initialize((false, false, false, "StreamableHTTP Server".to_string())) + .await + .unwrap(); + + // Test with StreamableHttp transport + let config = ServerConfig { + transport_config: TransportConfig::StreamableHttp { + host: Some("127.0.0.1".to_string()), + port: 0, // Random port + }, + auth_config: AuthConfig { + storage: StorageConfig::Memory, + enabled: false, + cache_size: 100, + session_timeout_secs: 3600, + max_failed_attempts: 5, + rate_limit_window_secs: 900, + }, + ..Default::default() + }; + + let server = McpServer::new(backend, config).await; + assert!(server.is_ok()); +} + +// ============================================================================ +// Config Edge Cases +// ============================================================================ + +#[tokio::test] +async fn test_server_with_profiling_enabled() { + use pulseengine_mcp_logging::ProfilingConfig; + + let backend = + MockServerBackend::initialize((false, false, false, "Profiling Server".to_string())) + .await + .unwrap(); + + // Use a profiling config with enabled = true + let profiling_config = ProfilingConfig { + enabled: true, + ..Default::default() + }; + + let config = ServerConfig { + transport_config: TransportConfig::Stdio, + auth_config: AuthConfig { + storage: StorageConfig::Memory, + enabled: false, + cache_size: 100, + session_timeout_secs: 3600, + max_failed_attempts: 5, + rate_limit_window_secs: 900, + }, + profiling_config, + ..Default::default() + }; + + let server = McpServer::new(backend, config).await; + assert!(server.is_ok()); +} + +#[tokio::test] +async fn test_server_with_persistence_config() { + use pulseengine_mcp_logging::PersistenceConfig; + + let backend = + MockServerBackend::initialize((false, false, false, "Persistence Server".to_string())) + .await + .unwrap(); + + let temp_dir = std::env::temp_dir().join("mcp_test_metrics"); + + // Use default PersistenceConfig with modified data_dir + let persistence = PersistenceConfig { + data_dir: temp_dir.clone(), + ..Default::default() + }; + + let config = ServerConfig { + transport_config: TransportConfig::Stdio, + auth_config: AuthConfig { + storage: StorageConfig::Memory, + enabled: false, + cache_size: 100, + session_timeout_secs: 3600, + max_failed_attempts: 5, + rate_limit_window_secs: 900, + }, + persistence_config: Some(persistence), + ..Default::default() + }; + + let server = McpServer::new(backend, config).await; + // May fail if temp dir can't be created, but should not crash + let _ = server; + + // Cleanup + let _ = std::fs::remove_dir_all(temp_dir); +} + +#[tokio::test] +async fn test_server_with_auth_enabled() { + let backend = + MockServerBackend::initialize((false, false, false, "Auth Enabled Server".to_string())) + .await + .unwrap(); + + let config = ServerConfig { + transport_config: TransportConfig::Stdio, + auth_config: AuthConfig { + storage: StorageConfig::Memory, + enabled: true, // Enable auth + cache_size: 100, + session_timeout_secs: 3600, + max_failed_attempts: 5, + rate_limit_window_secs: 900, + }, + ..Default::default() + }; + + let server = McpServer::new(backend, config).await; + assert!(server.is_ok()); +} diff --git a/mcp-server/src/tool_context.rs b/mcp-server/src/tool_context.rs new file mode 100644 index 0000000..f7d40c1 --- /dev/null +++ b/mcp-server/src/tool_context.rs @@ -0,0 +1,1137 @@ +//! Tool execution context for bidirectional server-to-client communication +//! +//! This module provides the [`ToolContext`] trait and related types that enable +//! MCP tools to send notifications and make requests back to the client during +//! execution. This is essential for implementing: +//! +//! - **Logging notifications** (`notifications/message`) - Send log messages to client +//! - **Progress notifications** (`notifications/progress`) - Report progress during long operations +//! - **Sampling requests** (`sampling/createMessage`) - Request LLM completions from client +//! - **Elicitation requests** (`elicitation/create`) - Request user input from client +//! +//! # Architecture +//! +//! The context uses task-local storage to make it available during tool execution +//! without requiring changes to the `McpBackend` trait signature. This provides +//! a non-breaking way to add bidirectional communication support. +//! +//! # Usage with `#[mcp_tool]` macro +//! +//! Tools can opt into receiving context by adding it as the first parameter: +//! +//! ```rust,ignore +//! use pulseengine_mcp_server::tool_context::ToolContext; +//! +//! #[mcp_tool(name = "long_operation")] +//! async fn long_operation(ctx: &dyn ToolContext, input: String) -> Result { +//! // Send progress notifications +//! for i in 0..=100 { +//! ctx.send_progress(i, Some(100)).await?; +//! tokio::time::sleep(Duration::from_millis(50)).await; +//! } +//! +//! // Log completion +//! ctx.send_log(LogLevel::Info, Some("long_op"), json!({"completed": true})).await?; +//! +//! Ok("Done!".to_string()) +//! } +//! ``` +//! +//! # Manual Usage +//! +//! For backends not using macros, context can be accessed via task-local storage: +//! +//! ```rust,ignore +//! use pulseengine_mcp_server::tool_context::try_current_context; +//! +//! async fn my_tool_handler() { +//! if let Some(ctx) = try_current_context() { +//! ctx.send_progress(50, Some(100)).await.ok(); +//! } +//! } +//! ``` + +use async_trait::async_trait; +use pulseengine_mcp_protocol::{Error, LogLevel}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::fmt; +use std::sync::Arc; +use std::time::Duration; + +// ============================================================================ +// Error Types +// ============================================================================ + +/// Errors that can occur during tool context operations +#[derive(Debug)] +pub enum ToolContextError { + /// Notification sending failed + NotificationFailed(String), + /// Request to client failed + RequestFailed(String), + /// Request timed out waiting for response + Timeout, + /// Client declined the request (e.g., user cancelled elicitation) + Declined(String), + /// Context is not available (tool not running in context scope) + NotAvailable, + /// Serialization error + Serialization(String), + /// Transport error + Transport(String), +} + +impl fmt::Display for ToolContextError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::NotificationFailed(msg) => write!(f, "Notification failed: {msg}"), + Self::RequestFailed(msg) => write!(f, "Request failed: {msg}"), + Self::Timeout => write!(f, "Request timed out"), + Self::Declined(msg) => write!(f, "Client declined: {msg}"), + Self::NotAvailable => write!(f, "Tool context not available"), + Self::Serialization(msg) => write!(f, "Serialization error: {msg}"), + Self::Transport(msg) => write!(f, "Transport error: {msg}"), + } + } +} + +impl std::error::Error for ToolContextError {} + +impl From for Error { + fn from(err: ToolContextError) -> Self { + Error::internal_error(err.to_string()) + } +} + +// ============================================================================ +// Sampling Types (for LLM requests) +// ============================================================================ + +/// Request to create a message via client's LLM (sampling/createMessage) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateMessageRequest { + /// Messages to send to the LLM + pub messages: Vec, + /// Maximum tokens to generate + pub max_tokens: u32, + /// Model preferences (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub model_preferences: Option, + /// System prompt (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub system_prompt: Option, + /// Stop sequences (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, + /// Temperature (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + /// Include context (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub include_context: Option, + /// Metadata (optional) + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +impl Default for CreateMessageRequest { + fn default() -> Self { + Self { + messages: vec![], + max_tokens: 1000, + model_preferences: None, + system_prompt: None, + stop_sequences: None, + temperature: None, + include_context: None, + meta: None, + } + } +} + +/// A message in a sampling request +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SamplingMessage { + /// Role of the message sender + pub role: SamplingRole, + /// Content of the message + pub content: SamplingContent, +} + +impl SamplingMessage { + /// Create a user message with text content + pub fn user(text: impl Into) -> Self { + Self { + role: SamplingRole::User, + content: SamplingContent::Text { text: text.into() }, + } + } + + /// Create an assistant message with text content + pub fn assistant(text: impl Into) -> Self { + Self { + role: SamplingRole::Assistant, + content: SamplingContent::Text { text: text.into() }, + } + } +} + +/// Role in a sampling conversation +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SamplingRole { + /// User message + User, + /// Assistant message + Assistant, +} + +/// Content of a sampling message +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum SamplingContent { + /// Text content + Text { + /// The text content + text: String, + }, + /// Image content + Image { + /// Base64-encoded image data + data: String, + /// MIME type of the image + #[serde(rename = "mimeType")] + mime_type: String, + }, +} + +impl SamplingContent { + /// Get text content if this is a text variant + pub fn as_text(&self) -> Option<&str> { + match self { + Self::Text { text } => Some(text), + _ => None, + } + } +} + +/// Model preferences for sampling +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelPreferences { + /// Cost priority (0.0 to 1.0) + #[serde(skip_serializing_if = "Option::is_none")] + pub cost_priority: Option, + /// Speed priority (0.0 to 1.0) + #[serde(skip_serializing_if = "Option::is_none")] + pub speed_priority: Option, + /// Intelligence priority (0.0 to 1.0) + #[serde(skip_serializing_if = "Option::is_none")] + pub intelligence_priority: Option, + /// Hints for model selection + #[serde(skip_serializing_if = "Option::is_none")] + pub hints: Option>, +} + +/// Hint for model selection +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelHint { + /// Model name hint + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +/// What context to include in sampling request +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum IncludeContext { + /// Include no additional context + None, + /// Include this server's context only + ThisServer, + /// Include all available context + AllServers, +} + +/// Result of a sampling request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateMessageResult { + /// Role of the response + pub role: SamplingRole, + /// Content of the response + pub content: SamplingContent, + /// Model that generated the response + pub model: String, + /// Reason the generation stopped + pub stop_reason: Option, +} + +// ============================================================================ +// Elicitation Types (for user input requests) +// ============================================================================ + +/// Request for user input (elicitation/create) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationRequest { + /// Message to show the user + pub message: String, + /// JSON Schema for the requested data + pub requested_schema: Value, + /// Metadata (optional) + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +impl ElicitationRequest { + /// Create a simple text input request + pub fn text(message: impl Into) -> Self { + Self { + message: message.into(), + requested_schema: serde_json::json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + }, + "required": ["value"] + }), + meta: None, + } + } + + /// Create a request with custom schema + pub fn with_schema(message: impl Into, schema: Value) -> Self { + Self { + message: message.into(), + requested_schema: schema, + meta: None, + } + } +} + +/// Result of an elicitation request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationResult { + /// Action taken by user + pub action: ElicitationAction, + /// Data provided by user (if accepted) + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +/// Action taken on elicitation request +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ElicitationAction { + /// User accepted and provided data + Accept, + /// User declined the request + Decline, + /// Request was cancelled + Cancel, +} + +// ============================================================================ +// Notification Types +// ============================================================================ + +/// Parameters for a log notification +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LogNotificationParams { + /// Log level + pub level: LogLevel, + /// Logger name (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub logger: Option, + /// Log data + pub data: Value, +} + +/// Parameters for a progress notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ProgressNotificationParams { + /// Progress token from the request + pub progress_token: String, + /// Current progress value + pub progress: u64, + /// Total expected value (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, + /// Message describing current progress (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +// ============================================================================ +// Sender Traits (implemented by transport layer) +// ============================================================================ + +/// Trait for sending notifications to the client +#[async_trait] +pub trait NotificationSender: Send + Sync { + /// Send a notification to the client + async fn send_notification(&self, method: &str, params: Value) -> Result<(), ToolContextError>; +} + +/// Trait for making requests to the client +#[async_trait] +pub trait RequestSender: Send + Sync { + /// Send a request to the client and wait for response + async fn send_request( + &self, + method: &str, + params: Value, + timeout: Duration, + ) -> Result; +} + +// ============================================================================ +// ToolContext Trait +// ============================================================================ + +/// Context provided to tool handlers for bidirectional communication +/// +/// This trait defines the interface for server-to-client communication during +/// tool execution. It is automatically provided via task-local storage when +/// a tool is invoked through the handler. +#[async_trait] +pub trait ToolContext: Send + Sync { + /// Send a log notification to the client + /// + /// # Arguments + /// * `level` - Log severity level + /// * `logger` - Optional logger name + /// * `data` - JSON data to log + async fn send_log( + &self, + level: LogLevel, + logger: Option<&str>, + data: Value, + ) -> Result<(), ToolContextError>; + + /// Send a progress notification to the client + /// + /// # Arguments + /// * `progress` - Current progress value + /// * `total` - Optional total value for percentage calculation + async fn send_progress( + &self, + progress: u64, + total: Option, + ) -> Result<(), ToolContextError>; + + /// Send a progress notification with a message + /// + /// # Arguments + /// * `progress` - Current progress value + /// * `total` - Optional total value for percentage calculation + /// * `message` - Description of current progress + async fn send_progress_with_message( + &self, + progress: u64, + total: Option, + message: String, + ) -> Result<(), ToolContextError>; + + /// Request LLM sampling from the client + /// + /// This blocks until the client responds with a completion. + /// + /// # Arguments + /// * `request` - Sampling request parameters + /// * `timeout` - Maximum time to wait for response + async fn request_sampling( + &self, + request: CreateMessageRequest, + timeout: Duration, + ) -> Result; + + /// Request user input from the client + /// + /// This blocks until the user responds or cancels. + /// + /// # Arguments + /// * `request` - Elicitation request parameters + /// * `timeout` - Maximum time to wait for response + async fn request_elicitation( + &self, + request: ElicitationRequest, + timeout: Duration, + ) -> Result; + + /// Get the current request ID + fn request_id(&self) -> &str; + + /// Get the name of the tool being executed + fn tool_name(&self) -> &str; + + /// Get the progress token for this request (if provided by client) + fn progress_token(&self) -> Option<&str>; + + /// Get the session ID for this request (if applicable) + fn session_id(&self) -> Option<&str>; +} + +// ============================================================================ +// Default Implementation +// ============================================================================ + +/// Default implementation of ToolContext +pub struct DefaultToolContext { + request_id: String, + tool_name: String, + progress_token: Option, + session_id: Option, + notification_sender: Arc, + request_sender: Arc, +} + +impl DefaultToolContext { + /// Create a new DefaultToolContext + pub fn new( + request_id: impl Into, + tool_name: impl Into, + progress_token: Option, + session_id: Option, + notification_sender: Arc, + request_sender: Arc, + ) -> Self { + Self { + request_id: request_id.into(), + tool_name: tool_name.into(), + progress_token, + session_id, + notification_sender, + request_sender, + } + } +} + +#[async_trait] +impl ToolContext for DefaultToolContext { + async fn send_log( + &self, + level: LogLevel, + logger: Option<&str>, + data: Value, + ) -> Result<(), ToolContextError> { + let params = LogNotificationParams { + level, + logger: logger.map(String::from), + data, + }; + let value = serde_json::to_value(¶ms) + .map_err(|e| ToolContextError::Serialization(e.to_string()))?; + self.notification_sender + .send_notification("notifications/message", value) + .await + } + + async fn send_progress( + &self, + progress: u64, + total: Option, + ) -> Result<(), ToolContextError> { + let Some(token) = &self.progress_token else { + // No progress token means client didn't request progress tracking + return Ok(()); + }; + + let params = ProgressNotificationParams { + progress_token: token.clone(), + progress, + total, + message: None, + }; + let value = serde_json::to_value(¶ms) + .map_err(|e| ToolContextError::Serialization(e.to_string()))?; + self.notification_sender + .send_notification("notifications/progress", value) + .await + } + + async fn send_progress_with_message( + &self, + progress: u64, + total: Option, + message: String, + ) -> Result<(), ToolContextError> { + let Some(token) = &self.progress_token else { + return Ok(()); + }; + + let params = ProgressNotificationParams { + progress_token: token.clone(), + progress, + total, + message: Some(message), + }; + let value = serde_json::to_value(¶ms) + .map_err(|e| ToolContextError::Serialization(e.to_string()))?; + self.notification_sender + .send_notification("notifications/progress", value) + .await + } + + async fn request_sampling( + &self, + request: CreateMessageRequest, + timeout: Duration, + ) -> Result { + let params = serde_json::to_value(&request) + .map_err(|e| ToolContextError::Serialization(e.to_string()))?; + + let response = self + .request_sender + .send_request("sampling/createMessage", params, timeout) + .await?; + + serde_json::from_value(response).map_err(|e| ToolContextError::Serialization(e.to_string())) + } + + async fn request_elicitation( + &self, + request: ElicitationRequest, + timeout: Duration, + ) -> Result { + let params = serde_json::to_value(&request) + .map_err(|e| ToolContextError::Serialization(e.to_string()))?; + + let response = self + .request_sender + .send_request("elicitation/create", params, timeout) + .await?; + + serde_json::from_value(response).map_err(|e| ToolContextError::Serialization(e.to_string())) + } + + fn request_id(&self) -> &str { + &self.request_id + } + + fn tool_name(&self) -> &str { + &self.tool_name + } + + fn progress_token(&self) -> Option<&str> { + self.progress_token.as_deref() + } + + fn session_id(&self) -> Option<&str> { + self.session_id.as_deref() + } +} + +// ============================================================================ +// Task-Local Storage +// ============================================================================ + +tokio::task_local! { + /// Task-local storage for the current tool context + pub static TOOL_CONTEXT: Arc; +} + +/// Get the current tool context +/// +/// # Panics +/// Panics if called outside of a tool execution scope +pub fn current_context() -> Arc { + TOOL_CONTEXT.with(|ctx| ctx.clone()) +} + +/// Try to get the current tool context +/// +/// Returns `None` if called outside of a tool execution scope +pub fn try_current_context() -> Option> { + TOOL_CONTEXT.try_with(|ctx| ctx.clone()).ok() +} + +/// Execute an async block with a tool context +pub async fn with_context(context: Arc, f: F) -> T +where + F: std::future::Future, +{ + TOOL_CONTEXT.scope(context, f).await +} + +// ============================================================================ +// Transport Bridge (connects ToolContext to Transport) +// ============================================================================ + +use pulseengine_mcp_transport::{ + NotificationSender as StreamingNotificationSender, StreamingNotification, Transport, + TransportError, +}; + +/// Bridge that connects ToolContext to a Transport implementation +/// +/// This allows tools to send notifications and requests through the transport layer. +pub struct TransportBridge { + transport: Arc, + session_id: Option, + /// Captured streaming notification sender for this request + /// This is captured at construction time to avoid task-local scope issues + streaming_sender: Option, +} + +impl TransportBridge { + /// Create a new transport bridge + pub fn new(transport: Arc, session_id: Option) -> Self { + // Capture the streaming sender from the current task-local context + // This ensures we have a direct reference that works even after + // the task-local scope changes (e.g., when with_context is called) + let streaming_sender = pulseengine_mcp_transport::try_notification_sender(); + Self { + transport, + session_id, + streaming_sender, + } + } +} + +#[async_trait] +impl NotificationSender for TransportBridge { + async fn send_notification(&self, method: &str, params: Value) -> Result<(), ToolContextError> { + eprintln!( + "[DEBUG] TransportBridge::send_notification: method={}, session_id={:?}, has_streaming_sender={}", + method, + self.session_id, + self.streaming_sender.is_some() + ); + tracing::debug!( + method = %method, + session_id = ?self.session_id, + has_streaming_sender = self.streaming_sender.is_some(), + "TransportBridge: sending notification" + ); + + // FIRST: Try to send via the captured streaming sender + // This is the preferred path for MCP 2025-03-26 Streamable HTTP + // We use the captured sender rather than task-local lookup because + // the tool executes in a different task-local scope (with_context) + if let Some(ref sender) = self.streaming_sender { + let notification = StreamingNotification { + id: None, // Notifications don't have IDs + method: method.to_string(), + params: params.clone(), + }; + if sender.send(notification).is_ok() { + eprintln!( + "[DEBUG] Notification sent via captured streaming channel: method={method}" + ); + tracing::debug!(method = %method, "Notification sent via captured streaming channel"); + return Ok(()); + } + // If send failed (channel closed), fall through to transport fallback + eprintln!("[DEBUG] Captured streaming channel closed, falling back: method={method}"); + } + + // FALLBACK: Send via transport's broadcast channel (for SSE endpoint) + // This path is used when there's no streaming context + eprintln!( + "[DEBUG] Falling back to transport notification: method={}", + method + ); + let result = self + .transport + .send_notification(self.session_id.as_deref(), method, params) + .await; + match &result { + Ok(()) => { + eprintln!("[DEBUG] Notification sent successfully via transport: method={method}"); + tracing::debug!(method = %method, "Notification sent successfully via transport"); + } + Err(e) => { + eprintln!("[DEBUG] Notification failed: method={method}, error={e}"); + tracing::warn!(method = %method, error = %e, "Notification failed"); + } + } + result.map_err(|e| match e { + TransportError::SessionNotFound(id) => { + ToolContextError::NotificationFailed(format!("Session not found: {id}")) + } + TransportError::ChannelClosed => { + ToolContextError::NotificationFailed("Channel closed".to_string()) + } + TransportError::NotSupported(msg) => { + ToolContextError::NotificationFailed(format!("Not supported: {msg}")) + } + other => ToolContextError::Transport(other.to_string()), + }) + } +} + +#[async_trait] +impl RequestSender for TransportBridge { + async fn send_request( + &self, + method: &str, + params: Value, + timeout: Duration, + ) -> Result { + // Generate a unique request ID + let request_id = uuid::Uuid::new_v4().to_string(); + + eprintln!( + "[DEBUG] TransportBridge::send_request: method={}, request_id={}, has_streaming_sender={}", + method, + request_id, + self.streaming_sender.is_some() + ); + + // FIRST: Try to send via the streaming channel (for POST response streams) + // This is required for MCP conformance when tools make server-to-client requests + if let Some(ref sender) = self.streaming_sender { + // Register the pending request with the transport to get a response receiver + let response_rx = self.transport.register_pending_request(&request_id); + + if let Some(rx) = response_rx { + // Send the request via streaming channel + let request = StreamingNotification { + id: Some(request_id.clone()), + method: method.to_string(), + params: params.clone(), + }; + + if sender.send(request).is_ok() { + eprintln!( + "[DEBUG] Request sent via streaming channel: method={method}, id={request_id}" + ); + + // Wait for response with timeout + match tokio::time::timeout(timeout, rx).await { + Ok(Ok(response)) => { + eprintln!("[DEBUG] Received response for request {request_id}"); + // Check if response is an error + if let Some(error) = response.get("error") { + return Err(ToolContextError::RequestFailed(error.to_string())); + } + return Ok(response); + } + Ok(Err(_)) => { + eprintln!("[DEBUG] Response channel closed for request {request_id}"); + return Err(ToolContextError::RequestFailed( + "Response channel closed".to_string(), + )); + } + Err(_) => { + eprintln!( + "[DEBUG] Timeout waiting for response to request {request_id}" + ); + return Err(ToolContextError::Timeout); + } + } + } + // If send failed, fall through to transport fallback + eprintln!( + "[DEBUG] Streaming channel send failed for request {request_id}, falling back" + ); + } else { + eprintln!("[DEBUG] Could not register pending request {request_id}, falling back"); + } + } + + // FALLBACK: Send via transport's direct method (for SSE endpoint) + eprintln!( + "[DEBUG] Falling back to transport.send_request: method={}", + method + ); + self.transport + .send_request(self.session_id.as_deref(), method, params, timeout) + .await + .map_err(|e| match e { + TransportError::SessionNotFound(id) => { + ToolContextError::RequestFailed(format!("Session not found: {id}")) + } + TransportError::Timeout => ToolContextError::Timeout, + TransportError::ChannelClosed => { + ToolContextError::RequestFailed("Channel closed".to_string()) + } + TransportError::NotSupported(msg) => { + ToolContextError::RequestFailed(format!("Not supported: {msg}")) + } + other => ToolContextError::Transport(other.to_string()), + }) + } +} + +/// Create a ToolContext from a Transport +/// +/// This is the main entry point for wiring up bidirectional communication. +pub fn create_tool_context( + transport: Arc, + request_id: impl Into, + tool_name: impl Into, + progress_token: Option, + session_id: Option, +) -> Arc { + let bridge = Arc::new(TransportBridge::new( + Arc::clone(&transport), + session_id.clone(), + )); + + Arc::new(DefaultToolContext::new( + request_id, + tool_name, + progress_token, + session_id, + bridge.clone(), + bridge, + )) +} + +// ============================================================================ +// No-Op Implementation (for testing/when transport doesn't support bidirectional) +// ============================================================================ + +/// A no-op tool context for when bidirectional communication is not available +pub struct NoOpToolContext { + request_id: String, + tool_name: String, +} + +impl NoOpToolContext { + /// Create a new NoOpToolContext + pub fn new(request_id: impl Into, tool_name: impl Into) -> Self { + Self { + request_id: request_id.into(), + tool_name: tool_name.into(), + } + } +} + +#[async_trait] +impl ToolContext for NoOpToolContext { + async fn send_log( + &self, + _level: LogLevel, + _logger: Option<&str>, + _data: Value, + ) -> Result<(), ToolContextError> { + // No-op: silently succeed + Ok(()) + } + + async fn send_progress( + &self, + _progress: u64, + _total: Option, + ) -> Result<(), ToolContextError> { + Ok(()) + } + + async fn send_progress_with_message( + &self, + _progress: u64, + _total: Option, + _message: String, + ) -> Result<(), ToolContextError> { + Ok(()) + } + + async fn request_sampling( + &self, + _request: CreateMessageRequest, + _timeout: Duration, + ) -> Result { + Err(ToolContextError::NotAvailable) + } + + async fn request_elicitation( + &self, + _request: ElicitationRequest, + _timeout: Duration, + ) -> Result { + Err(ToolContextError::NotAvailable) + } + + fn request_id(&self) -> &str { + &self.request_id + } + + fn tool_name(&self) -> &str { + &self.tool_name + } + + fn progress_token(&self) -> Option<&str> { + None + } + + fn session_id(&self) -> Option<&str> { + None + } +} + +// ============================================================================ +// Mock Implementation (for testing) +// ============================================================================ + +#[cfg(test)] +pub mod mock { + use super::*; + use std::sync::Mutex; + + /// A mock tool context for testing that records all operations + pub struct MockToolContext { + request_id: String, + tool_name: String, + progress_token: Option, + /// Recorded log notifications + pub logs: Mutex>, + /// Recorded progress notifications + pub progress: Mutex>, + /// Response to return for sampling requests + pub sampling_response: Mutex>, + /// Response to return for elicitation requests + pub elicitation_response: Mutex>, + } + + impl MockToolContext { + /// Create a new mock context + pub fn new(tool_name: impl Into) -> Self { + Self { + request_id: uuid::Uuid::new_v4().to_string(), + tool_name: tool_name.into(), + progress_token: Some("test-progress-token".to_string()), + logs: Mutex::new(vec![]), + progress: Mutex::new(vec![]), + sampling_response: Mutex::new(None), + elicitation_response: Mutex::new(None), + } + } + + /// Create a mock context with a specific progress token + pub fn with_progress_token(tool_name: impl Into, token: impl Into) -> Self { + Self { + request_id: uuid::Uuid::new_v4().to_string(), + tool_name: tool_name.into(), + progress_token: Some(token.into()), + logs: Mutex::new(vec![]), + progress: Mutex::new(vec![]), + sampling_response: Mutex::new(None), + elicitation_response: Mutex::new(None), + } + } + + /// Set the response for sampling requests + pub fn set_sampling_response(&self, response: CreateMessageResult) { + *self.sampling_response.lock().unwrap() = Some(response); + } + + /// Set the response for elicitation requests + pub fn set_elicitation_response(&self, response: ElicitationResult) { + *self.elicitation_response.lock().unwrap() = Some(response); + } + + /// Get all recorded logs + pub fn get_logs(&self) -> Vec { + self.logs.lock().unwrap().clone() + } + + /// Get all recorded progress notifications + pub fn get_progress(&self) -> Vec { + self.progress.lock().unwrap().clone() + } + } + + #[async_trait] + impl ToolContext for MockToolContext { + async fn send_log( + &self, + level: LogLevel, + logger: Option<&str>, + data: Value, + ) -> Result<(), ToolContextError> { + self.logs.lock().unwrap().push(LogNotificationParams { + level, + logger: logger.map(String::from), + data, + }); + Ok(()) + } + + async fn send_progress( + &self, + progress: u64, + total: Option, + ) -> Result<(), ToolContextError> { + if let Some(token) = &self.progress_token { + self.progress + .lock() + .unwrap() + .push(ProgressNotificationParams { + progress_token: token.clone(), + progress, + total, + message: None, + }); + } + Ok(()) + } + + async fn send_progress_with_message( + &self, + progress: u64, + total: Option, + message: String, + ) -> Result<(), ToolContextError> { + if let Some(token) = &self.progress_token { + self.progress + .lock() + .unwrap() + .push(ProgressNotificationParams { + progress_token: token.clone(), + progress, + total, + message: Some(message), + }); + } + Ok(()) + } + + async fn request_sampling( + &self, + _request: CreateMessageRequest, + _timeout: Duration, + ) -> Result { + self.sampling_response + .lock() + .unwrap() + .clone() + .ok_or(ToolContextError::NotAvailable) + } + + async fn request_elicitation( + &self, + _request: ElicitationRequest, + _timeout: Duration, + ) -> Result { + self.elicitation_response + .lock() + .unwrap() + .clone() + .ok_or(ToolContextError::NotAvailable) + } + + fn request_id(&self) -> &str { + &self.request_id + } + + fn tool_name(&self) -> &str { + &self.tool_name + } + + fn progress_token(&self) -> Option<&str> { + self.progress_token.as_deref() + } + + fn session_id(&self) -> Option<&str> { + None + } + } +} diff --git a/mcp-server/src/tool_context_tests.rs b/mcp-server/src/tool_context_tests.rs new file mode 100644 index 0000000..74682b8 --- /dev/null +++ b/mcp-server/src/tool_context_tests.rs @@ -0,0 +1,1406 @@ +//! Tests for tool execution context and bidirectional communication + +use crate::tool_context::{ + CreateMessageRequest, CreateMessageResult, ElicitationAction, ElicitationRequest, + ElicitationResult, IncludeContext, LogNotificationParams, ModelHint, ModelPreferences, + NoOpToolContext, NotificationSender, ProgressNotificationParams, RequestSender, + SamplingContent, SamplingMessage, SamplingRole, ToolContext, ToolContextError, + mock::MockToolContext, +}; +use async_trait::async_trait; +use pulseengine_mcp_protocol::{Error, LogLevel}; +use serde_json::{Value, json}; +use std::sync::Arc; +use std::time::Duration; + +// ============================================================================ +// Error Type Tests +// ============================================================================ + +#[test] +fn test_tool_context_error_display() { + let err = ToolContextError::NotificationFailed("test message".to_string()); + assert_eq!(err.to_string(), "Notification failed: test message"); + + let err = ToolContextError::RequestFailed("request error".to_string()); + assert_eq!(err.to_string(), "Request failed: request error"); + + let err = ToolContextError::Timeout; + assert_eq!(err.to_string(), "Request timed out"); + + let err = ToolContextError::Declined("user cancelled".to_string()); + assert_eq!(err.to_string(), "Client declined: user cancelled"); + + let err = ToolContextError::NotAvailable; + assert_eq!(err.to_string(), "Tool context not available"); + + let err = ToolContextError::Serialization("json error".to_string()); + assert_eq!(err.to_string(), "Serialization error: json error"); + + let err = ToolContextError::Transport("connection lost".to_string()); + assert_eq!(err.to_string(), "Transport error: connection lost"); +} + +#[test] +fn test_tool_context_error_to_protocol_error() { + let ctx_err = ToolContextError::NotificationFailed("test".to_string()); + let proto_err: Error = ctx_err.into(); + assert!(proto_err.message.contains("Notification failed")); + + let ctx_err = ToolContextError::Timeout; + let proto_err: Error = ctx_err.into(); + assert!(proto_err.message.contains("timed out")); +} + +#[test] +fn test_tool_context_error_is_std_error() { + let err: Box = + Box::new(ToolContextError::NotificationFailed("test".to_string())); + assert!(err.to_string().contains("Notification failed")); +} + +// ============================================================================ +// Sampling Types Tests +// ============================================================================ + +#[test] +fn test_create_message_request_default() { + let request = CreateMessageRequest::default(); + assert!(request.messages.is_empty()); + assert_eq!(request.max_tokens, 1000); + assert!(request.model_preferences.is_none()); + assert!(request.system_prompt.is_none()); + assert!(request.stop_sequences.is_none()); + assert!(request.temperature.is_none()); + assert!(request.include_context.is_none()); + assert!(request.meta.is_none()); +} + +#[test] +fn test_create_message_request_serialization() { + let request = CreateMessageRequest { + messages: vec![SamplingMessage::user("Hello")], + max_tokens: 500, + model_preferences: Some(ModelPreferences { + cost_priority: Some(0.5), + speed_priority: Some(0.3), + intelligence_priority: Some(0.8), + hints: Some(vec![ModelHint { + name: Some("claude-3".to_string()), + }]), + }), + system_prompt: Some("You are helpful".to_string()), + stop_sequences: Some(vec!["END".to_string()]), + temperature: Some(0.5), // Use 0.5 for exact f32 representation + include_context: Some(IncludeContext::ThisServer), + meta: Some(json!({"key": "value"})), + }; + + let json = serde_json::to_value(&request).unwrap(); + assert_eq!(json["maxTokens"], 500); + assert_eq!(json["systemPrompt"], "You are helpful"); + assert_eq!(json["temperature"], 0.5); + // Note: IncludeContext uses lowercase serialization + assert_eq!(json["includeContext"], "thisserver"); + assert_eq!(json["_meta"]["key"], "value"); + + // Deserialize back + let deserialized: CreateMessageRequest = serde_json::from_value(json).unwrap(); + assert_eq!(deserialized.max_tokens, 500); + assert_eq!( + deserialized.system_prompt, + Some("You are helpful".to_string()) + ); +} + +#[test] +fn test_sampling_message_user() { + let msg = SamplingMessage::user("Hello, world!"); + assert!(matches!(msg.role, SamplingRole::User)); + assert!(matches!(&msg.content, SamplingContent::Text { text } if text == "Hello, world!")); +} + +#[test] +fn test_sampling_message_assistant() { + let msg = SamplingMessage::assistant("Hi there!"); + assert!(matches!(msg.role, SamplingRole::Assistant)); + assert!(matches!(&msg.content, SamplingContent::Text { text } if text == "Hi there!")); +} + +#[test] +fn test_sampling_content_as_text() { + let text_content = SamplingContent::Text { + text: "hello".to_string(), + }; + assert_eq!(text_content.as_text(), Some("hello")); + + let image_content = SamplingContent::Image { + data: "base64data".to_string(), + mime_type: "image/png".to_string(), + }; + assert_eq!(image_content.as_text(), None); +} + +#[test] +fn test_sampling_role_serialization() { + let user = SamplingRole::User; + let json = serde_json::to_value(user).unwrap(); + assert_eq!(json, "user"); + + let assistant = SamplingRole::Assistant; + let json = serde_json::to_value(assistant).unwrap(); + assert_eq!(json, "assistant"); + + // Deserialize + let role: SamplingRole = serde_json::from_str("\"user\"").unwrap(); + assert!(matches!(role, SamplingRole::User)); +} + +#[test] +fn test_sampling_content_serialization() { + let text = SamplingContent::Text { + text: "hello".to_string(), + }; + let json = serde_json::to_value(&text).unwrap(); + assert_eq!(json["type"], "text"); + assert_eq!(json["text"], "hello"); + + let image = SamplingContent::Image { + data: "abc123".to_string(), + mime_type: "image/png".to_string(), + }; + let json = serde_json::to_value(&image).unwrap(); + assert_eq!(json["type"], "image"); + assert_eq!(json["data"], "abc123"); + assert_eq!(json["mimeType"], "image/png"); +} + +#[test] +fn test_include_context_serialization() { + // Note: Uses lowercase serialization (rename_all = "lowercase") + let none_ctx = IncludeContext::None; + let json = serde_json::to_value(none_ctx).unwrap(); + assert_eq!(json, "none"); + + let this_server = IncludeContext::ThisServer; + let json = serde_json::to_value(this_server).unwrap(); + assert_eq!(json, "thisserver"); + + let all_servers = IncludeContext::AllServers; + let json = serde_json::to_value(all_servers).unwrap(); + assert_eq!(json, "allservers"); +} + +#[test] +fn test_model_preferences_default() { + let prefs = ModelPreferences::default(); + assert!(prefs.cost_priority.is_none()); + assert!(prefs.speed_priority.is_none()); + assert!(prefs.intelligence_priority.is_none()); + assert!(prefs.hints.is_none()); +} + +#[test] +fn test_create_message_result_serialization() { + let result = CreateMessageResult { + role: SamplingRole::Assistant, + content: SamplingContent::Text { + text: "Hello!".to_string(), + }, + model: "claude-3-sonnet".to_string(), + stop_reason: Some("end_turn".to_string()), + }; + + let json = serde_json::to_value(&result).unwrap(); + assert_eq!(json["role"], "assistant"); + assert_eq!(json["model"], "claude-3-sonnet"); + assert_eq!(json["stopReason"], "end_turn"); + + let deserialized: CreateMessageResult = serde_json::from_value(json).unwrap(); + assert_eq!(deserialized.model, "claude-3-sonnet"); +} + +// ============================================================================ +// Elicitation Types Tests +// ============================================================================ + +#[test] +fn test_elicitation_request_text() { + let req = ElicitationRequest::text("Please enter your name"); + assert_eq!(req.message, "Please enter your name"); + assert!(req.requested_schema["type"] == "object"); + assert!(req.requested_schema["properties"]["value"]["type"] == "string"); + assert!(req.meta.is_none()); +} + +#[test] +fn test_elicitation_request_with_schema() { + let schema = json!({ + "type": "object", + "properties": { + "age": { "type": "integer", "minimum": 0 } + } + }); + let req = ElicitationRequest::with_schema("Enter your age", schema.clone()); + assert_eq!(req.message, "Enter your age"); + assert_eq!(req.requested_schema, schema); +} + +#[test] +fn test_elicitation_request_serialization() { + let req = ElicitationRequest { + message: "Test message".to_string(), + requested_schema: json!({"type": "string"}), + meta: Some(json!({"key": "value"})), + }; + + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["message"], "Test message"); + assert_eq!(json["requestedSchema"]["type"], "string"); + assert_eq!(json["_meta"]["key"], "value"); +} + +#[test] +fn test_elicitation_result_serialization() { + let result = ElicitationResult { + action: ElicitationAction::Accept, + content: Some(json!({"value": "test input"})), + }; + + let json = serde_json::to_value(&result).unwrap(); + assert_eq!(json["action"], "accept"); + assert_eq!(json["content"]["value"], "test input"); + + let deserialized: ElicitationResult = serde_json::from_value(json).unwrap(); + assert!(matches!(deserialized.action, ElicitationAction::Accept)); +} + +#[test] +fn test_elicitation_action_serialization() { + let accept = ElicitationAction::Accept; + assert_eq!(serde_json::to_value(accept).unwrap(), "accept"); + + let decline = ElicitationAction::Decline; + assert_eq!(serde_json::to_value(decline).unwrap(), "decline"); + + let cancel = ElicitationAction::Cancel; + assert_eq!(serde_json::to_value(cancel).unwrap(), "cancel"); +} + +// ============================================================================ +// Notification Types Tests +// ============================================================================ + +#[test] +fn test_log_notification_params_serialization() { + let params = LogNotificationParams { + level: LogLevel::Info, + logger: Some("my-tool".to_string()), + data: json!({"message": "test"}), + }; + + let json = serde_json::to_value(¶ms).unwrap(); + assert_eq!(json["logger"], "my-tool"); + assert_eq!(json["data"]["message"], "test"); +} + +#[test] +fn test_progress_notification_params_serialization() { + let params = ProgressNotificationParams { + progress_token: "token123".to_string(), + progress: 50, + total: Some(100), + message: Some("Processing...".to_string()), + }; + + let json = serde_json::to_value(¶ms).unwrap(); + assert_eq!(json["progressToken"], "token123"); + assert_eq!(json["progress"], 50); + assert_eq!(json["total"], 100); + assert_eq!(json["message"], "Processing..."); + + // Without optional fields + let params_minimal = ProgressNotificationParams { + progress_token: "token".to_string(), + progress: 10, + total: None, + message: None, + }; + let json = serde_json::to_value(¶ms_minimal).unwrap(); + assert!(json.get("total").is_none()); + assert!(json.get("message").is_none()); +} + +// ============================================================================ +// NoOpToolContext Tests +// ============================================================================ + +#[tokio::test] +async fn test_noop_context_send_log() { + let ctx = NoOpToolContext::new("req-123", "test-tool"); + + // Should succeed silently + let result = ctx + .send_log(LogLevel::Info, Some("logger"), json!({})) + .await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_noop_context_send_progress() { + let ctx = NoOpToolContext::new("req-123", "test-tool"); + + let result = ctx.send_progress(50, Some(100)).await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_noop_context_send_progress_with_message() { + let ctx = NoOpToolContext::new("req-123", "test-tool"); + + let result = ctx + .send_progress_with_message(50, Some(100), "Processing".to_string()) + .await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_noop_context_request_sampling_fails() { + let ctx = NoOpToolContext::new("req-123", "test-tool"); + + let result = ctx + .request_sampling(CreateMessageRequest::default(), Duration::from_secs(5)) + .await; + assert!(matches!(result, Err(ToolContextError::NotAvailable))); +} + +#[tokio::test] +async fn test_noop_context_request_elicitation_fails() { + let ctx = NoOpToolContext::new("req-123", "test-tool"); + + let result = ctx + .request_elicitation(ElicitationRequest::text("test"), Duration::from_secs(5)) + .await; + assert!(matches!(result, Err(ToolContextError::NotAvailable))); +} + +#[test] +fn test_noop_context_accessors() { + let ctx = NoOpToolContext::new("req-123", "my-tool"); + + assert_eq!(ctx.request_id(), "req-123"); + assert_eq!(ctx.tool_name(), "my-tool"); + assert!(ctx.progress_token().is_none()); + assert!(ctx.session_id().is_none()); +} + +// ============================================================================ +// MockToolContext Tests +// ============================================================================ + +#[tokio::test] +async fn test_mock_context_records_logs() { + let ctx = MockToolContext::new("test-tool"); + + ctx.send_log(LogLevel::Info, Some("logger"), json!({"msg": "test"})) + .await + .unwrap(); + ctx.send_log(LogLevel::Error, None, json!({"error": true})) + .await + .unwrap(); + + let logs = ctx.get_logs(); + assert_eq!(logs.len(), 2); + assert!(matches!(logs[0].level, LogLevel::Info)); + assert_eq!(logs[0].logger, Some("logger".to_string())); + assert!(matches!(logs[1].level, LogLevel::Error)); + assert!(logs[1].logger.is_none()); +} + +#[tokio::test] +async fn test_mock_context_records_progress() { + let ctx = MockToolContext::new("test-tool"); + + ctx.send_progress(10, Some(100)).await.unwrap(); + ctx.send_progress(50, Some(100)).await.unwrap(); + ctx.send_progress_with_message(100, Some(100), "Done".to_string()) + .await + .unwrap(); + + let progress = ctx.get_progress(); + assert_eq!(progress.len(), 3); + assert_eq!(progress[0].progress, 10); + assert_eq!(progress[1].progress, 50); + assert_eq!(progress[2].progress, 100); + assert_eq!(progress[2].message, Some("Done".to_string())); +} + +#[tokio::test] +async fn test_mock_context_sampling_without_response() { + let ctx = MockToolContext::new("test-tool"); + + let result = ctx + .request_sampling(CreateMessageRequest::default(), Duration::from_secs(1)) + .await; + assert!(matches!(result, Err(ToolContextError::NotAvailable))); +} + +#[tokio::test] +async fn test_mock_context_sampling_with_response() { + let ctx = MockToolContext::new("test-tool"); + + let response = CreateMessageResult { + role: SamplingRole::Assistant, + content: SamplingContent::Text { + text: "Hello!".to_string(), + }, + model: "test-model".to_string(), + stop_reason: Some("end_turn".to_string()), + }; + ctx.set_sampling_response(response); + + let result = ctx + .request_sampling(CreateMessageRequest::default(), Duration::from_secs(1)) + .await + .unwrap(); + assert_eq!(result.model, "test-model"); +} + +#[tokio::test] +async fn test_mock_context_elicitation_without_response() { + let ctx = MockToolContext::new("test-tool"); + + let result = ctx + .request_elicitation(ElicitationRequest::text("test"), Duration::from_secs(1)) + .await; + assert!(matches!(result, Err(ToolContextError::NotAvailable))); +} + +#[tokio::test] +async fn test_mock_context_elicitation_with_response() { + let ctx = MockToolContext::new("test-tool"); + + let response = ElicitationResult { + action: ElicitationAction::Accept, + content: Some(json!({"value": "user input"})), + }; + ctx.set_elicitation_response(response); + + let result = ctx + .request_elicitation(ElicitationRequest::text("test"), Duration::from_secs(1)) + .await + .unwrap(); + assert!(matches!(result.action, ElicitationAction::Accept)); +} + +#[test] +fn test_mock_context_accessors() { + let ctx = MockToolContext::new("my-tool"); + + assert!(!ctx.request_id().is_empty()); // UUID generated + assert_eq!(ctx.tool_name(), "my-tool"); + assert_eq!(ctx.progress_token(), Some("test-progress-token")); + assert!(ctx.session_id().is_none()); +} + +#[test] +fn test_mock_context_with_progress_token() { + let ctx = MockToolContext::with_progress_token("my-tool", "custom-token"); + + assert_eq!(ctx.progress_token(), Some("custom-token")); +} + +// ============================================================================ +// Task-Local Storage Tests +// ============================================================================ + +#[tokio::test] +async fn test_try_current_context_without_scope() { + use crate::tool_context::try_current_context; + + let ctx = try_current_context(); + assert!(ctx.is_none()); +} + +#[tokio::test] +async fn test_with_context_scope() { + use crate::tool_context::{try_current_context, with_context}; + + let mock = Arc::new(MockToolContext::new("scoped-tool")) as Arc; + + let result = with_context(mock.clone(), async { + let ctx = try_current_context(); + assert!(ctx.is_some()); + ctx.unwrap().tool_name().to_string() + }) + .await; + + assert_eq!(result, "scoped-tool"); + + // Outside scope, should be None again + assert!(try_current_context().is_none()); +} + +#[tokio::test] +#[should_panic(expected = "cannot access a task-local")] +async fn test_current_context_panics_without_scope() { + use crate::tool_context::current_context; + + // This should panic + let _ = current_context(); +} + +#[tokio::test] +async fn test_current_context_in_scope() { + use crate::tool_context::{current_context, with_context}; + + let mock = Arc::new(MockToolContext::new("test")) as Arc; + + with_context(mock, async { + let ctx = current_context(); + assert_eq!(ctx.tool_name(), "test"); + }) + .await; +} + +// ============================================================================ +// DefaultToolContext Tests (with mock senders) +// ============================================================================ + +struct MockNotificationSender { + sent: std::sync::Mutex>, +} + +impl MockNotificationSender { + fn new() -> Self { + Self { + sent: std::sync::Mutex::new(vec![]), + } + } + + fn get_sent(&self) -> Vec<(String, Value)> { + self.sent.lock().unwrap().clone() + } +} + +#[async_trait] +impl NotificationSender for MockNotificationSender { + async fn send_notification(&self, method: &str, params: Value) -> Result<(), ToolContextError> { + self.sent.lock().unwrap().push((method.to_string(), params)); + Ok(()) + } +} + +struct MockRequestSender { + response: std::sync::Mutex>, + error: std::sync::Mutex>, +} + +impl MockRequestSender { + fn new() -> Self { + Self { + response: std::sync::Mutex::new(None), + error: std::sync::Mutex::new(None), + } + } + + fn set_response(&self, response: Value) { + *self.response.lock().unwrap() = Some(response); + } + + fn set_error(&self, err: ToolContextError) { + *self.error.lock().unwrap() = Some(err); + } +} + +#[async_trait] +impl RequestSender for MockRequestSender { + async fn send_request( + &self, + _method: &str, + _params: Value, + _timeout: Duration, + ) -> Result { + if let Some(err) = self.error.lock().unwrap().take() { + return Err(err); + } + self.response + .lock() + .unwrap() + .take() + .ok_or(ToolContextError::NotAvailable) + } +} + +#[tokio::test] +async fn test_default_context_send_log() { + use crate::tool_context::DefaultToolContext; + + let notif_sender = Arc::new(MockNotificationSender::new()); + let req_sender = Arc::new(MockRequestSender::new()); + + let ctx = DefaultToolContext::new( + "req-1", + "tool-1", + None, + Some("session-1".to_string()), + notif_sender.clone(), + req_sender, + ); + + ctx.send_log(LogLevel::Warning, Some("my-logger"), json!({"test": true})) + .await + .unwrap(); + + let sent = notif_sender.get_sent(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0].0, "notifications/message"); + assert_eq!(sent[0].1["logger"], "my-logger"); +} + +#[tokio::test] +async fn test_default_context_send_progress_with_token() { + use crate::tool_context::DefaultToolContext; + + let notif_sender = Arc::new(MockNotificationSender::new()); + let req_sender = Arc::new(MockRequestSender::new()); + + let ctx = DefaultToolContext::new( + "req-1", + "tool-1", + Some("progress-token-123".to_string()), + None, + notif_sender.clone(), + req_sender, + ); + + ctx.send_progress(25, Some(100)).await.unwrap(); + + let sent = notif_sender.get_sent(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0].0, "notifications/progress"); + assert_eq!(sent[0].1["progressToken"], "progress-token-123"); + assert_eq!(sent[0].1["progress"], 25); + assert_eq!(sent[0].1["total"], 100); +} + +#[tokio::test] +async fn test_default_context_send_progress_without_token() { + use crate::tool_context::DefaultToolContext; + + let notif_sender = Arc::new(MockNotificationSender::new()); + let req_sender = Arc::new(MockRequestSender::new()); + + // No progress token + let ctx = DefaultToolContext::new( + "req-1", + "tool-1", + None, + None, + notif_sender.clone(), + req_sender, + ); + + // Should succeed but not send anything + ctx.send_progress(25, Some(100)).await.unwrap(); + + let sent = notif_sender.get_sent(); + assert!(sent.is_empty()); +} + +#[tokio::test] +async fn test_default_context_send_progress_with_message() { + use crate::tool_context::DefaultToolContext; + + let notif_sender = Arc::new(MockNotificationSender::new()); + let req_sender = Arc::new(MockRequestSender::new()); + + let ctx = DefaultToolContext::new( + "req-1", + "tool-1", + Some("token".to_string()), + None, + notif_sender.clone(), + req_sender, + ); + + ctx.send_progress_with_message(75, Some(100), "Almost done".to_string()) + .await + .unwrap(); + + let sent = notif_sender.get_sent(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0].1["message"], "Almost done"); +} + +#[tokio::test] +async fn test_default_context_request_sampling_success() { + use crate::tool_context::DefaultToolContext; + + let notif_sender = Arc::new(MockNotificationSender::new()); + let req_sender = Arc::new(MockRequestSender::new()); + + let response = json!({ + "role": "assistant", + "content": {"type": "text", "text": "Hello!"}, + "model": "test-model", + "stopReason": "end_turn" + }); + req_sender.set_response(response); + + let ctx = DefaultToolContext::new("req-1", "tool-1", None, None, notif_sender, req_sender); + + let result = ctx + .request_sampling(CreateMessageRequest::default(), Duration::from_secs(5)) + .await + .unwrap(); + + assert_eq!(result.model, "test-model"); +} + +#[tokio::test] +async fn test_default_context_request_elicitation_success() { + use crate::tool_context::DefaultToolContext; + + let notif_sender = Arc::new(MockNotificationSender::new()); + let req_sender = Arc::new(MockRequestSender::new()); + + let response = json!({ + "action": "accept", + "content": {"value": "user input"} + }); + req_sender.set_response(response); + + let ctx = DefaultToolContext::new("req-1", "tool-1", None, None, notif_sender, req_sender); + + let result = ctx + .request_elicitation( + ElicitationRequest::text("Enter name"), + Duration::from_secs(5), + ) + .await + .unwrap(); + + assert!(matches!(result.action, ElicitationAction::Accept)); +} + +#[tokio::test] +async fn test_default_context_request_error() { + use crate::tool_context::DefaultToolContext; + + let notif_sender = Arc::new(MockNotificationSender::new()); + let req_sender = Arc::new(MockRequestSender::new()); + req_sender.set_error(ToolContextError::Timeout); + + let ctx = DefaultToolContext::new("req-1", "tool-1", None, None, notif_sender, req_sender); + + let result = ctx + .request_sampling(CreateMessageRequest::default(), Duration::from_secs(5)) + .await; + + assert!(matches!(result, Err(ToolContextError::Timeout))); +} + +#[test] +fn test_default_context_accessors() { + use crate::tool_context::DefaultToolContext; + + let notif_sender = Arc::new(MockNotificationSender::new()); + let req_sender = Arc::new(MockRequestSender::new()); + + let ctx = DefaultToolContext::new( + "req-abc", + "my-tool", + Some("prog-token".to_string()), + Some("sess-123".to_string()), + notif_sender, + req_sender, + ); + + assert_eq!(ctx.request_id(), "req-abc"); + assert_eq!(ctx.tool_name(), "my-tool"); + assert_eq!(ctx.progress_token(), Some("prog-token")); + assert_eq!(ctx.session_id(), Some("sess-123")); +} + +// ============================================================================ +// TransportBridge Tests +// ============================================================================ + +use pulseengine_mcp_transport::{Transport, TransportError}; + +/// Mock error type for tests (String-based for clonability) +#[derive(Clone)] +enum MockTransportResult { + Ok, + ConnectionError(String), + Timeout, + ChannelClosed, + NotSupported(String), + SessionNotFound(String), +} + +impl MockTransportResult { + fn to_notification_result(&self) -> Result<(), TransportError> { + match self { + MockTransportResult::Ok => Ok(()), + MockTransportResult::ConnectionError(msg) => { + Err(TransportError::Connection(msg.clone())) + } + MockTransportResult::Timeout => Err(TransportError::Timeout), + MockTransportResult::ChannelClosed => Err(TransportError::ChannelClosed), + MockTransportResult::NotSupported(msg) => { + Err(TransportError::NotSupported(msg.clone())) + } + MockTransportResult::SessionNotFound(id) => { + Err(TransportError::SessionNotFound(id.clone())) + } + } + } + + fn to_request_result(&self, response: &Value) -> Result { + match self { + MockTransportResult::Ok => Ok(response.clone()), + MockTransportResult::ConnectionError(msg) => { + Err(TransportError::Connection(msg.clone())) + } + MockTransportResult::Timeout => Err(TransportError::Timeout), + MockTransportResult::ChannelClosed => Err(TransportError::ChannelClosed), + MockTransportResult::NotSupported(msg) => { + Err(TransportError::NotSupported(msg.clone())) + } + MockTransportResult::SessionNotFound(id) => { + Err(TransportError::SessionNotFound(id.clone())) + } + } + } +} + +struct MockTransport { + supports_bidirectional: bool, + notification_result: std::sync::Mutex, + request_result: std::sync::Mutex, + request_response: std::sync::Mutex, +} + +impl MockTransport { + fn new(supports_bidirectional: bool) -> Self { + Self { + supports_bidirectional, + notification_result: std::sync::Mutex::new(MockTransportResult::Ok), + request_result: std::sync::Mutex::new(MockTransportResult::Ok), + request_response: std::sync::Mutex::new(json!({})), + } + } + + fn set_notification_error(&self, result: MockTransportResult) { + *self.notification_result.lock().unwrap() = result; + } + + fn set_request_error(&self, result: MockTransportResult) { + *self.request_result.lock().unwrap() = result; + } + + fn set_request_response(&self, response: Value) { + *self.request_response.lock().unwrap() = response; + } +} + +#[async_trait] +impl Transport for MockTransport { + async fn start( + &mut self, + _handler: pulseengine_mcp_transport::RequestHandler, + ) -> Result<(), TransportError> { + Ok(()) + } + + async fn stop(&mut self) -> Result<(), TransportError> { + Ok(()) + } + + async fn health_check(&self) -> Result<(), TransportError> { + Ok(()) + } + + fn supports_bidirectional(&self) -> bool { + self.supports_bidirectional + } + + async fn send_notification( + &self, + _session_id: Option<&str>, + _method: &str, + _params: Value, + ) -> Result<(), TransportError> { + self.notification_result + .lock() + .unwrap() + .to_notification_result() + } + + async fn send_request( + &self, + _session_id: Option<&str>, + _method: &str, + _params: Value, + _timeout: Duration, + ) -> Result { + let result = self.request_result.lock().unwrap().clone(); + let response = self.request_response.lock().unwrap().clone(); + result.to_request_result(&response) + } +} + +#[tokio::test] +async fn test_transport_bridge_send_notification_success() { + use crate::tool_context::TransportBridge; + + let transport = Arc::new(MockTransport::new(true)) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge + .send_notification("test/method", json!({"key": "value"})) + .await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_transport_bridge_send_notification_error() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_notification_error(MockTransportResult::ConnectionError( + "connection lost".to_string(), + )); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge.send_notification("test/method", json!({})).await; + // Connection errors map to Transport (only SessionNotFound/ChannelClosed/NotSupported map to NotificationFailed) + assert!(matches!(result, Err(ToolContextError::Transport(_)))); +} + +#[tokio::test] +async fn test_transport_bridge_send_request_success() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_request_response(json!({"result": "success"})); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge + .send_request("test/method", json!({}), Duration::from_secs(5)) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap()["result"], "success"); +} + +#[tokio::test] +async fn test_transport_bridge_send_request_timeout() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_request_error(MockTransportResult::Timeout); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge + .send_request("test/method", json!({}), Duration::from_secs(5)) + .await; + assert!(matches!(result, Err(ToolContextError::Timeout))); +} + +#[tokio::test] +async fn test_transport_bridge_send_request_channel_closed() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_request_error(MockTransportResult::ChannelClosed); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge + .send_request("test/method", json!({}), Duration::from_secs(5)) + .await; + assert!(matches!(result, Err(ToolContextError::RequestFailed(_)))); +} + +#[tokio::test] +async fn test_transport_bridge_send_request_not_supported() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_request_error(MockTransportResult::NotSupported("sampling".to_string())); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge + .send_request("test/method", json!({}), Duration::from_secs(5)) + .await; + assert!(matches!(result, Err(ToolContextError::RequestFailed(_)))); +} + +#[tokio::test] +async fn test_transport_bridge_send_request_session_not_found() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_request_error(MockTransportResult::SessionNotFound("sess-999".to_string())); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge + .send_request("test/method", json!({}), Duration::from_secs(5)) + .await; + // SessionNotFound maps to RequestFailed for requests + assert!(matches!(result, Err(ToolContextError::RequestFailed(_)))); +} + +// ============================================================================ +// create_tool_context Tests +// ============================================================================ + +#[test] +fn test_create_tool_context() { + use crate::tool_context::create_tool_context; + + let transport = Arc::new(MockTransport::new(true)) as Arc; + + let ctx = create_tool_context( + transport, + "req-123", + "my-tool", + Some("progress-token".to_string()), + Some("session-id".to_string()), + ); + + assert_eq!(ctx.request_id(), "req-123"); + assert_eq!(ctx.tool_name(), "my-tool"); + assert_eq!(ctx.progress_token(), Some("progress-token")); + assert_eq!(ctx.session_id(), Some("session-id")); +} + +#[tokio::test] +async fn test_create_tool_context_send_log() { + use crate::tool_context::create_tool_context; + + let transport = Arc::new(MockTransport::new(true)) as Arc; + + let ctx = create_tool_context( + transport, + "req-123", + "my-tool", + None, + Some("session-id".to_string()), + ); + + // This should succeed via the TransportBridge + let result = ctx + .send_log(LogLevel::Info, Some("logger"), json!({})) + .await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_create_tool_context_send_progress() { + use crate::tool_context::create_tool_context; + + let transport = Arc::new(MockTransport::new(true)) as Arc; + + let ctx = create_tool_context( + transport, + "req-123", + "my-tool", + Some("progress-token".to_string()), + Some("session-id".to_string()), + ); + + let result = ctx.send_progress(50, Some(100)).await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_create_tool_context_request_sampling() { + use crate::tool_context::create_tool_context; + + let mock = MockTransport::new(true); + mock.set_request_response(json!({ + "role": "assistant", + "content": {"type": "text", "text": "Hello!"}, + "model": "test-model", + "stopReason": "end_turn" + })); + let transport = Arc::new(mock) as Arc; + + let ctx = create_tool_context( + transport, + "req-123", + "my-tool", + None, + Some("session-id".to_string()), + ); + + let result = ctx + .request_sampling(CreateMessageRequest::default(), Duration::from_secs(5)) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().model, "test-model"); +} + +#[tokio::test] +async fn test_create_tool_context_request_elicitation() { + use crate::tool_context::create_tool_context; + + let mock = MockTransport::new(true); + mock.set_request_response(json!({ + "action": "accept", + "content": {"value": "user input"} + })); + let transport = Arc::new(mock) as Arc; + + let ctx = create_tool_context( + transport, + "req-123", + "my-tool", + None, + Some("session-id".to_string()), + ); + + let result = ctx + .request_elicitation(ElicitationRequest::text("test"), Duration::from_secs(5)) + .await; + assert!(result.is_ok()); + assert!(matches!(result.unwrap().action, ElicitationAction::Accept)); +} + +// ============================================================================ +// Additional TransportBridge Notification Error Tests +// ============================================================================ + +#[tokio::test] +async fn test_transport_bridge_send_notification_session_not_found() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_notification_error(MockTransportResult::SessionNotFound("sess-123".to_string())); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge.send_notification("test/method", json!({})).await; + // SessionNotFound maps to NotificationFailed for notifications + assert!(matches!( + result, + Err(ToolContextError::NotificationFailed(_)) + )); +} + +#[tokio::test] +async fn test_transport_bridge_send_notification_channel_closed() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_notification_error(MockTransportResult::ChannelClosed); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge.send_notification("test/method", json!({})).await; + // ChannelClosed maps to NotificationFailed for notifications + assert!(matches!( + result, + Err(ToolContextError::NotificationFailed(_)) + )); +} + +#[tokio::test] +async fn test_transport_bridge_send_notification_not_supported() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_notification_error(MockTransportResult::NotSupported("logging".to_string())); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge.send_notification("test/method", json!({})).await; + // NotSupported maps to NotificationFailed for notifications + assert!(matches!( + result, + Err(ToolContextError::NotificationFailed(_)) + )); +} + +#[tokio::test] +async fn test_transport_bridge_send_notification_timeout() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_notification_error(MockTransportResult::Timeout); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge.send_notification("test/method", json!({})).await; + // Timeout (like other unlisted errors) maps to Transport + assert!(matches!(result, Err(ToolContextError::Transport(_)))); +} + +#[tokio::test] +async fn test_transport_bridge_send_request_connection_error() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_request_error(MockTransportResult::ConnectionError( + "network down".to_string(), + )); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + let result = bridge + .send_request("test/method", json!({}), Duration::from_secs(5)) + .await; + // Connection errors map to Transport for requests (catch-all) + assert!(matches!(result, Err(ToolContextError::Transport(_)))); +} + +// ============================================================================ +// Context Accessor Tests +// ============================================================================ + +#[tokio::test] +async fn test_try_current_context_none() { + use crate::tool_context::try_current_context; + + // Without with_context, there should be no current context + let ctx = try_current_context(); + assert!(ctx.is_none()); +} + +#[tokio::test] +async fn test_with_context_and_current_context() { + use crate::tool_context::{NoOpToolContext, with_context}; + + let noop_ctx = Arc::new(NoOpToolContext::new("req-1", "tool-1")) as Arc; + + // Inside with_context, try_current_context should return the context + let result = with_context(noop_ctx.clone(), async { + // Verify context accessors work inside the scope + noop_ctx.request_id().to_string() + }) + .await; + + assert_eq!(result, "req-1"); +} + +// ============================================================================ +// TransportBridge Without Session ID Tests +// ============================================================================ + +#[tokio::test] +async fn test_transport_bridge_no_session_id() { + use crate::tool_context::TransportBridge; + + let transport = Arc::new(MockTransport::new(true)) as Arc; + // Create bridge without session_id (None) + let bridge = TransportBridge::new(transport, None); + + let result = bridge + .send_notification("test/method", json!({"key": "value"})) + .await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_transport_bridge_request_no_session_id() { + use crate::tool_context::TransportBridge; + + let mock = MockTransport::new(true); + mock.set_request_response(json!({"status": "ok"})); + let transport = Arc::new(mock) as Arc; + let bridge = TransportBridge::new(transport, None); + + let result = bridge + .send_request("test/method", json!({}), Duration::from_secs(5)) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap()["status"], "ok"); +} + +// ============================================================================ +// TransportBridge Not Bidirectional Tests +// ============================================================================ + +#[tokio::test] +async fn test_transport_bridge_not_bidirectional() { + use crate::tool_context::TransportBridge; + + // Transport that does not support bidirectional + let transport = Arc::new(MockTransport::new(false)) as Arc; + let bridge = TransportBridge::new(transport, Some("session-1".to_string())); + + // Should still work (falls through to transport method) + let result = bridge + .send_notification("test/method", json!({"data": true})) + .await; + assert!(result.is_ok()); +} + +// ============================================================================ +// create_tool_context Without Optional Fields Tests +// ============================================================================ + +#[tokio::test] +async fn test_create_tool_context_minimal() { + use crate::tool_context::create_tool_context; + + let transport = Arc::new(MockTransport::new(true)) as Arc; + + let ctx = create_tool_context(transport, "req-456", "minimal-tool", None, None); + + assert_eq!(ctx.request_id(), "req-456"); + assert_eq!(ctx.tool_name(), "minimal-tool"); + assert_eq!(ctx.progress_token(), None); + assert_eq!(ctx.session_id(), None); +} + +#[tokio::test] +async fn test_create_tool_context_notification_error() { + use crate::tool_context::create_tool_context; + + let mock = MockTransport::new(true); + mock.set_notification_error(MockTransportResult::ConnectionError("failed".to_string())); + let transport = Arc::new(mock) as Arc; + + let ctx = create_tool_context( + transport, + "req-err", + "error-tool", + None, + Some("session".to_string()), + ); + + // send_log should fail because notification fails + let result = ctx.send_log(LogLevel::Error, None, json!({})).await; + assert!(matches!(result, Err(ToolContextError::Transport(_)))); +} + +#[tokio::test] +async fn test_create_tool_context_request_error() { + use crate::tool_context::create_tool_context; + + let mock = MockTransport::new(true); + mock.set_request_error(MockTransportResult::Timeout); + let transport = Arc::new(mock) as Arc; + + let ctx = create_tool_context( + transport, + "req-timeout", + "timeout-tool", + None, + Some("session".to_string()), + ); + + let result = ctx + .request_sampling(CreateMessageRequest::default(), Duration::from_secs(1)) + .await; + assert!(matches!(result, Err(ToolContextError::Timeout))); +} diff --git a/mcp-transport/src/lib.rs b/mcp-transport/src/lib.rs index 7889f65..50b2800 100644 --- a/mcp-transport/src/lib.rs +++ b/mcp-transport/src/lib.rs @@ -58,6 +58,9 @@ mod websocket_tests; use async_trait::async_trait; use pulseengine_mcp_protocol::{Request, Response}; +use serde_json::Value; +use std::sync::Arc; +use std::time::Duration; // std::error::Error not needed with thiserror use thiserror::Error as ThisError; @@ -73,6 +76,18 @@ pub enum TransportError { #[error("Protocol error: {0}")] Protocol(String), + + #[error("Timeout waiting for response")] + Timeout, + + #[error("Session not found: {0}")] + SessionNotFound(String), + + #[error("Channel closed")] + ChannelClosed, + + #[error("Not supported by this transport: {0}")] + NotSupported(String), } /// Request handler function type @@ -82,12 +97,224 @@ pub type RequestHandler = Box< + Sync, >; +/// Response handler for server-initiated requests +/// +/// When the server sends a request to the client, this handler is used to +/// route responses back to the waiting request. +pub type ResponseHandler = Arc< + dyn Fn(Response) -> std::pin::Pin + Send>> + + Send + + Sync, +>; + /// Transport layer trait #[async_trait] pub trait Transport: Send + Sync { + /// Start the transport with the given request handler async fn start(&mut self, handler: RequestHandler) -> std::result::Result<(), TransportError>; + + /// Stop the transport async fn stop(&mut self) -> std::result::Result<(), TransportError>; + + /// Check if the transport is healthy async fn health_check(&self) -> std::result::Result<(), TransportError>; + + /// Send a notification to a client session + /// + /// Notifications are fire-and-forget messages that don't expect a response. + /// Used for: notifications/message (logging), notifications/progress + /// + /// # Arguments + /// * `session_id` - The session to send to (None for broadcast or default session) + /// * `method` - The notification method name + /// * `params` - The notification parameters + /// + /// # Default Implementation + /// Returns NotSupported error - transports should override if they support notifications + async fn send_notification( + &self, + _session_id: Option<&str>, + _method: &str, + _params: Value, + ) -> std::result::Result<(), TransportError> { + Err(TransportError::NotSupported( + "Notifications not supported by this transport".to_string(), + )) + } + + /// Send a request to a client and wait for a response + /// + /// Used for: sampling/createMessage, elicitation/create + /// + /// # Arguments + /// * `session_id` - The session to send to + /// * `method` - The request method name + /// * `params` - The request parameters + /// * `timeout` - Maximum time to wait for response + /// + /// # Default Implementation + /// Returns NotSupported error - transports should override if they support requests + async fn send_request( + &self, + _session_id: Option<&str>, + _method: &str, + _params: Value, + _timeout: Duration, + ) -> std::result::Result { + Err(TransportError::NotSupported( + "Server-initiated requests not supported by this transport".to_string(), + )) + } + + /// Set the handler for routing responses to server-initiated requests + /// + /// When the transport receives a response to a server-initiated request, + /// it uses this handler to route it back. + /// + /// # Default Implementation + /// Does nothing - transports should override if they support requests + fn set_response_handler(&mut self, _handler: ResponseHandler) { + // Default: no-op for transports that don't support server requests + } + + /// Check if this transport supports bidirectional communication + /// + /// Returns true if the transport can send notifications and requests to clients + fn supports_bidirectional(&self) -> bool { + false + } + + /// Register a pending request and get a receiver for the response + /// + /// This is used for server-initiated requests that need response correlation. + /// Returns a oneshot receiver that will receive the response value. + /// + /// # Arguments + /// * `request_id` - The unique ID for this request + /// + /// # Default Implementation + /// Returns None - transports should override if they support requests + fn register_pending_request( + &self, + _request_id: &str, + ) -> Option> { + None + } +} + +// ============================================================================ +// Session Context (Task-Local Storage) +// ============================================================================ + +/// A message to be sent during request processing via the streaming response +/// +/// Can be either a notification (no id) or a request (with id for response correlation) +#[derive(Debug, Clone)] +pub struct StreamingNotification { + /// Request ID - if Some, this is a request expecting a response; if None, it's a notification + pub id: Option, + pub method: String, + pub params: Value, +} + +/// Sender for streaming notifications during a request +pub type NotificationSender = tokio::sync::mpsc::UnboundedSender; + +tokio::task_local! { + /// Task-local storage for the current session ID + /// + /// This allows handlers to access the session ID of the client that made + /// the current request, enabling bidirectional communication to be routed + /// to the correct client. + pub static SESSION_ID: String; + + /// Task-local storage for the notification sender + /// + /// When a request supports streaming (like tools/call), this sender is set + /// and tools can use it to send notifications that will be included in the + /// SSE response stream. + pub static NOTIFICATION_SENDER: NotificationSender; +} + +/// Get the current session ID +/// +/// # Panics +/// Panics if called outside of a request handling scope +pub fn current_session_id() -> String { + SESSION_ID.with(|id| id.clone()) +} + +/// Try to get the current session ID +/// +/// Returns `None` if called outside of a request handling scope +pub fn try_current_session_id() -> Option { + SESSION_ID.try_with(|id| id.clone()).ok() +} + +/// Execute an async block with a session ID context +pub async fn with_session(session_id: String, f: F) -> T +where + F: std::future::Future, +{ + SESSION_ID.scope(session_id, f).await +} + +/// Try to get the current notification sender +/// +/// Returns `Some` if we're in a streaming request context, `None` otherwise +pub fn try_notification_sender() -> Option { + NOTIFICATION_SENDER.try_with(|s| s.clone()).ok() +} + +/// Send a notification via the streaming response (if available) +/// +/// Returns `true` if notification was queued, `false` if no streaming context +pub fn send_streaming_notification(method: &str, params: Value) -> bool { + if let Some(sender) = try_notification_sender() { + sender + .send(StreamingNotification { + id: None, + method: method.to_string(), + params, + }) + .is_ok() + } else { + false + } +} + +/// Send a request via the streaming response (if available) +/// +/// Returns `true` if request was queued, `false` if no streaming context +pub fn send_streaming_request(id: &str, method: &str, params: Value) -> bool { + if let Some(sender) = try_notification_sender() { + sender + .send(StreamingNotification { + id: Some(id.to_string()), + method: method.to_string(), + params, + }) + .is_ok() + } else { + false + } +} + +/// Execute an async block with both session ID and notification sender contexts +pub async fn with_streaming_context( + session_id: String, + notification_sender: NotificationSender, + f: F, +) -> T +where + F: std::future::Future, +{ + SESSION_ID + .scope( + session_id, + NOTIFICATION_SENDER.scope(notification_sender, f), + ) + .await } /// Create a transport from configuration diff --git a/mcp-transport/src/lib_tests.rs b/mcp-transport/src/lib_tests.rs index 36956c3..78ee188 100644 --- a/mcp-transport/src/lib_tests.rs +++ b/mcp-transport/src/lib_tests.rs @@ -75,6 +75,20 @@ mod tests { assert!(display.contains("Protocol error")); assert!(display.contains(msg)); } + TransportError::Timeout => { + assert!(display.contains("Timeout")); + } + TransportError::SessionNotFound(msg) => { + assert!(display.contains("Session not found")); + assert!(display.contains(msg)); + } + TransportError::ChannelClosed => { + assert!(display.contains("Channel closed")); + } + TransportError::NotSupported(msg) => { + assert!(display.contains("Not supported")); + assert!(display.contains(msg)); + } } } } @@ -254,4 +268,316 @@ mod tests { let _stdio_mod = std::any::type_name::(); let _websocket_mod = std::any::type_name::(); } + + // ============================================================================ + // Streaming Context Tests + // ============================================================================ + + #[test] + fn test_try_notification_sender_without_context() { + // Outside any streaming context, should return None + let sender = crate::try_notification_sender(); + assert!(sender.is_none()); + } + + #[test] + fn test_send_streaming_notification_without_context() { + // Without context, should return false + let result = crate::send_streaming_notification("test/method", serde_json::json!({})); + assert!(!result); + } + + #[test] + fn test_send_streaming_request_without_context() { + // Without context, should return false + let result = crate::send_streaming_request("req-123", "test/method", serde_json::json!({})); + assert!(!result); + } + + #[test] + fn test_try_current_session_id_without_context() { + // Outside any session context, should return None + let session_id = crate::try_current_session_id(); + assert!(session_id.is_none()); + } + + #[tokio::test] + async fn test_with_session_context() { + let session_id = "test-session-123"; + + let result = crate::with_session(session_id.to_string(), async { + crate::try_current_session_id() + }) + .await; + + assert_eq!(result, Some(session_id.to_string())); + } + + #[tokio::test] + async fn test_with_streaming_context() { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + + let result = crate::with_streaming_context("session-456".to_string(), tx, async { + // Inside context, should be able to send + let sent = crate::send_streaming_notification( + "test/method", + serde_json::json!({"key": "value"}), + ); + assert!(sent); + + // Try sending a request too + let sent_req = crate::send_streaming_request( + "req-789", + "sampling/createMessage", + serde_json::json!({}), + ); + assert!(sent_req); + + // Session ID should be available + crate::try_current_session_id() + }) + .await; + + // Session ID was captured + assert_eq!(result, Some("session-456".to_string())); + + // Verify messages were received + let notification = rx.recv().await.unwrap(); + assert!(notification.id.is_none()); // Notification has no ID + assert_eq!(notification.method, "test/method"); + + let request = rx.recv().await.unwrap(); + assert_eq!(request.id, Some("req-789".to_string())); + assert_eq!(request.method, "sampling/createMessage"); + } + + #[tokio::test] + async fn test_streaming_notification_struct() { + let notification = crate::StreamingNotification { + id: None, + method: "notifications/progress".to_string(), + params: serde_json::json!({"progress": 50}), + }; + + assert!(notification.id.is_none()); + assert_eq!(notification.method, "notifications/progress"); + assert_eq!(notification.params["progress"], 50); + + // Test clone + let cloned = notification.clone(); + assert_eq!(cloned.method, notification.method); + + // Test debug + let debug_str = format!("{notification:?}"); + assert!(debug_str.contains("StreamingNotification")); + } + + #[tokio::test] + async fn test_streaming_request_struct() { + let request = crate::StreamingNotification { + id: Some("req-abc".to_string()), + method: "sampling/createMessage".to_string(), + params: serde_json::json!({"maxTokens": 100}), + }; + + assert_eq!(request.id, Some("req-abc".to_string())); + assert_eq!(request.method, "sampling/createMessage"); + } + + // ============================================================================ + // create_transport Tests + // ============================================================================ + + #[test] + fn test_create_transport_stdio() { + let config = TransportConfig::Stdio; + let transport = crate::create_transport(config); + + assert!(transport.is_ok()); + } + + #[test] + fn test_create_transport_http() { + let config = TransportConfig::Http { + host: Some("127.0.0.1".to_string()), + port: 3000, + }; + let transport = crate::create_transport(config); + + assert!(transport.is_ok()); + } + + #[test] + fn test_create_transport_streamable_http() { + let config = TransportConfig::StreamableHttp { + host: Some("127.0.0.1".to_string()), + port: 3001, + }; + let transport = crate::create_transport(config); + + assert!(transport.is_ok()); + } + + #[test] + fn test_create_transport_websocket() { + let config = TransportConfig::WebSocket { + host: Some("127.0.0.1".to_string()), + port: 3002, + }; + let transport = crate::create_transport(config); + + assert!(transport.is_ok()); + } + + // ============================================================================ + // Additional TransportError Tests + // ============================================================================ + + #[test] + fn test_transport_error_timeout() { + let error = TransportError::Timeout; + assert!(error.to_string().contains("Timeout")); + } + + #[test] + fn test_transport_error_session_not_found() { + let error = TransportError::SessionNotFound("sess-123".to_string()); + let msg = error.to_string(); + assert!(msg.contains("Session not found")); + assert!(msg.contains("sess-123")); + } + + #[test] + fn test_transport_error_channel_closed() { + let error = TransportError::ChannelClosed; + assert!(error.to_string().contains("Channel closed")); + } + + #[test] + fn test_transport_error_not_supported() { + let error = TransportError::NotSupported("bidirectional".to_string()); + let msg = error.to_string(); + assert!(msg.contains("Not supported")); + assert!(msg.contains("bidirectional")); + } + + // ============================================================================ + // Transport Trait Default Implementation Tests + // ============================================================================ + + #[tokio::test] + async fn test_transport_default_send_notification() { + // Create a minimal transport that uses default implementations + struct MinimalTransport; + + #[async_trait::async_trait] + impl Transport for MinimalTransport { + async fn start( + &mut self, + _handler: RequestHandler, + ) -> std::result::Result<(), TransportError> { + Ok(()) + } + async fn stop(&mut self) -> std::result::Result<(), TransportError> { + Ok(()) + } + async fn health_check(&self) -> std::result::Result<(), TransportError> { + Ok(()) + } + } + + let transport = MinimalTransport; + + // Default implementation should return NotSupported + let result = transport + .send_notification(None, "test", serde_json::json!({})) + .await; + assert!(matches!(result, Err(TransportError::NotSupported(_)))); + } + + #[tokio::test] + async fn test_transport_default_send_request() { + struct MinimalTransport; + + #[async_trait::async_trait] + impl Transport for MinimalTransport { + async fn start( + &mut self, + _handler: RequestHandler, + ) -> std::result::Result<(), TransportError> { + Ok(()) + } + async fn stop(&mut self) -> std::result::Result<(), TransportError> { + Ok(()) + } + async fn health_check(&self) -> std::result::Result<(), TransportError> { + Ok(()) + } + } + + let transport = MinimalTransport; + + // Default implementation should return NotSupported + let result = transport + .send_request( + None, + "test", + serde_json::json!({}), + std::time::Duration::from_secs(1), + ) + .await; + assert!(matches!(result, Err(TransportError::NotSupported(_)))); + } + + #[test] + fn test_transport_default_supports_bidirectional() { + struct MinimalTransport; + + #[async_trait::async_trait] + impl Transport for MinimalTransport { + async fn start( + &mut self, + _handler: RequestHandler, + ) -> std::result::Result<(), TransportError> { + Ok(()) + } + async fn stop(&mut self) -> std::result::Result<(), TransportError> { + Ok(()) + } + async fn health_check(&self) -> std::result::Result<(), TransportError> { + Ok(()) + } + } + + let transport = MinimalTransport; + + // Default implementation should return false + assert!(!transport.supports_bidirectional()); + } + + #[test] + fn test_transport_default_register_pending_request() { + struct MinimalTransport; + + #[async_trait::async_trait] + impl Transport for MinimalTransport { + async fn start( + &mut self, + _handler: RequestHandler, + ) -> std::result::Result<(), TransportError> { + Ok(()) + } + async fn stop(&mut self) -> std::result::Result<(), TransportError> { + Ok(()) + } + async fn health_check(&self) -> std::result::Result<(), TransportError> { + Ok(()) + } + } + + let transport = MinimalTransport; + + // Default implementation should return None + assert!(transport.register_pending_request("req-123").is_none()); + } } diff --git a/mcp-transport/src/streamable_http.rs b/mcp-transport/src/streamable_http.rs index 0e2ad8c..9f8beb4 100644 --- a/mcp-transport/src/streamable_http.rs +++ b/mcp-transport/src/streamable_http.rs @@ -8,20 +8,28 @@ //! - Event IDs for stream resumption via `Last-Event-ID` header //! - Server-initiated disconnect with `retry` field for polling //! - Origin header validation (HTTP 403 for invalid origins) +//! - **Bidirectional communication** - server can send notifications and requests to clients -use crate::{RequestHandler, Transport, TransportError}; +use crate::{ + RequestHandler, StreamingNotification, Transport, TransportError, with_streaming_context, +}; use async_trait::async_trait; use axum::{ Json, Router, extract::{Query, State}, http::{HeaderMap, StatusCode}, - response::IntoResponse, + response::{IntoResponse, Response as AxumResponse, Sse, sse::Event as SseEvent}, routing::{get, post}, }; +use futures::stream::Stream; use serde::Deserialize; use serde_json::Value; -use std::{collections::HashMap, net::SocketAddr, sync::Arc}; -use tokio::sync::RwLock; +use std::{collections::HashMap, convert::Infallible, net::SocketAddr, sync::Arc, time::Duration}; +use tokio::sync::{RwLock, broadcast, oneshot}; + +// Type aliases for clarity - sessions use async RwLock, pending_requests use sync RwLock +type SessionsMap = RwLock>; +type PendingRequestsMap = std::sync::RwLock>; use tower::ServiceBuilder; use tower_http::cors::CorsLayer; use tracing::{debug, info, warn}; @@ -46,6 +54,10 @@ pub struct StreamableHttpConfig { /// Whether to enable SSE stream resumption (MCP 2025-11-25) /// When true, server will attach event IDs and support Last-Event-ID header pub sse_resumable: bool, + /// Channel capacity for SSE message broadcasting + pub channel_capacity: usize, + /// Default timeout for server-initiated requests (sampling, elicitation) + pub request_timeout: Duration, } impl Default for StreamableHttpConfig { @@ -58,6 +70,8 @@ impl Default for StreamableHttpConfig { enforce_origin_validation: false, sse_retry_ms: 3000, // 3 seconds default retry interval sse_resumable: true, + channel_capacity: 100, + request_timeout: Duration::from_secs(60), } } } @@ -84,8 +98,21 @@ impl StreamableHttpConfig { } } -/// Session information +/// Message that can be sent via SSE to clients #[derive(Debug, Clone)] +pub enum SseMessage { + /// A JSON-RPC notification (no response expected) + Notification { method: String, params: Value }, + /// A JSON-RPC request (response expected) + Request { + id: String, + method: String, + params: Value, + }, +} + +/// Session information +#[derive(Debug)] struct SessionInfo { #[allow(dead_code)] id: String, @@ -93,6 +120,13 @@ struct SessionInfo { created_at: std::time::Instant, /// Counter for generating unique event IDs within this session event_counter: u64, + /// Broadcast channel sender for this session's SSE messages + message_sender: broadcast::Sender, +} + +/// Pending request awaiting response from client +struct PendingRequest { + response_sender: oneshot::Sender, } /// SSE Event ID (MCP 2025-11-25) @@ -140,10 +174,179 @@ impl SseEventId { #[derive(Clone)] struct AppState { handler: Arc, - sessions: Arc>>, + sessions: Arc, + pending_requests: Arc, + config: StreamableHttpConfig, +} + +/// Handle for accessing transport state from outside the HTTP server +#[derive(Clone)] +pub struct TransportHandle { + sessions: Arc, + pending_requests: Arc, + #[allow(dead_code)] config: StreamableHttpConfig, } +impl TransportHandle { + /// Send a notification to a specific session or all sessions + pub async fn send_notification( + &self, + session_id: Option<&str>, + method: &str, + params: Value, + ) -> Result<(), TransportError> { + let message = SseMessage::Notification { + method: method.to_string(), + params, + }; + + let sessions = self.sessions.read().await; + + if let Some(id) = session_id { + // Send to specific session + if let Some(session) = sessions.get(id) { + // Note: broadcast::Sender.send() returns Err if there are no receivers, + // but this is not an error condition - it just means no SSE clients are + // currently connected. The message is discarded, which is fine. + let receiver_count = session.message_sender.receiver_count(); + if receiver_count > 0 { + if let Err(e) = session.message_sender.send(message.clone()) { + warn!( + "Failed to send notification {} to session {}: {}", + method, id, e + ); + } else { + debug!( + "Sent notification {} to session {} ({} receivers)", + method, id, receiver_count + ); + } + } else { + debug!( + "No SSE receivers for session {}, notification {} discarded", + id, method + ); + } + } else { + return Err(TransportError::SessionNotFound(id.to_string())); + } + } else { + // Broadcast to all sessions + for (id, session) in sessions.iter() { + if session.message_sender.send(message.clone()).is_err() { + warn!( + "Failed to send notification to session {} (channel closed)", + id + ); + } else { + debug!("Sent notification {} to session {}", method, id); + } + } + } + + Ok(()) + } + + /// Send a request to a specific session and wait for response + pub async fn send_request( + &self, + session_id: Option<&str>, + method: &str, + params: Value, + timeout: Duration, + ) -> Result { + let session_id = session_id.ok_or_else(|| { + TransportError::Config("Session ID required for requests".to_string()) + })?; + + let request_id = Uuid::new_v4().to_string(); + let (tx, rx) = oneshot::channel(); + + // Register pending request (using sync RwLock) + { + let mut pending = self.pending_requests.write().unwrap(); + pending.insert( + request_id.clone(), + PendingRequest { + response_sender: tx, + }, + ); + } + + // Send the request message + let message = SseMessage::Request { + id: request_id.clone(), + method: method.to_string(), + params, + }; + + { + let sessions = self.sessions.read().await; + if let Some(session) = sessions.get(session_id) { + session + .message_sender + .send(message) + .map_err(|_| TransportError::ChannelClosed)?; + debug!( + "Sent request {} ({}) to session {}", + method, request_id, session_id + ); + } else { + // Clean up pending request + let mut pending = self.pending_requests.write().unwrap(); + pending.remove(&request_id); + return Err(TransportError::SessionNotFound(session_id.to_string())); + } + } + + // Wait for response with timeout + match tokio::time::timeout(timeout, rx).await { + Ok(Ok(response)) => Ok(response), + Ok(Err(_)) => { + // Channel was closed (shouldn't happen normally) + let mut pending = self.pending_requests.write().unwrap(); + pending.remove(&request_id); + Err(TransportError::ChannelClosed) + } + Err(_) => { + // Timeout + let mut pending = self.pending_requests.write().unwrap(); + pending.remove(&request_id); + Err(TransportError::Timeout) + } + } + } + + /// Handle a response from the client to a server-initiated request + pub fn handle_response(&self, id: &str, result: Value) -> bool { + let mut pending = self.pending_requests.write().unwrap(); + if let Some(pending_request) = pending.remove(id) { + let _ = pending_request.response_sender.send(result); + true + } else { + false + } + } + + /// Register a pending request and get the response receiver + /// + /// This separates registration from sending, allowing the request to be + /// sent via the streaming response channel while still correlating responses. + pub fn register_pending_request_sync(&self, request_id: &str) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + // Using sync RwLock - safe for use in any context + let mut pending = self.pending_requests.write().unwrap(); + pending.insert( + request_id.to_string(), + PendingRequest { + response_sender: tx, + }, + ); + rx + } +} + /// Validate Origin header against allowed origins (MCP 2025-11-25) /// /// Returns None if validation passes, Some(response) with 403 Forbidden if invalid @@ -210,6 +413,8 @@ struct StreamQuery { pub struct StreamableHttpTransport { config: StreamableHttpConfig, server_handle: Option>, + /// Handle for sending messages to sessions + transport_handle: Option, } impl StreamableHttpTransport { @@ -220,6 +425,7 @@ impl StreamableHttpTransport { ..Default::default() }, server_handle: None, + transport_handle: None, } } @@ -228,6 +434,7 @@ impl StreamableHttpTransport { Self { config, server_handle: None, + transport_handle: None, } } @@ -259,6 +466,11 @@ impl StreamableHttpTransport { &mut self.config } + /// Get the transport handle for sending messages + pub fn handle(&self) -> Option { + self.transport_handle.clone() + } + /// Create or get session async fn ensure_session(state: &AppState, session_id: Option) -> String { if let Some(id) = session_id { @@ -269,10 +481,12 @@ impl StreamableHttpTransport { } // If session doesn't exist, create it with the provided ID drop(sessions); + let (sender, _) = broadcast::channel(state.config.channel_capacity); let session = SessionInfo { id: id.clone(), created_at: std::time::Instant::now(), event_counter: 0, + message_sender: sender, }; let mut sessions = state.sessions.write().await; sessions.insert(id.clone(), session); @@ -282,10 +496,12 @@ impl StreamableHttpTransport { // Create new session with generated ID let id = Uuid::new_v4().to_string(); + let (sender, _) = broadcast::channel(state.config.channel_capacity); let session = SessionInfo { id: id.clone(), created_at: std::time::Instant::now(), event_counter: 0, + message_sender: sender, }; let mut sessions = state.sessions.write().await; @@ -320,7 +536,7 @@ async fn handle_messages( State(state): State>, headers: HeaderMap, body: String, -) -> axum::response::Response { +) -> AxumResponse { debug!("Received POST /messages: {}", body); // MCP 2025-11-25: Validate Origin header - return 403 Forbidden for invalid origins @@ -336,11 +552,11 @@ async fn handle_messages( let session_id = StreamableHttpTransport::ensure_session(&state, session_id).await; - // Parse the request - let request: Value = match serde_json::from_str(&body) { + // Parse the request/response + let message: Value = match serde_json::from_str(&body) { Ok(v) => v, Err(e) => { - warn!("Failed to parse request: {}", e); + warn!("Failed to parse message: {}", e); return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({ @@ -356,9 +572,44 @@ async fn handle_messages( } }; + // Check if this is a response to a server-initiated request + if message.get("result").is_some() || message.get("error").is_some() { + // This is a response, not a request + if let Some(id) = message.get("id").and_then(|v| v.as_str()) { + let handle = TransportHandle { + sessions: Arc::clone(&state.sessions), + pending_requests: Arc::clone(&state.pending_requests), + config: state.config.clone(), + }; + + let result = if let Some(result) = message.get("result") { + result.clone() + } else if let Some(error) = message.get("error") { + // Convert error to a value that the caller can handle + serde_json::json!({ "error": error }) + } else { + Value::Null + }; + + if handle.handle_response(id, result) { + debug!("Routed response for request {}", id); + let mut response_headers = HeaderMap::new(); + response_headers.insert("Mcp-Session-Id", session_id.parse().unwrap()); + return ( + StatusCode::OK, + response_headers, + Json(serde_json::json!({})), + ) + .into_response(); + } else { + warn!("Received response for unknown request {}", id); + } + } + } + // Convert to MCP Request let mcp_request: pulseengine_mcp_protocol::Request = - match serde_json::from_value(request.clone()) { + match serde_json::from_value(message.clone()) { Ok(r) => r, Err(e) => { warn!("Invalid request format: {}", e); @@ -370,22 +621,353 @@ async fn handle_messages( "code": -32600, "message": "Invalid request" }, - "id": request.get("id").cloned().unwrap_or(Value::Null) + "id": message.get("id").cloned().unwrap_or(Value::Null) })), ) .into_response(); } }; - // Process through handler - let response = (state.handler)(mcp_request).await; + // Check if client accepts SSE responses + let accepts_sse = headers + .get("Accept") + .and_then(|v| v.to_str().ok()) + .map(|s| s.contains("text/event-stream")) + .unwrap_or(false); + + // Build response headers + let mut response_headers = HeaderMap::new(); + response_headers.insert("Mcp-Session-Id", session_id.parse().unwrap()); + + // For bidirectional communication (sampling, elicitation), we need true streaming + // where events are sent as they're produced, not collected after handler completes. + // This is required because tools may block waiting for client responses. + if accepts_sse { + // Create streaming response that sends events in real-time + let stream = create_realtime_sse_stream(state, session_id, mcp_request); + + response_headers.insert("Content-Type", "text/event-stream".parse().unwrap()); + response_headers.insert("Cache-Control", "no-cache".parse().unwrap()); + response_headers.insert("Connection", "keep-alive".parse().unwrap()); + + return (StatusCode::OK, response_headers, Sse::new(stream)).into_response(); + } + + // Non-SSE path: collect notifications after handler completes + // This works for simple tools but NOT for sampling/elicitation + let (notification_tx, mut notification_rx) = + tokio::sync::mpsc::unbounded_channel::(); + + let handler = state.handler.clone(); + let session_id_for_context = session_id.clone(); + let response = with_streaming_context(session_id_for_context, notification_tx, async move { + (handler)(mcp_request).await + }) + .await; + + // Collect all notifications that were sent during processing + let mut notifications: Vec = Vec::new(); + while let Ok(notification) = notification_rx.try_recv() { + notifications.push(notification); + } + + debug!( + "Sending response with session ID: {}, notifications: {}", + session_id, + notifications.len() + ); + + // If there are notifications, return an SSE stream + if !notifications.is_empty() { + eprintln!( + "[DEBUG] Returning SSE stream with {} notifications", + notifications.len() + ); + + let stream = create_post_response_stream(notifications, response); + + response_headers.insert("Content-Type", "text/event-stream".parse().unwrap()); + response_headers.insert("Cache-Control", "no-cache".parse().unwrap()); + response_headers.insert("Connection", "keep-alive".parse().unwrap()); + + return (StatusCode::OK, response_headers, Sse::new(stream)).into_response(); + } + + // No notifications - return simple JSON response + (StatusCode::OK, response_headers, Json(response)).into_response() +} + +/// Create a real-time SSE stream that sends events as they're produced +/// +/// This is essential for bidirectional communication (sampling, elicitation) where +/// the server needs to send requests to the client and wait for responses during +/// tool execution. The stream sends notifications/requests immediately as they're +/// generated, then the final response when the handler completes. +fn create_realtime_sse_stream( + state: Arc, + session_id: String, + mcp_request: pulseengine_mcp_protocol::Request, +) -> impl Stream> + Send { + async_stream::stream! { + eprintln!("[DEBUG SSE RT] Starting real-time stream for session {}", session_id); + + // Create channel for streaming notifications/requests + let (notification_tx, mut notification_rx) = + tokio::sync::mpsc::unbounded_channel::(); + + // Spawn handler in a separate task so we can stream events concurrently + let handler = state.handler.clone(); + let session_id_for_context = session_id.clone(); + let handler_task = tokio::spawn(async move { + with_streaming_context(session_id_for_context, notification_tx, async move { + (handler)(mcp_request).await + }) + .await + }); + + // Wrap the handler task in a fuse to allow awaiting multiple times safely + let mut handler_task = std::pin::pin!(handler_task); + let mut handler_result: Option = None; + + eprintln!("[DEBUG SSE RT] Entering main loop"); + + // Stream events until handler completes and all notifications are drained + loop { + // If handler already completed, just drain notifications + if handler_result.is_some() { + eprintln!("[DEBUG SSE RT] Handler complete, draining notifications"); + match notification_rx.try_recv() { + Ok(notification) => { + let is_request = notification.id.is_some(); + let json_message = if let Some(ref id) = notification.id { + serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "method": notification.method, + "params": notification.params + }) + } else { + serde_json::json!({ + "jsonrpc": "2.0", + "method": notification.method, + "params": notification.params + }) + }; + eprintln!("[DEBUG SSE RT] Draining {}: {}", if is_request { "request" } else { "notification" }, notification.method); + yield Ok(SseEvent::default().data(json_message.to_string())); + } + Err(_) => { + // No more notifications, send final response and exit + if let Some(response) = handler_result.take() { + let json_response = serde_json::to_value(&response).unwrap_or(Value::Null); + eprintln!("[DEBUG SSE RT] Sending final response"); + yield Ok(SseEvent::default().data(json_response.to_string())); + } + break; + } + } + continue; + } + + // Handler not yet complete - use select to wait for either + tokio::select! { + // Check if handler completed + result = &mut handler_task => { + eprintln!("[DEBUG SSE RT] Handler task completed"); + match result { + Ok(response) => { + handler_result = Some(response); + // Continue loop to drain notifications + } + Err(e) => { + eprintln!("[DEBUG SSE RT] Handler task failed: {e}"); + let error_response = serde_json::json!({ + "jsonrpc": "2.0", + "error": { + "code": -32603, + "message": format!("Internal error: {}", e) + }, + "id": null + }); + yield Ok(SseEvent::default().data(error_response.to_string())); + break; + } + } + } + + // Receive and send notifications/requests as they arrive + notification = notification_rx.recv() => { + match notification { + Some(notification) => { + let is_request = notification.id.is_some(); + let json_message = if let Some(ref id) = notification.id { + serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "method": notification.method, + "params": notification.params + }) + } else { + serde_json::json!({ + "jsonrpc": "2.0", + "method": notification.method, + "params": notification.params + }) + }; + eprintln!("[DEBUG SSE RT] Sending {}: {}", + if is_request { "request" } else { "notification" }, + notification.method); + yield Ok(SseEvent::default().data(json_message.to_string())); + } + None => { + // Channel closed - handler should be done + eprintln!("[DEBUG SSE RT] Channel closed"); + // Wait for handler task to complete + if let Ok(response) = (&mut handler_task).await { + handler_result = Some(response); + } + } + } + } + } + } + eprintln!("[DEBUG SSE RT] Stream complete"); + } +} + +/// Create an SSE stream for POST responses that includes notifications/requests and the final response +fn create_post_response_stream( + notifications: Vec, + response: pulseengine_mcp_protocol::Response, +) -> impl Stream> + Send { + async_stream::stream! { + // First, send all notifications/requests as SSE events + for notification in notifications { + let is_request = notification.id.is_some(); + let json_message = if let Some(ref id) = notification.id { + // This is a request (has ID for response correlation) + serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "method": notification.method, + "params": notification.params + }) + } else { + // This is a notification (no ID) + serde_json::json!({ + "jsonrpc": "2.0", + "method": notification.method, + "params": notification.params + }) + }; + + eprintln!("[DEBUG SSE] Sending {}: {}", + if is_request { "request" } else { "notification" }, + notification.method); + yield Ok(SseEvent::default().data(json_message.to_string())); + } + + // Then send the final response + let json_response = serde_json::to_value(&response).unwrap_or(Value::Null); + eprintln!("[DEBUG SSE] Sending final response"); + yield Ok(SseEvent::default().data(json_response.to_string())); + } +} + +/// Create an SSE stream for a session +fn create_sse_stream( + state: Arc, + session_id: String, + stream_id: String, +) -> impl Stream> + Send { + async_stream::stream! { + eprintln!("[DEBUG SSE] Stream started for session {}, stream {}", session_id, stream_id); + // Get a receiver for this session's messages + let mut receiver = { + let sessions = state.sessions.read().await; + if let Some(session) = sessions.get(&session_id) { + let rx = session.message_sender.subscribe(); + let receiver_count = session.message_sender.receiver_count(); + eprintln!("[DEBUG SSE] Subscribed to session {session_id}, receiver count now: {receiver_count}"); + rx + } else { + warn!("Session {} not found when creating SSE stream", session_id); + eprintln!("[DEBUG SSE] Session {session_id} NOT FOUND!"); + return; + } + }; + + // Send initial priming event + let event_id = StreamableHttpTransport::next_event_id(&state, &session_id, &stream_id).await; + let mut event = SseEvent::default() + .retry(std::time::Duration::from_millis(state.config.sse_retry_ms)) + .data(""); + if let Some(id) = event_id { + event = event.id(id.encode()); + } + yield Ok(event); + + // Send connection established event + let connection_event = serde_json::json!({ + "type": "connection", + "status": "connected", + "sessionId": session_id, + "streamId": stream_id, + "transport": "streamable-http", + "resumable": state.config.sse_resumable, + "bidirectional": true + }); + + let event_id = StreamableHttpTransport::next_event_id(&state, &session_id, &stream_id).await; + let mut event = SseEvent::default().data(connection_event.to_string()); + if let Some(id) = event_id { + event = event.id(id.encode()); + } + yield Ok(event); - // Return JSON response with session header - let mut headers = HeaderMap::new(); - headers.insert("Mcp-Session-Id", session_id.parse().unwrap()); - debug!("Sending response with session ID: {}", session_id); + // Listen for messages and forward them + eprintln!("[DEBUG SSE] Entering message loop for session {}", session_id); + loop { + match receiver.recv().await { + Ok(message) => { + eprintln!("[DEBUG SSE] Received message for session {session_id}: {message:?}"); + let json_message = match message { + SseMessage::Notification { method, params } => { + serde_json::json!({ + "jsonrpc": "2.0", + "method": method, + "params": params + }) + } + SseMessage::Request { id, method, params } => { + serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params + }) + } + }; - (StatusCode::OK, headers, Json(response)).into_response() + let event_id = StreamableHttpTransport::next_event_id(&state, &session_id, &stream_id).await; + let mut event = SseEvent::default().data(json_message.to_string()); + if let Some(id) = event_id { + event = event.id(id.encode()); + } + eprintln!("[DEBUG SSE] Yielding SSE event for session {session_id}"); + yield Ok(event); + } + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!("SSE stream lagged by {} messages", n); + // Continue receiving + } + Err(broadcast::error::RecvError::Closed) => { + debug!("SSE channel closed for session {}", session_id); + break; + } + } + } + } } /// Handle SSE requests for server-to-client streaming (MCP 2025-11-25) @@ -395,12 +977,17 @@ async fn handle_messages( /// - Stream resumption via `Last-Event-ID` header /// - Event IDs for stream identity /// - Server-initiated disconnect with `retry` field +/// - **Server-to-client notifications and requests** async fn handle_sse( State(state): State>, headers: HeaderMap, Query(query): Query, -) -> axum::response::Response { +) -> AxumResponse { info!("SSE connection request: {:?}", query); + eprintln!( + "[DEBUG SSE HANDLER] SSE request with query session_id: {:?}", + query.session_id + ); // MCP 2025-11-25: Validate Origin header - return 403 Forbidden for invalid origins if let Some(forbidden_response) = validate_origin(&headers, &state.config) { @@ -426,63 +1013,20 @@ async fn handle_sse( // Generate a stream ID for this connection let stream_id = Uuid::new_v4().to_string(); - // Get initial event ID for this stream (MCP 2025-11-25) - // This primes the client for reconnection - let initial_event_id = if state.config.sse_resumable { - StreamableHttpTransport::next_event_id(&state, &session_id, &stream_id).await - } else { - None - }; - - // Build SSE response - // MCP 2025-11-25: Server SHOULD immediately send an SSE event with an event ID - // and empty data field to prime the client for reconnection - let mut sse_body = String::new(); - - // Add retry field (MCP 2025-11-25) - sse_body.push_str(&format!("retry: {}\n", state.config.sse_retry_ms)); - - // Send initial priming event with event ID - if let Some(event_id) = initial_event_id { - sse_body.push_str(&format!("id: {}\n", event_id.encode())); - } - sse_body.push_str("data: \n\n"); // Empty data field for priming - - // Send connection established event - let connection_event = serde_json::json!({ - "type": "connection", - "status": "connected", - "sessionId": session_id, - "streamId": stream_id, - "transport": "streamable-http", - "resumable": state.config.sse_resumable - }); - - // Get next event ID for the connection event - let connection_event_id = if state.config.sse_resumable { - StreamableHttpTransport::next_event_id(&state, &session_id, &stream_id).await - } else { - None - }; - - if let Some(event_id) = connection_event_id { - sse_body.push_str(&format!("id: {}\n", event_id.encode())); - } - sse_body.push_str(&format!("data: {connection_event}\n\n")); - - // Build response headers - let mut response_headers = HeaderMap::new(); - response_headers.insert("Content-Type", "text/event-stream".parse().unwrap()); - response_headers.insert("Cache-Control", "no-cache".parse().unwrap()); - response_headers.insert("Connection", "keep-alive".parse().unwrap()); - response_headers.insert("Mcp-Session-Id", session_id.parse().unwrap()); - debug!( "SSE response with session ID: {}, stream ID: {}", session_id, stream_id ); - (StatusCode::OK, response_headers, sse_body).into_response() + // Create the SSE stream + let stream = create_sse_stream(Arc::clone(&state), session_id.clone(), stream_id); + + // Build response with headers + let mut response_headers = HeaderMap::new(); + response_headers.insert("Mcp-Session-Id", session_id.parse().unwrap()); + response_headers.insert("Cache-Control", "no-cache".parse().unwrap()); + + (response_headers, Sse::new(stream)).into_response() } #[async_trait] @@ -493,9 +1037,21 @@ impl Transport for StreamableHttpTransport { self.config.host, self.config.port ); + let sessions: Arc = Arc::new(RwLock::new(HashMap::new())); + let pending_requests: Arc = + Arc::new(std::sync::RwLock::new(HashMap::new())); + + // Create transport handle for external access + self.transport_handle = Some(TransportHandle { + sessions: Arc::clone(&sessions), + pending_requests: Arc::clone(&pending_requests), + config: self.config.clone(), + }); + let state = Arc::new(AppState { handler: Arc::new(handler), - sessions: Arc::new(RwLock::new(HashMap::new())), + sessions, + pending_requests, config: self.config.clone(), }); @@ -504,7 +1060,10 @@ impl Transport for StreamableHttpTransport { .route("/mcp", post(handle_messages).get(handle_sse)) .route("/messages", post(handle_messages)) // Legacy endpoint .route("/sse", get(handle_sse)) // Legacy endpoint - .route("/", get(|| async { "MCP Streamable HTTP Server" })) + .route( + "/", + get(|| async { "MCP Streamable HTTP Server (Bidirectional)" }), + ) .layer(ServiceBuilder::new().layer(if self.config.enable_cors { CorsLayer::permissive() } else { @@ -528,14 +1087,11 @@ impl Transport for StreamableHttpTransport { addr ); info!( - " GET http://{}/mcp - Session establishment (MCP-UI compatible)", + " GET http://{}/mcp - SSE stream (bidirectional, MCP-UI compatible)", addr ); info!(" POST http://{}/messages - MCP messages (legacy)", addr); - info!( - " GET http://{}/sse - Session establishment (legacy)", - addr - ); + info!(" GET http://{}/sse - SSE stream (legacy)", addr); let server_handle = tokio::spawn(async move { if let Err(e) = axum::serve(listener, app).await { @@ -551,6 +1107,7 @@ impl Transport for StreamableHttpTransport { if let Some(handle) = self.server_handle.take() { handle.abort(); } + self.transport_handle = None; Ok(()) } @@ -561,4 +1118,50 @@ impl Transport for StreamableHttpTransport { Err(TransportError::Connection("Not running".to_string())) } } + + async fn send_notification( + &self, + session_id: Option<&str>, + method: &str, + params: Value, + ) -> Result<(), TransportError> { + if let Some(handle) = &self.transport_handle { + handle.send_notification(session_id, method, params).await + } else { + Err(TransportError::Connection( + "Transport not started".to_string(), + )) + } + } + + async fn send_request( + &self, + session_id: Option<&str>, + method: &str, + params: Value, + timeout: Duration, + ) -> Result { + if let Some(handle) = &self.transport_handle { + handle + .send_request(session_id, method, params, timeout) + .await + } else { + Err(TransportError::Connection( + "Transport not started".to_string(), + )) + } + } + + fn supports_bidirectional(&self) -> bool { + true + } + + fn register_pending_request( + &self, + request_id: &str, + ) -> Option> { + self.transport_handle + .as_ref() + .map(|handle| handle.register_pending_request_sync(request_id)) + } } diff --git a/mcp-transport/src/streamable_http_tests.rs b/mcp-transport/src/streamable_http_tests.rs index b4ec4fd..6c9eedd 100644 --- a/mcp-transport/src/streamable_http_tests.rs +++ b/mcp-transport/src/streamable_http_tests.rs @@ -62,6 +62,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 5000, sse_resumable: false, + ..Default::default() }; assert_eq!(config.port, 8080); @@ -126,6 +127,7 @@ mod tests { enforce_origin_validation: true, sse_retry_ms: 2000, sse_resumable: true, + ..Default::default() }; let transport = StreamableHttpTransport::with_config(config); @@ -174,6 +176,7 @@ mod tests { enforce_origin_validation: true, sse_retry_ms: 4000, sse_resumable: true, + ..Default::default() }; let cloned = config.clone(); @@ -233,6 +236,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 3000, sse_resumable: true, + ..Default::default() }, StreamableHttpConfig { port: 8080, @@ -242,6 +246,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 3000, sse_resumable: true, + ..Default::default() }, StreamableHttpConfig { port: 65535, @@ -251,6 +256,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 3000, sse_resumable: true, + ..Default::default() }, ]; @@ -272,6 +278,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 3000, sse_resumable: true, + ..Default::default() }; // Test that host string is properly stored @@ -446,6 +453,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 3000, sse_resumable: true, + ..Default::default() }; assert_eq!(config.port, port); @@ -469,6 +477,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 3000, sse_resumable: true, + ..Default::default() }; assert_eq!(config.port, 0); @@ -484,6 +493,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 1000, sse_resumable: false, + ..Default::default() }; assert_eq!(config.port, 65535); @@ -517,6 +527,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 3000, sse_resumable: true, + ..Default::default() }; assert_eq!(config.host, host); @@ -537,6 +548,7 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 3000, sse_resumable: true, + ..Default::default() }; assert_eq!(config.enable_cors, enable_cors); @@ -726,9 +738,415 @@ mod tests { enforce_origin_validation: false, sse_retry_ms: 10000, // 10 seconds sse_resumable: false, + ..Default::default() }; assert_eq!(config.sse_retry_ms, 10000); assert!(!config.sse_resumable); } + + // ============================================================================ + // SseMessage Tests + // ============================================================================ + + #[test] + fn test_sse_message_notification() { + let message = SseMessage::Notification { + method: "notifications/progress".to_string(), + params: json!({"progress": 50, "total": 100}), + }; + + if let SseMessage::Notification { method, params } = message { + assert_eq!(method, "notifications/progress"); + assert_eq!(params["progress"], 50); + } else { + panic!("Expected Notification variant"); + } + } + + #[test] + fn test_sse_message_request() { + let message = SseMessage::Request { + id: "req-123".to_string(), + method: "sampling/createMessage".to_string(), + params: json!({"maxTokens": 100}), + }; + + if let SseMessage::Request { id, method, params } = message { + assert_eq!(id, "req-123"); + assert_eq!(method, "sampling/createMessage"); + assert_eq!(params["maxTokens"], 100); + } else { + panic!("Expected Request variant"); + } + } + + #[test] + fn test_sse_message_debug() { + let notification = SseMessage::Notification { + method: "test".to_string(), + params: json!({}), + }; + let debug_str = format!("{notification:?}"); + assert!(debug_str.contains("Notification")); + assert!(debug_str.contains("test")); + + let request = SseMessage::Request { + id: "id123".to_string(), + method: "test".to_string(), + params: json!({}), + }; + let debug_str = format!("{request:?}"); + assert!(debug_str.contains("Request")); + assert!(debug_str.contains("id123")); + } + + #[test] + fn test_sse_message_clone() { + let original = SseMessage::Notification { + method: "test".to_string(), + params: json!({"key": "value"}), + }; + + let cloned = original.clone(); + if let ( + SseMessage::Notification { + method: m1, + params: p1, + }, + SseMessage::Notification { + method: m2, + params: p2, + }, + ) = (&original, &cloned) + { + assert_eq!(m1, m2); + assert_eq!(p1, p2); + } + } + + // ============================================================================ + // TransportHandle Tests (via Transport trait) + // ============================================================================ + + #[tokio::test] + async fn test_transport_send_notification_not_started() { + let transport = StreamableHttpTransport::new(18200); + + let result = transport + .send_notification(Some("session-123"), "test/method", json!({})) + .await; + + assert!(result.is_err()); + if let Err(TransportError::Connection(msg)) = result { + assert!(msg.contains("not started")); + } + } + + #[tokio::test] + async fn test_transport_send_request_not_started() { + use std::time::Duration; + + let transport = StreamableHttpTransport::new(18201); + + let result = transport + .send_request( + Some("session-123"), + "sampling/createMessage", + json!({}), + Duration::from_secs(5), + ) + .await; + + assert!(result.is_err()); + if let Err(TransportError::Connection(msg)) = result { + assert!(msg.contains("not started")); + } + } + + #[tokio::test] + async fn test_transport_supports_bidirectional() { + let transport = StreamableHttpTransport::new(18202); + assert!(transport.supports_bidirectional()); + } + + #[tokio::test] + async fn test_transport_register_pending_request_not_started() { + let transport = StreamableHttpTransport::new(18203); + + let result = transport.register_pending_request("req-123"); + // When transport is not started, handle is None, so register returns None + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_transport_handle_not_available_before_start() { + let transport = StreamableHttpTransport::new(18204); + assert!(transport.handle().is_none()); + } + + #[tokio::test] + async fn test_transport_handle_available_after_start() { + let mut transport = StreamableHttpTransport::new(18205); + let handler = Box::new(mock_handler); + + if transport.start(handler).await.is_ok() { + assert!(transport.handle().is_some()); + transport.stop().await.ok(); + } + } + + // ============================================================================ + // TransportHandle Direct Tests + // ============================================================================ + + #[tokio::test] + async fn test_transport_handle_send_notification_session_not_found() { + let mut transport = StreamableHttpTransport::new(18206); + let handler = Box::new(mock_handler); + + if transport.start(handler).await.is_ok() { + let handle = transport.handle().unwrap(); + + // Try to send to a non-existent session + let result = handle + .send_notification(Some("nonexistent-session"), "test/method", json!({})) + .await; + + assert!(result.is_err()); + if let Err(TransportError::SessionNotFound(id)) = result { + assert_eq!(id, "nonexistent-session"); + } + + transport.stop().await.ok(); + } + } + + #[tokio::test] + async fn test_transport_handle_send_notification_broadcast_empty() { + let mut transport = StreamableHttpTransport::new(18207); + let handler = Box::new(mock_handler); + + if transport.start(handler).await.is_ok() { + let handle = transport.handle().unwrap(); + + // Broadcast to all sessions (empty, should succeed) + let result = handle + .send_notification(None, "test/method", json!({})) + .await; + + assert!(result.is_ok()); + + transport.stop().await.ok(); + } + } + + #[tokio::test] + async fn test_transport_handle_send_request_missing_session_id() { + let mut transport = StreamableHttpTransport::new(18208); + let handler = Box::new(mock_handler); + + if transport.start(handler).await.is_ok() { + let handle = transport.handle().unwrap(); + + // Try to send request without session ID + let result = handle + .send_request( + None, + "sampling/createMessage", + json!({}), + std::time::Duration::from_secs(1), + ) + .await; + + assert!(result.is_err()); + if let Err(TransportError::Config(msg)) = result { + assert!(msg.contains("Session ID required")); + } + + transport.stop().await.ok(); + } + } + + #[tokio::test] + async fn test_transport_handle_send_request_session_not_found() { + let mut transport = StreamableHttpTransport::new(18209); + let handler = Box::new(mock_handler); + + if transport.start(handler).await.is_ok() { + let handle = transport.handle().unwrap(); + + let result = handle + .send_request( + Some("nonexistent-session"), + "sampling/createMessage", + json!({}), + std::time::Duration::from_secs(1), + ) + .await; + + assert!(result.is_err()); + if let Err(TransportError::SessionNotFound(id)) = result { + assert_eq!(id, "nonexistent-session"); + } + + transport.stop().await.ok(); + } + } + + #[tokio::test] + async fn test_transport_handle_handle_response_unknown_request() { + let mut transport = StreamableHttpTransport::new(18210); + let handler = Box::new(mock_handler); + + if transport.start(handler).await.is_ok() { + let handle = transport.handle().unwrap(); + + // Try to handle response for unknown request + let result = handle.handle_response("unknown-req-id", json!({"result": "test"})); + assert!(!result); // Should return false for unknown request + + transport.stop().await.ok(); + } + } + + #[tokio::test] + async fn test_transport_handle_register_and_handle_response() { + let mut transport = StreamableHttpTransport::new(18211); + let handler = Box::new(mock_handler); + + if transport.start(handler).await.is_ok() { + let handle = transport.handle().unwrap(); + + // Register a pending request + let rx = handle.register_pending_request_sync("req-456"); + + // Handle the response + let handled = handle.handle_response("req-456", json!({"result": "success"})); + assert!(handled); + + // Receive the response + let response = rx.await.unwrap(); + assert_eq!(response["result"], "success"); + + // Trying to handle same ID again should fail + let handled_again = handle.handle_response("req-456", json!({"result": "again"})); + assert!(!handled_again); + + transport.stop().await.ok(); + } + } + + #[tokio::test] + async fn test_transport_handle_clone() { + let mut transport = StreamableHttpTransport::new(18212); + let handler = Box::new(mock_handler); + + if transport.start(handler).await.is_ok() { + let handle1 = transport.handle().unwrap(); + let handle2 = handle1.clone(); + + // Both handles should work + let result1 = handle1.send_notification(None, "test1", json!({})).await; + let result2 = handle2.send_notification(None, "test2", json!({})).await; + + assert!(result1.is_ok()); + assert!(result2.is_ok()); + + transport.stop().await.ok(); + } + } + + // ============================================================================ + // SseEventId Edge Cases + // ============================================================================ + + #[test] + fn test_sse_event_id_parse_with_colons_in_ids() { + // IDs that contain colons (splitn(3) splits into at most 3 parts) + // "session:with:colons:stream:id:7" -> ["session", "with", "colons:stream:id:7"] + // The third part "colons:stream:id:7" fails to parse as u64, so returns None + let parsed = SseEventId::parse("session:with:colons:stream:id:7"); + assert!(parsed.is_none()); // Fails because sequence contains non-numeric chars + + // Test case where colons appear but sequence is still valid + // "session:stream-with-dashes:42" -> ["session", "stream-with-dashes", "42"] + let parsed = SseEventId::parse("session:stream-with-dashes:42"); + assert!(parsed.is_some()); + let event_id = parsed.unwrap(); + assert_eq!(event_id.session_id, "session"); + assert_eq!(event_id.stream_id, "stream-with-dashes"); + assert_eq!(event_id.sequence, 42); + } + + #[test] + fn test_sse_event_id_parse_large_sequence() { + let parsed = SseEventId::parse("session:stream:18446744073709551615"); + + assert!(parsed.is_some()); + let event_id = parsed.unwrap(); + assert_eq!(event_id.sequence, u64::MAX); + } + + #[test] + fn test_sse_event_id_parse_zero_sequence() { + let parsed = SseEventId::parse("s:t:0"); + + assert!(parsed.is_some()); + let event_id = parsed.unwrap(); + assert_eq!(event_id.sequence, 0); + } + + #[test] + fn test_sse_event_id_parse_negative_sequence() { + // Negative numbers should fail to parse as u64 + let parsed = SseEventId::parse("session:stream:-1"); + assert!(parsed.is_none()); + } + + #[test] + fn test_sse_event_id_empty_parts() { + // Empty session or stream should still parse (just empty strings) + let parsed = SseEventId::parse("::42"); + + assert!(parsed.is_some()); + let event_id = parsed.unwrap(); + assert_eq!(event_id.session_id, ""); + assert_eq!(event_id.stream_id, ""); + assert_eq!(event_id.sequence, 42); + } + + // ============================================================================ + // Config Request Timeout Tests + // ============================================================================ + + #[test] + fn test_config_request_timeout_default() { + let config = StreamableHttpConfig::default(); + assert_eq!(config.request_timeout, std::time::Duration::from_secs(60)); + } + + #[test] + fn test_config_request_timeout_custom() { + let config = StreamableHttpConfig { + request_timeout: std::time::Duration::from_secs(30), + ..Default::default() + }; + assert_eq!(config.request_timeout, std::time::Duration::from_secs(30)); + } + + #[test] + fn test_config_channel_capacity_default() { + let config = StreamableHttpConfig::default(); + assert_eq!(config.channel_capacity, 100); + } + + #[test] + fn test_config_channel_capacity_custom() { + let config = StreamableHttpConfig { + channel_capacity: 500, + ..Default::default() + }; + assert_eq!(config.channel_capacity, 500); + } }