Skip to content

Commit 29c69e6

Browse files
authored
Fix uncaught exception in MCP server (#822)
1 parent 1a9ead0 commit 29c69e6

File tree

2 files changed

+212
-19
lines changed

2 files changed

+212
-19
lines changed

src/mcp/shared/session.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
1616
from mcp.types import (
1717
CONNECTION_CLOSED,
18+
INVALID_PARAMS,
1819
CancelledNotification,
1920
ClientNotification,
2021
ClientRequest,
@@ -354,27 +355,47 @@ async def _receive_loop(self) -> None:
354355
if isinstance(message, Exception):
355356
await self._handle_incoming(message)
356357
elif isinstance(message.message.root, JSONRPCRequest):
357-
validated_request = self._receive_request_type.model_validate(
358-
message.message.root.model_dump(
359-
by_alias=True, mode="json", exclude_none=True
358+
try:
359+
validated_request = self._receive_request_type.model_validate(
360+
message.message.root.model_dump(
361+
by_alias=True, mode="json", exclude_none=True
362+
)
360363
)
361-
)
362-
responder = RequestResponder(
363-
request_id=message.message.root.id,
364-
request_meta=validated_request.root.params.meta
365-
if validated_request.root.params
366-
else None,
367-
request=validated_request,
368-
session=self,
369-
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
370-
message_metadata=message.metadata,
371-
)
372-
373-
self._in_flight[responder.request_id] = responder
374-
await self._received_request(responder)
364+
responder = RequestResponder(
365+
request_id=message.message.root.id,
366+
request_meta=validated_request.root.params.meta
367+
if validated_request.root.params
368+
else None,
369+
request=validated_request,
370+
session=self,
371+
on_complete=lambda r: self._in_flight.pop(
372+
r.request_id, None),
373+
message_metadata=message.metadata,
374+
)
375+
self._in_flight[responder.request_id] = responder
376+
await self._received_request(responder)
375377

376-
if not responder._completed: # type: ignore[reportPrivateUsage]
377-
await self._handle_incoming(responder)
378+
if not responder._completed: # type: ignore[reportPrivateUsage]
379+
await self._handle_incoming(responder)
380+
except Exception as e:
381+
# For request validation errors, send a proper JSON-RPC error
382+
# response instead of crashing the server
383+
logging.warning(f"Failed to validate request: {e}")
384+
logging.debug(
385+
f"Message that failed validation: {message.message.root}"
386+
)
387+
error_response = JSONRPCError(
388+
jsonrpc="2.0",
389+
id=message.message.root.id,
390+
error=ErrorData(
391+
code=INVALID_PARAMS,
392+
message="Invalid request parameters",
393+
data="",
394+
),
395+
)
396+
session_message = SessionMessage(
397+
message=JSONRPCMessage(error_response))
398+
await self._write_stream.send(session_message)
378399

379400
elif isinstance(message.message.root, JSONRPCNotification):
380401
try:

tests/issues/test_malformed_input.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Claude Debug
2+
"""Test for HackerOne vulnerability report #3156202 - malformed input DOS."""
3+
4+
import anyio
5+
import pytest
6+
7+
from mcp.server.models import InitializationOptions
8+
from mcp.server.session import ServerSession
9+
from mcp.shared.message import SessionMessage
10+
from mcp.types import (
11+
INVALID_PARAMS,
12+
JSONRPCError,
13+
JSONRPCMessage,
14+
JSONRPCRequest,
15+
ServerCapabilities,
16+
)
17+
18+
19+
@pytest.mark.anyio
20+
async def test_malformed_initialize_request_does_not_crash_server():
21+
"""
22+
Test that malformed initialize requests return proper error responses
23+
instead of crashing the server (HackerOne #3156202).
24+
"""
25+
# Create in-memory streams for testing
26+
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[
27+
SessionMessage | Exception
28+
](10)
29+
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[
30+
SessionMessage
31+
](10)
32+
33+
try:
34+
# Create a malformed initialize request (missing required params field)
35+
malformed_request = JSONRPCRequest(
36+
jsonrpc="2.0",
37+
id="f20fe86132ed4cd197f89a7134de5685",
38+
method="initialize",
39+
# params=None # Missing required params field
40+
)
41+
42+
# Wrap in session message
43+
request_message = SessionMessage(message=JSONRPCMessage(malformed_request))
44+
45+
# Start a server session
46+
async with ServerSession(
47+
read_stream=read_receive_stream,
48+
write_stream=write_send_stream,
49+
init_options=InitializationOptions(
50+
server_name="test_server",
51+
server_version="1.0.0",
52+
capabilities=ServerCapabilities(),
53+
),
54+
):
55+
# Send the malformed request
56+
await read_send_stream.send(request_message)
57+
58+
# Give the session time to process the request
59+
await anyio.sleep(0.1)
60+
61+
# Check that we received an error response instead of a crash
62+
try:
63+
response_message = write_receive_stream.receive_nowait()
64+
response = response_message.message.root
65+
66+
# Verify it's a proper JSON-RPC error response
67+
assert isinstance(response, JSONRPCError)
68+
assert response.jsonrpc == "2.0"
69+
assert response.id == "f20fe86132ed4cd197f89a7134de5685"
70+
assert response.error.code == INVALID_PARAMS
71+
assert "Invalid request parameters" in response.error.message
72+
73+
# Verify the session is still alive and can handle more requests
74+
# Send another malformed request to confirm server stability
75+
another_malformed_request = JSONRPCRequest(
76+
jsonrpc="2.0",
77+
id="test_id_2",
78+
method="tools/call",
79+
# params=None # Missing required params
80+
)
81+
another_request_message = SessionMessage(
82+
message=JSONRPCMessage(another_malformed_request)
83+
)
84+
85+
await read_send_stream.send(another_request_message)
86+
await anyio.sleep(0.1)
87+
88+
# Should get another error response, not a crash
89+
second_response_message = write_receive_stream.receive_nowait()
90+
second_response = second_response_message.message.root
91+
92+
assert isinstance(second_response, JSONRPCError)
93+
assert second_response.id == "test_id_2"
94+
assert second_response.error.code == INVALID_PARAMS
95+
96+
except anyio.WouldBlock:
97+
pytest.fail("No response received - server likely crashed")
98+
finally:
99+
# Close all streams to ensure proper cleanup
100+
await read_send_stream.aclose()
101+
await write_send_stream.aclose()
102+
await read_receive_stream.aclose()
103+
await write_receive_stream.aclose()
104+
105+
106+
@pytest.mark.anyio
107+
async def test_multiple_concurrent_malformed_requests():
108+
"""
109+
Test that multiple concurrent malformed requests don't crash the server.
110+
"""
111+
# Create in-memory streams for testing
112+
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[
113+
SessionMessage | Exception
114+
](100)
115+
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[
116+
SessionMessage
117+
](100)
118+
119+
try:
120+
# Start a server session
121+
async with ServerSession(
122+
read_stream=read_receive_stream,
123+
write_stream=write_send_stream,
124+
init_options=InitializationOptions(
125+
server_name="test_server",
126+
server_version="1.0.0",
127+
capabilities=ServerCapabilities(),
128+
),
129+
):
130+
# Send multiple malformed requests concurrently
131+
malformed_requests = []
132+
for i in range(10):
133+
malformed_request = JSONRPCRequest(
134+
jsonrpc="2.0",
135+
id=f"malformed_{i}",
136+
method="initialize",
137+
# params=None # Missing required params
138+
)
139+
request_message = SessionMessage(
140+
message=JSONRPCMessage(malformed_request)
141+
)
142+
malformed_requests.append(request_message)
143+
144+
# Send all requests
145+
for request in malformed_requests:
146+
await read_send_stream.send(request)
147+
148+
# Give time to process
149+
await anyio.sleep(0.2)
150+
151+
# Verify we get error responses for all requests
152+
error_responses = []
153+
try:
154+
while True:
155+
response_message = write_receive_stream.receive_nowait()
156+
error_responses.append(response_message.message.root)
157+
except anyio.WouldBlock:
158+
pass # No more messages
159+
160+
# Should have received 10 error responses
161+
assert len(error_responses) == 10
162+
163+
for i, response in enumerate(error_responses):
164+
assert isinstance(response, JSONRPCError)
165+
assert response.id == f"malformed_{i}"
166+
assert response.error.code == INVALID_PARAMS
167+
finally:
168+
# Close all streams to ensure proper cleanup
169+
await read_send_stream.aclose()
170+
await write_send_stream.aclose()
171+
await read_receive_stream.aclose()
172+
await write_receive_stream.aclose()

0 commit comments

Comments
 (0)