Skip to content

Commit b13c5c5

Browse files
Ketansuhaasketan-clairyondbschmigelski
authored
feat(mcp): Add list_prompts, get_prompt methods (#160)
Co-authored-by: ketan-clairyon <[email protected]> Co-authored-by: Dean Schmigelski <[email protected]>
1 parent 4e0e0a6 commit b13c5c5

File tree

3 files changed

+184
-10
lines changed

3 files changed

+184
-10
lines changed

src/strands/tools/mcp/mcp_client.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from mcp import ClientSession, ListToolsResult
2222
from mcp.types import CallToolResult as MCPCallToolResult
23+
from mcp.types import GetPromptResult, ListPromptsResult
2324
from mcp.types import ImageContent as MCPImageContent
2425
from mcp.types import TextContent as MCPTextContent
2526

@@ -165,6 +166,54 @@ async def _list_tools_async() -> ListToolsResult:
165166
self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools))
166167
return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor)
167168

169+
def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult:
170+
"""Synchronously retrieves the list of available prompts from the MCP server.
171+
172+
This method calls the asynchronous list_prompts method on the MCP session
173+
and returns the raw ListPromptsResult with pagination support.
174+
175+
Args:
176+
pagination_token: Optional token for pagination
177+
178+
Returns:
179+
ListPromptsResult: The raw MCP response containing prompts and pagination info
180+
"""
181+
self._log_debug_with_thread("listing MCP prompts synchronously")
182+
if not self._is_session_active():
183+
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
184+
185+
async def _list_prompts_async() -> ListPromptsResult:
186+
return await self._background_thread_session.list_prompts(cursor=pagination_token)
187+
188+
list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
189+
self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts))
190+
for prompt in list_prompts_result.prompts:
191+
self._log_debug_with_thread(prompt.name)
192+
193+
return list_prompts_result
194+
195+
def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult:
196+
"""Synchronously retrieves a prompt from the MCP server.
197+
198+
Args:
199+
prompt_id: The ID of the prompt to retrieve
200+
args: Optional arguments to pass to the prompt
201+
202+
Returns:
203+
GetPromptResult: The prompt response from the MCP server
204+
"""
205+
self._log_debug_with_thread("getting MCP prompt synchronously")
206+
if not self._is_session_active():
207+
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
208+
209+
async def _get_prompt_async() -> GetPromptResult:
210+
return await self._background_thread_session.get_prompt(prompt_id, arguments=args)
211+
212+
get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
213+
self._log_debug_with_thread("received prompt from MCP server")
214+
215+
return get_prompt_result
216+
168217
def call_tool_sync(
169218
self,
170219
tool_use_id: str,

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from mcp import ListToolsResult
66
from mcp.types import CallToolResult as MCPCallToolResult
7+
from mcp.types import GetPromptResult, ListPromptsResult, Prompt, PromptMessage
78
from mcp.types import TextContent as MCPTextContent
89
from mcp.types import Tool as MCPTool
910

@@ -404,3 +405,64 @@ def test_exception_when_future_not_running():
404405

405406
# Verify that set_exception was not called since the future was not running
406407
mock_future.set_exception.assert_not_called()
408+
409+
410+
# Prompt Tests - Sync Methods
411+
412+
413+
def test_list_prompts_sync(mock_transport, mock_session):
414+
"""Test that list_prompts_sync correctly retrieves prompts."""
415+
mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1")
416+
mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt])
417+
418+
with MCPClient(mock_transport["transport_callable"]) as client:
419+
result = client.list_prompts_sync()
420+
421+
mock_session.list_prompts.assert_called_once_with(cursor=None)
422+
assert len(result.prompts) == 1
423+
assert result.prompts[0].name == "test_prompt"
424+
assert result.nextCursor is None
425+
426+
427+
def test_list_prompts_sync_with_pagination_token(mock_transport, mock_session):
428+
"""Test that list_prompts_sync correctly passes pagination token and returns next cursor."""
429+
mock_prompt = Prompt(name="test_prompt", description="A test prompt", id="prompt_1")
430+
mock_session.list_prompts.return_value = ListPromptsResult(prompts=[mock_prompt], nextCursor="next_page_token")
431+
432+
with MCPClient(mock_transport["transport_callable"]) as client:
433+
result = client.list_prompts_sync(pagination_token="current_page_token")
434+
435+
mock_session.list_prompts.assert_called_once_with(cursor="current_page_token")
436+
assert len(result.prompts) == 1
437+
assert result.prompts[0].name == "test_prompt"
438+
assert result.nextCursor == "next_page_token"
439+
440+
441+
def test_list_prompts_sync_session_not_active():
442+
"""Test that list_prompts_sync raises an error when session is not active."""
443+
client = MCPClient(MagicMock())
444+
445+
with pytest.raises(MCPClientInitializationError, match="client session is not running"):
446+
client.list_prompts_sync()
447+
448+
449+
def test_get_prompt_sync(mock_transport, mock_session):
450+
"""Test that get_prompt_sync correctly retrieves a prompt."""
451+
mock_message = PromptMessage(role="user", content=MCPTextContent(type="text", text="This is a test prompt"))
452+
mock_session.get_prompt.return_value = GetPromptResult(messages=[mock_message])
453+
454+
with MCPClient(mock_transport["transport_callable"]) as client:
455+
result = client.get_prompt_sync("test_prompt_id", {"key": "value"})
456+
457+
mock_session.get_prompt.assert_called_once_with("test_prompt_id", arguments={"key": "value"})
458+
assert len(result.messages) == 1
459+
assert result.messages[0].role == "user"
460+
assert result.messages[0].content.text == "This is a test prompt"
461+
462+
463+
def test_get_prompt_sync_session_not_active():
464+
"""Test that get_prompt_sync raises an error when session is not active."""
465+
client = MCPClient(MagicMock())
466+
467+
with pytest.raises(MCPClientInitializationError, match="client session is not running"):
468+
client.get_prompt_sync("test_prompt_id", {})

