Skip to content

Commit 8794a04

Browse files
committed
feat: Enhance dynamic endpoint creation with union schema handling
1 parent d38edbe commit 8794a04

File tree

1 file changed

+199
-59
lines changed

1 file changed

+199
-59
lines changed

src/mcpo/main.py

Lines changed: 199 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from fastapi import FastAPI, Body
22
from fastapi.middleware.cors import CORSMiddleware
33
from starlette.routing import Mount
4-
from pydantic import create_model
4+
from pydantic import create_model, Field
55
from contextlib import AsyncExitStack, asynccontextmanager
66

77
from mcp import ClientSession, StdioServerParameters, types
@@ -11,6 +11,7 @@
1111
import uvicorn
1212
import json
1313
import os
14+
import asyncio
1415

1516

1617
def get_python_type(param_type: str):
@@ -31,68 +32,199 @@ def get_python_type(param_type: str):
3132
# Expand as needed. PRs welcome!
3233

3334

35+
def handle_union_schema(schema: Dict[str, Any], tool_name: str = "") -> Dict[str, Any]:
36+
"""Handle anyOf/oneOf schemas by flattening them for FastAPI/Pydantic"""
37+
if "anyOf" in schema or "oneOf" in schema:
38+
union_key = "anyOf" if "anyOf" in schema else "oneOf"
39+
union_types = schema[union_key]
40+
41+
flattened = {
42+
"type": "object",
43+
"properties": {},
44+
"required": []
45+
}
46+
47+
discriminator = schema.get("discriminator", {}).get("propertyName")
48+
if not discriminator:
49+
potential_discriminators = {}
50+
51+
for variant in union_types:
52+
if variant.get("type") == "object" and "properties" in variant:
53+
for prop_name, prop_schema in variant.get("properties", {}).items():
54+
if "const" in prop_schema:
55+
potential_discriminators.setdefault(prop_name, 0)
56+
potential_discriminators[prop_name] += 1
57+
58+
if potential_discriminators:
59+
max_count = max(potential_discriminators.values())
60+
for prop_name, count in potential_discriminators.items():
61+
if count == max_count:
62+
discriminator = prop_name
63+
break
64+
65+
variant_map = {}
66+
discriminator_values = []
67+
68+
for variant in union_types:
69+
if variant.get("type") == "object" and "properties" in variant:
70+
for prop_name, prop_schema in variant.get("properties", {}).items():
71+
if prop_name not in flattened["properties"] or prop_name == discriminator:
72+
flattened["properties"][prop_name] = prop_schema.copy()
73+
74+
if discriminator and discriminator in variant.get("properties", {}):
75+
disc_prop = variant["properties"][discriminator]
76+
if "const" in disc_prop:
77+
disc_value = disc_prop["const"]
78+
discriminator_values.append(disc_value)
79+
variant_map[disc_value] = variant
80+
81+
if discriminator and discriminator_values and discriminator in flattened["properties"]:
82+
flattened["properties"][discriminator] = {
83+
"type": "string",
84+
"enum": discriminator_values,
85+
"description": flattened["properties"][discriminator].get("description", f"Operation type: {', '.join(discriminator_values)}")
86+
}
87+
88+
if discriminator not in flattened["required"]:
89+
flattened["required"].append(discriminator)
90+
91+
flattened["x-enumValueMappings"] = {}
92+
93+
for value, variant in variant_map.items():
94+
variant_required = variant.get("required", [])
95+
if variant_required:
96+
flattened["x-enumValueMappings"][value] = {
97+
"required": variant_required
98+
}
99+
100+
if "description" in schema:
101+
flattened["description"] = schema["description"]
102+
103+
return flattened
104+
105+
# If not a union schema, return the original schema
106+
return schema
107+
108+
34109
async def create_dynamic_endpoints(app: FastAPI):
35110
session = app.state.session
36111
if not session:
37112
raise ValueError("Session is not initialized in the app state.")
38113

