diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java index 9a1f6e84..2df3514b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java @@ -35,9 +35,17 @@ public Mono handleRequest(McpTransportContext transpo } return requestHandler.handle(transportContext, request.params()) .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(t -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, t.getMessage(), - null)))); + .onErrorResume(t -> { + McpSchema.JSONRPCResponse.JSONRPCError error; + if (t instanceof McpError mcpError && mcpError.getJsonRpcError() != null) { + error = mcpError.getJsonRpcError(); + } + else { + error = new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + t.getMessage(), null); + } + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); + }); } @Override diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java index 00942226..4c3f22d7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -4,33 +4,13 @@ package io.modelcontextprotocol.server; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.awaitility.Awaitility.await; - -import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; - -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.springframework.web.client.RestClient; - import com.fasterxml.jackson.databind.ObjectMapper; - import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; @@ -41,7 +21,33 @@ import io.modelcontextprotocol.spec.McpSchema.PromptReference; import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.ProtocolVersions; import net.javacrumbs.jsonunit.core.Option; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.web.client.RestClient; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + +import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; +import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; class HttpServletStatelessIntegrationTests { @@ -460,6 +466,49 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { mcpServer.close(); } + @Test + void testThrownMcpError() throws Exception { + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool testTool = Tool.builder().name("test").description("test").build(); + + McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( + testTool, (transportContext, request) -> { + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(12345, "testing", Map.of("a", "b"))); + }); + + mcpServer.addTool(toolSpec); + + McpSchema.CallToolRequest callToolRequest = new McpSchema.CallToolRequest("test", Map.of()); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_TOOLS_CALL, "test", callToolRequest); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", CUSTOM_MESSAGE_ENDPOINT); + MockHttpServletResponse response = new MockHttpServletResponse(); + + byte[] content = new ObjectMapper().writeValueAsBytes(jsonrpcRequest); + request.setContent(content); + request.addHeader("Content-Type", "application/json"); + request.addHeader("Content-Length", Integer.toString(content.length)); + request.addHeader("Content-Length", Integer.toString(content.length)); + request.addHeader("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM); + request.addHeader("Content-Type", APPLICATION_JSON); + request.addHeader("Cache-Control", "no-cache"); + request.addHeader(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2025_03_26); + mcpStatelessServerTransport.service(request, response); + + McpSchema.JSONRPCResponse jsonrpcResponse = new ObjectMapper().readValue(response.getContentAsByteArray(), + McpSchema.JSONRPCResponse.class); + + assertThat(jsonrpcResponse.error()) + .isEqualTo(new McpSchema.JSONRPCResponse.JSONRPCError(12345, "testing", Map.of("a", "b"))); + + mcpServer.close(); + } + private double evaluateExpression(String expression) { // Simple expression evaluator for testing return switch (expression) {