tests_integ/test_mcp_client.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,17 @@
1818
from strands.types.tools import ToolUse
1919

2020

21-
def start_calculator_server(transport: Literal["sse", "streamable-http"], port=int):
21+
def start_comprehensive_mcp_server(transport: Literal["sse", "streamable-http"], port=int):
2222
"""
23-
Initialize and start an MCP calculator server for integration testing.
23+
Initialize and start a comprehensive MCP server for integration testing.
2424
25-
This function creates a FastMCP server instance that provides a simple
26-
calculator tool for performing addition operations. The server uses
27-
Server-Sent Events (SSE) transport for communication, making it accessible
28-
over HTTP.
25+
This function creates a FastMCP server instance that provides tools, prompts,
26+
and resources all in one server for comprehensive testing. The server uses
27+
Server-Sent Events (SSE) or streamable HTTP transport for communication.
2928
"""
3029
from mcp.server import FastMCP
3130

32-
mcp = FastMCP("Calculator Server", port=port)
31+
mcp = FastMCP("Comprehensive MCP Server", port=port)
3332

3433
@mcp.tool(description="Calculator tool which performs calculations")
3534
def calculator(x: int, y: int) -> int:
@@ -44,6 +43,15 @@ def generate_custom_image() -> MCPImageContent:
4443
except Exception as e:
4544
print("Error while generating custom image: {}".format(e))
4645

46+
# Prompts
47+
@mcp.prompt(description="A greeting prompt template")
48+
def greeting_prompt(name: str = "World") -> str:
49+
return f"Hello, {name}! How are you today?"
50+
51+
@mcp.prompt(description="A math problem prompt template")
52+
def math_prompt(operation: str = "addition", difficulty: str = "easy") -> str:
53+
return f"Create a {difficulty} {operation} math problem and solve it step by step."
54+
4755
mcp.run(transport=transport)
4856

4957