39-
result = await session.initialize()
40-
server_info = getattr(result, "serverInfo", None)
41-
if server_info:
42-
app.title = server_info.name or app.title
43-
app.description = (
44-
f"{server_info.name} MCP Server" if server_info.name else app.description
45-
)
46-
app.version = server_info.version or app.version
47-
48-
tools_result = await session.list_tools()
49-
tools = tools_result.tools
50-
51-
for tool in tools:
52-
endpoint_name = tool.name
53-
endpoint_description = tool.description
54-
schema = tool.inputSchema
55-
56-
# Build Pydantic model
57-
model_fields = {}
58-
required_fields = schema.get("required", [])
59-
for param_name, param_schema in schema["properties"].items():
60-
param_type = param_schema.get("type", "string")
61-
param_desc = param_schema.get("description", "")
62-
python_type = get_python_type(param_type)
63-
default_value = ... if param_name in required_fields else None
64-
model_fields[param_name] = (
65-
python_type,
66-
Body(default_value, description=param_desc),
114+
try:
115+
result = await session.initialize()
116+
117+
server_info = getattr(result, "serverInfo", None)
118+
if server_info:
119+
app.title = server_info.name or app.title
120+
app.description = (
121+
f"{server_info.name} MCP Server" if server_info.name else app.description
67122
)
123+
app.version = server_info.version or app.version
124+
except Exception as e:
125+
raise ValueError(f"Error initializing MCP session: {str(e)}")
68126

69-
FormModel = create_model(f"{endpoint_name}_form_model", **model_fields)
127+
try:
128+
tools_result = await session.list_tools()
129+
tools = tools_result.tools
130+
except Exception as e:
131+
raise ValueError(f"Error listing tools: {str(e)}")
70132

71-
def make_endpoint_func(endpoint_name: str, FormModel, session: ClientSession):
72-
async def tool_endpoint(form_data: FormModel):
73-
args = form_data.model_dump()
74-
print(f"Calling {endpoint_name} with arguments:", args)
75-
result = await session.call_tool(endpoint_name, arguments=args)
76-
response = []
77-
for content in result.content:
78-
text = content.text
79-
if isinstance(text, str):
80-
try:
81-
text = json.loads(text)
82-
except json.JSONDecodeError:
83-
pass
84-
response.append(text)
85-
return response
86-
87-
return tool_endpoint
88-
89-
tool = make_endpoint_func(endpoint_name, FormModel, session)
133+
for tool in tools:
134+
try:
135+
endpoint_name = tool.name
136+
endpoint_description = tool.description
137+
original_schema = tool.inputSchema
138+
139+
schema = handle_union_schema(original_schema, endpoint_name)
140+
141+
# Build Pydantic model
142+
model_fields = {}
143+
required_fields = schema.get("required", [])
144+
discriminator = None
145+
enum_mappings = schema.get("x-enumValueMappings", {})
146+
147+
for param_name, param_schema in schema.get("properties", {}).items():
148+
if "enum" in param_schema and param_name in required_fields:
149+
discriminator = param_name
150+
break
151+
152+
for param_name, param_schema in schema.get("properties", {}).items():
153+
param_type = param_schema.get("type", "string")
154+
param_desc = param_schema.get("description", "")
155+
python_type = get_python_type(param_type)
156+
157+
if "enum" in param_schema:
158+
enum_values = param_schema["enum"]
159+
param_desc += f" Allowed values: {', '.join(map(str, enum_values))}"
160+
161+
is_required = param_name in required_fields
162+
163+
if "const" in param_schema:
164+
const_value = param_schema["const"]
165+
model_fields[param_name] = (
166+
python_type,
167+
Field(default=const_value, description=param_desc),
168+
)
169+
else:
170+
default_value = ... if is_required else None
171+
model_fields[param_name] = (
172+
python_type,
173+
Field(default=default_value, description=param_desc),
174+
)
90175

91-
app.post(
92-
f"/{endpoint_name}",
93-
summary=endpoint_name.replace("_", " ").title(),
94-
description=endpoint_description,
95-
)(tool)
176+
if not model_fields:
177+
model_fields = {
178+
"params": (Dict[str, Any], Field(default=..., description="Tool parameters"))
179+
}
180+
181+
FormModel = create_model(f"{endpoint_name}_form_model", **model_fields)
182+
183+
def make_endpoint_func(endpoint_name: str, FormModel, session: ClientSession, enum_mappings=None, discriminator=None):
184+
async def tool_endpoint(form_data: FormModel):
185+
# Convert form_data to dict for sending to MCP
186+
if hasattr(form_data, "params") and len(model_fields) == 1 and "params" in model_fields:
187+
args = form_data.params
188+
else:
189+
args = form_data.model_dump(exclude_unset=True)
190+
191+
if discriminator and discriminator in args and enum_mappings:
192+
disc_value = args[discriminator]
193+
if disc_value in enum_mappings:
194+
mapping = enum_mappings[disc_value]
195+
required_fields = mapping.get("required", [])
196+
197+
missing = [field for field in required_fields if field not in args]
198+
if missing:
199+
error_msg = f"When operation is '{disc_value}', the following fields are required: {', '.join(missing)}"
200+
return [{"success": False, "error": error_msg, "guidance": f"Please provide values for: {', '.join(missing)}"}]
201+
202+
try:
203+
result = await session.call_tool(endpoint_name, arguments=args)
204+
response = []
205+
for content in result.content:
206+
text = content.text
207+
if isinstance(text, str):
208+
try:
209+
text = json.loads(text)
210+
except json.JSONDecodeError:
211+
pass
212+
response.append(text)
213+
return response
214+
except Exception as e:
215+
return [{"success": False, "error": str(e), "guidance": "Please check your parameters and try again."}]
216+
217+
return tool_endpoint
218+
219+
tool = make_endpoint_func(endpoint_name, FormModel, session, enum_mappings, discriminator)
220+
221+
app.post(
222+
f"/{endpoint_name}",
223+
summary=endpoint_name.replace("_", " ").title(),
224+
description=endpoint_description,
225+
)(tool)
226+
except Exception:
227+
continue
96228

97229

98230
@asynccontextmanager
@@ -117,11 +249,19 @@ async def lifespan(app: FastAPI):
117249
env={**env},
118250
)
119251

120-
async with stdio_client(server_params) as (reader, writer):
121-
async with ClientSession(reader, writer) as session:
122-
app.state.session = session
123-
await create_dynamic_endpoints(app)
124-
yield
252+
try:
253+
async with stdio_client(server_params) as (reader, writer):
254+
async with ClientSession(reader, writer) as session:
255+
app.state.session = session
256+
257+
try:
258+
await asyncio.wait_for(create_dynamic_endpoints(app), timeout=30)
259+
except asyncio.TimeoutError:
260+
pass
261+
262+
yield
263+
except Exception:
264+
yield
125265

126266

127267
async def run(host: str = "127.0.0.1", port: int = 8000, **kwargs):

0 commit comments

Comments
 (0)