11from fastapi import FastAPI , Body
22from fastapi .middleware .cors import CORSMiddleware
33from starlette .routing import Mount
4- from pydantic import create_model
4+ from pydantic import create_model , Field
55from contextlib import AsyncExitStack , asynccontextmanager
66
77from mcp import ClientSession , StdioServerParameters , types
1111import uvicorn
1212import json
1313import os
14+ import asyncio
1415
1516
1617def get_python_type (param_type : str ):
@@ -31,68 +32,199 @@ def get_python_type(param_type: str):
3132 # Expand as needed. PRs welcome!
3233
3334
35+ def handle_union_schema (schema : Dict [str , Any ], tool_name : str = "" ) -> Dict [str , Any ]:
36+ """Handle anyOf/oneOf schemas by flattening them for FastAPI/Pydantic"""
37+ if "anyOf" in schema or "oneOf" in schema :
38+ union_key = "anyOf" if "anyOf" in schema else "oneOf"
39+ union_types = schema [union_key ]
40+
41+ flattened = {
42+ "type" : "object" ,
43+ "properties" : {},
44+ "required" : []
45+ }
46+
47+ discriminator = schema .get ("discriminator" , {}).get ("propertyName" )
48+ if not discriminator :
49+ potential_discriminators = {}
50+
51+ for variant in union_types :
52+ if variant .get ("type" ) == "object" and "properties" in variant :
53+ for prop_name , prop_schema in variant .get ("properties" , {}).items ():
54+ if "const" in prop_schema :
55+ potential_discriminators .setdefault (prop_name , 0 )
56+ potential_discriminators [prop_name ] += 1
57+
58+ if potential_discriminators :
59+ max_count = max (potential_discriminators .values ())
60+ for prop_name , count in potential_discriminators .items ():
61+ if count == max_count :
62+ discriminator = prop_name
63+ break
64+
65+ variant_map = {}
66+ discriminator_values = []
67+
68+ for variant in union_types :
69+ if variant .get ("type" ) == "object" and "properties" in variant :
70+ for prop_name , prop_schema in variant .get ("properties" , {}).items ():
71+ if prop_name not in flattened ["properties" ] or prop_name == discriminator :
72+ flattened ["properties" ][prop_name ] = prop_schema .copy ()
73+
74+ if discriminator and discriminator in variant .get ("properties" , {}):
75+ disc_prop = variant ["properties" ][discriminator ]
76+ if "const" in disc_prop :
77+ disc_value = disc_prop ["const" ]
78+ discriminator_values .append (disc_value )
79+ variant_map [disc_value ] = variant
80+
81+ if discriminator and discriminator_values and discriminator in flattened ["properties" ]:
82+ flattened ["properties" ][discriminator ] = {
83+ "type" : "string" ,
84+ "enum" : discriminator_values ,
85+ "description" : flattened ["properties" ][discriminator ].get ("description" , f"Operation type: { ', ' .join (discriminator_values )} " )
86+ }
87+
88+ if discriminator not in flattened ["required" ]:
89+ flattened ["required" ].append (discriminator )
90+
91+ flattened ["x-enumValueMappings" ] = {}
92+
93+ for value , variant in variant_map .items ():
94+ variant_required = variant .get ("required" , [])
95+ if variant_required :
96+ flattened ["x-enumValueMappings" ][value ] = {
97+ "required" : variant_required
98+ }
99+
100+ if "description" in schema :
101+ flattened ["description" ] = schema ["description" ]
102+
103+ return flattened
104+
105+ # If not a union schema, return the original schema
106+ return schema
107+
108+
34109async def create_dynamic_endpoints (app : FastAPI ):
35110 session = app .state .session
36111 if not session :
37112 raise ValueError ("Session is not initialized in the app state." )
38113
39- result = await session .initialize ()
40- server_info = getattr (result , "serverInfo" , None )
41- if server_info :
42- app .title = server_info .name or app .title
43- app .description = (
44- f"{ server_info .name } MCP Server" if server_info .name else app .description
45- )
46- app .version = server_info .version or app .version
47-
48- tools_result = await session .list_tools ()
49- tools = tools_result .tools
50-
51- for tool in tools :
52- endpoint_name = tool .name
53- endpoint_description = tool .description
54- schema = tool .inputSchema
55-
56- # Build Pydantic model
57- model_fields = {}
58- required_fields = schema .get ("required" , [])
59- for param_name , param_schema in schema ["properties" ].items ():
60- param_type = param_schema .get ("type" , "string" )
61- param_desc = param_schema .get ("description" , "" )
62- python_type = get_python_type (param_type )
63- default_value = ... if param_name in required_fields else None
64- model_fields [param_name ] = (
65- python_type ,
66- Body (default_value , description = param_desc ),
114+ try :
115+ result = await session .initialize ()
116+
117+ server_info = getattr (result , "serverInfo" , None )
118+ if server_info :
119+ app .title = server_info .name or app .title
120+ app .description = (
121+ f"{ server_info .name } MCP Server" if server_info .name else app .description
67122 )
123+ app .version = server_info .version or app .version
124+ except Exception as e :
125+ raise ValueError (f"Error initializing MCP session: { str (e )} " )
68126
69- FormModel = create_model (f"{ endpoint_name } _form_model" , ** model_fields )
127+ try :
128+ tools_result = await session .list_tools ()
129+ tools = tools_result .tools
130+ except Exception as e :
131+ raise ValueError (f"Error listing tools: { str (e )} " )
70132
71- def make_endpoint_func (endpoint_name : str , FormModel , session : ClientSession ):
72- async def tool_endpoint (form_data : FormModel ):
73- args = form_data .model_dump ()
74- print (f"Calling { endpoint_name } with arguments:" , args )
75- result = await session .call_tool (endpoint_name , arguments = args )
76- response = []
77- for content in result .content :
78- text = content .text
79- if isinstance (text , str ):
80- try :
81- text = json .loads (text )
82- except json .JSONDecodeError :
83- pass
84- response .append (text )
85- return response
86-
87- return tool_endpoint
88-
89- tool = make_endpoint_func (endpoint_name , FormModel , session )
133+ for tool in tools :
134+ try :
135+ endpoint_name = tool .name
136+ endpoint_description = tool .description
137+ original_schema = tool .inputSchema
138+
139+ schema = handle_union_schema (original_schema , endpoint_name )
140+
141+ # Build Pydantic model
142+ model_fields = {}
143+ required_fields = schema .get ("required" , [])
144+ discriminator = None
145+ enum_mappings = schema .get ("x-enumValueMappings" , {})
146+
147+ for param_name , param_schema in schema .get ("properties" , {}).items ():
148+ if "enum" in param_schema and param_name in required_fields :
149+ discriminator = param_name
150+ break
151+
152+ for param_name , param_schema in schema .get ("properties" , {}).items ():
153+ param_type = param_schema .get ("type" , "string" )
154+ param_desc = param_schema .get ("description" , "" )
155+ python_type = get_python_type (param_type )
156+
157+ if "enum" in param_schema :
158+ enum_values = param_schema ["enum" ]
159+ param_desc += f" Allowed values: { ', ' .join (map (str , enum_values ))} "
160+
161+ is_required = param_name in required_fields
162+
163+ if "const" in param_schema :
164+ const_value = param_schema ["const" ]
165+ model_fields [param_name ] = (
166+ python_type ,
167+ Field (default = const_value , description = param_desc ),
168+ )
169+ else :
170+ default_value = ... if is_required else None
171+ model_fields [param_name ] = (
172+ python_type ,
173+ Field (default = default_value , description = param_desc ),
174+ )
90175
91- app .post (
92- f"/{ endpoint_name } " ,
93- summary = endpoint_name .replace ("_" , " " ).title (),
94- description = endpoint_description ,
95- )(tool )
176+ if not model_fields :
177+ model_fields = {
178+ "params" : (Dict [str , Any ], Field (default = ..., description = "Tool parameters" ))
179+ }
180+
181+ FormModel = create_model (f"{ endpoint_name } _form_model" , ** model_fields )
182+
183+ def make_endpoint_func (endpoint_name : str , FormModel , session : ClientSession , enum_mappings = None , discriminator = None ):
184+ async def tool_endpoint (form_data : FormModel ):
185+ # Convert form_data to dict for sending to MCP
186+ if hasattr (form_data , "params" ) and len (model_fields ) == 1 and "params" in model_fields :
187+ args = form_data .params
188+ else :
189+ args = form_data .model_dump (exclude_unset = True )
190+
191+ if discriminator and discriminator in args and enum_mappings :
192+ disc_value = args [discriminator ]
193+ if disc_value in enum_mappings :
194+ mapping = enum_mappings [disc_value ]
195+ required_fields = mapping .get ("required" , [])
196+
197+ missing = [field for field in required_fields if field not in args ]
198+ if missing :
199+ error_msg = f"When operation is '{ disc_value } ', the following fields are required: { ', ' .join (missing )} "
200+ return [{"success" : False , "error" : error_msg , "guidance" : f"Please provide values for: { ', ' .join (missing )} " }]
201+
202+ try :
203+ result = await session .call_tool (endpoint_name , arguments = args )
204+ response = []
205+ for content in result .content :
206+ text = content .text
207+ if isinstance (text , str ):
208+ try :
209+ text = json .loads (text )
210+ except json .JSONDecodeError :
211+ pass
212+ response .append (text )
213+ return response
214+ except Exception as e :
215+ return [{"success" : False , "error" : str (e ), "guidance" : "Please check your parameters and try again." }]
216+
217+ return tool_endpoint
218+
219+ tool = make_endpoint_func (endpoint_name , FormModel , session , enum_mappings , discriminator )
220+
221+ app .post (
222+ f"/{ endpoint_name } " ,
223+ summary = endpoint_name .replace ("_" , " " ).title (),
224+ description = endpoint_description ,
225+ )(tool )
226+ except Exception :
227+ continue
96228
97229
98230@asynccontextmanager
@@ -117,11 +249,19 @@ async def lifespan(app: FastAPI):
117249 env = {** env },
118250 )
119251
120- async with stdio_client (server_params ) as (reader , writer ):
121- async with ClientSession (reader , writer ) as session :
122- app .state .session = session
123- await create_dynamic_endpoints (app )
124- yield
252+ try :
253+ async with stdio_client (server_params ) as (reader , writer ):
254+ async with ClientSession (reader , writer ) as session :
255+ app .state .session = session
256+
257+ try :
258+ await asyncio .wait_for (create_dynamic_endpoints (app ), timeout = 30 )
259+ except asyncio .TimeoutError :
260+ pass
261+
262+ yield
263+ except Exception :
264+ yield
125265
126266
127267async def run (host : str = "127.0.0.1" , port : int = 8000 , ** kwargs ):
0 commit comments