@@ -58,8 +66,9 @@ def test_mcp_client():
5866
{'role': 'assistant', 'content': [{'text': '\n\nThe result of adding 1 and 2 is 3.'}]}
5967
""" # noqa: E501
6068

69+
# Start comprehensive server with tools, prompts, and resources
6170
server_thread = threading.Thread(
62-
target=start_calculator_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
71+
target=start_comprehensive_mcp_server, kwargs={"transport": "sse", "port": 8000}, daemon=True
6372
)
6473
server_thread.start()
6574
time.sleep(2) # wait for server to startup completely
@@ -68,8 +77,14 @@ def test_mcp_client():
6877
stdio_mcp_client = MCPClient(
6978
lambda: stdio_client(StdioServerParameters(command="python", args=["tests_integ/echo_server.py"]))
7079
)
80+
7181
with sse_mcp_client, stdio_mcp_client:
72-
agent = Agent(tools=sse_mcp_client.list_tools_sync() + stdio_mcp_client.list_tools_sync())
82+
# Test Tools functionality
83+
sse_tools = sse_mcp_client.list_tools_sync()
84+
stdio_tools = stdio_mcp_client.list_tools_sync()
85+
all_tools = sse_tools + stdio_tools
86+
87+
agent = Agent(tools=all_tools)
7388
agent("add 1 and 2, then echo the result back to me")
7489

7590
tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
@@ -88,6 +103,43 @@ def test_mcp_client():
88103
]
89104
)
90105

106+
# Test Prompts functionality
107+
prompts_result = sse_mcp_client.list_prompts_sync()
108+
assert len(prompts_result.prompts) >= 2 # We expect at least greeting and math prompts
109+
110+
prompt_names = [prompt.name for prompt in prompts_result.prompts]
111+
assert "greeting_prompt" in prompt_names
112+
assert "math_prompt" in prompt_names
113+
114+
# Test get_prompt_sync with greeting prompt
115+
greeting_result = sse_mcp_client.get_prompt_sync("greeting_prompt", {"name": "Alice"})
116+
assert len(greeting_result.messages) > 0
117+
prompt_text = greeting_result.messages[0].content.text
118+
assert "Hello, Alice!" in prompt_text
119+
assert "How are you today?" in prompt_text
120+
121+
# Test get_prompt_sync with math prompt
122+
math_result = sse_mcp_client.get_prompt_sync(
123+
"math_prompt", {"operation": "multiplication", "difficulty": "medium"}
124+
)
125+
assert len(math_result.messages) > 0
126+
math_text = math_result.messages[0].content.text
127+
assert "multiplication" in math_text
128+
assert "medium" in math_text
129+
assert "step by step" in math_text
130+
131+
# Test pagination support for prompts
132+
prompts_with_token = sse_mcp_client.list_prompts_sync(pagination_token=None)
133+
assert len(prompts_with_token.prompts) >= 0
134+
135+
# Test pagination support for tools (existing functionality)
136+
tools_with_token = sse_mcp_client.list_tools_sync(pagination_token=None)
137+
assert len(tools_with_token) >= 0
138+
139+
# TODO: Add resources testing when resources are implemented
140+
# resources_result = sse_mcp_client.list_resources_sync()
141+
# assert len(resources_result.resources) >= 0
142+
91143
tool_use_id = "test-structured-content-123"
92144
result = stdio_mcp_client.call_tool_sync(
93145
tool_use_id=tool_use_id,
@@ -185,8 +237,9 @@ def test_mcp_client_without_structured_content():
185237
reason="streamable transport is failing in GitHub actions, debugging if linux compatibility issue",
186238
)
187239
def test_streamable_http_mcp_client():
240+
"""Test comprehensive MCP client with streamable HTTP transport."""
188241
server_thread = threading.Thread(
189-
target=start_calculator_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
242+
target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
190243
)
191244
server_thread.start()
192245
time.sleep(2) # wait for server to startup completely
@@ -196,12 +249,22 @@ def transport_callback() -> MCPTransport:
196249

197250
streamable_http_client = MCPClient(transport_callback)
198251
with streamable_http_client:
252+
# Test tools
199253
agent = Agent(tools=streamable_http_client.list_tools_sync())
200254
agent("add 1 and 2 using a calculator")
201255

202256
tool_use_content_blocks = _messages_to_content_blocks(agent.messages)
203257
assert any([block["name"] == "calculator" for block in tool_use_content_blocks])
204258

259+
# Test prompts
260+
prompts_result = streamable_http_client.list_prompts_sync()
261+
assert len(prompts_result.prompts) >= 2
262+
263+
greeting_result = streamable_http_client.get_prompt_sync("greeting_prompt", {"name": "Charlie"})
264+
assert len(greeting_result.messages) > 0
265+
prompt_text = greeting_result.messages[0].content.text
266+
assert "Hello, Charlie!" in prompt_text
267+
205268

206269
def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]:
207270
return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block]

0 commit comments

Comments
 (0)