Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/mcpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
f"{endpoint_name}_form_model",
inputSchema.get("properties", {}),
inputSchema.get("required", []),
inputSchema.get("$defs", {}),
schema_defs=inputSchema.get("$defs", {}),
root_schema=inputSchema,
)

response_model_fields = None
Expand All @@ -53,7 +54,8 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None):
f"{endpoint_name}_response_model",
outputSchema.get("properties", {}),
outputSchema.get("required", []),
outputSchema.get("$defs", {}),
schema_defs=outputSchema.get("$defs", {}),
root_schema=outputSchema,
)

tool_handler = get_tool_handler(
Expand Down
97 changes: 95 additions & 2 deletions src/mcpo/tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from pydantic import BaseModel, Field
from typing import Any, List, Dict, Union
from typing import Any, List, Dict, Type, Union

from mcpo.utils.main import _process_schema_property


_model_cache = {}
_model_cache: Dict[str, Type] = {}


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -310,3 +310,96 @@ def test_multi_type_property_with_any_of():

# assert result_field parameter config
assert result_field.description == "A property with multiple types"


def test_process_property_reference():
schema = {
"type": "object",
"properties": {
"start_time": {
"type": "string",
"format": "date-time",
"description": "Start time in ISO 8601 format",
},
"end_time": {
"$ref": "#/properties/start_time",
"description": "End time in ISO 8601 format",
},
},
"required": ["start_time"],
}

# First process the start_time property to ensure reference target exists
result_type, result_field = _process_schema_property(
_model_cache,
schema,
"test",
"prop",
True,
schema_defs=None,
root_schema=schema,
)

assert issubclass(result_type, BaseModel)
model_fields = result_type.model_fields

# Check that both fields have the same type (string)
assert model_fields["start_time"].annotation is str
assert model_fields["end_time"].annotation is str

# Check descriptions are preserved
assert model_fields["start_time"].description == "Start time in ISO 8601 format"
assert model_fields["end_time"].description == "End time in ISO 8601 format"


def test_process_invalid_property_reference():
schema = {
"type": "object",
"properties": {"invalid_ref": {"$ref": "#/properties/nonexistent"}},
}

with pytest.raises(
ValueError, match="Reference not found: #/properties/nonexistent"
):
_process_schema_property(
_model_cache,
schema,
"test",
"prop",
True,
schema_defs=None,
root_schema=schema,
)


def test_process_nested_property_reference():
schema = {
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"created_at": {"type": "string", "format": "date-time"},
"updated_at": {"$ref": "#/properties/user/properties/created_at"},
},
}
},
}

result_type, _ = _process_schema_property(
_model_cache,
schema,
"test",
"prop",
True,
schema_defs=None,
root_schema=schema,
)

assert issubclass(result_type, BaseModel)
user_field = result_type.model_fields["user"]
user_model = user_field.annotation

# Both timestamps should be strings
assert user_model.model_fields["created_at"].annotation is str
assert user_model.model_fields["updated_at"].annotation is str
61 changes: 51 additions & 10 deletions src/mcpo/utils/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def _process_schema_property(
model_name_prefix: str,
prop_name: str,
is_required: bool,
schema_defs: Optional[Dict] = None,
) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]:
schema_defs: Optional[Dict[str, Any]] = None,
root_schema: Optional[Dict[str, Any]] = None,
) -> tuple[Union[Type, List[Any], ForwardRef, Any], FieldInfo]:
"""
Recursively processes a schema property to determine its Python type hint
and Pydantic Field definition.
Expand All @@ -64,11 +65,34 @@ def _process_schema_property(
A tuple containing (python_type_hint, pydantic_field).
The pydantic_field contains default value and description.
"""
original_schema = prop_schema.copy()
if "$ref" in prop_schema:
ref = prop_schema["$ref"]
ref = ref.split("/")[-1]
assert ref in schema_defs, "Custom field not found"
prop_schema = schema_defs[ref]
ref_parts = ref.split("/")[1:] # Skip the '#' at the start

# Start from the root schema
current: Optional[Dict[str, Any]] = None
if ref_parts[0] in ["definitions", "$defs"] and schema_defs is not None:
current = schema_defs
elif ref_parts[0] == "properties" and root_schema is not None:
current = root_schema.get("properties", {})

if current is None:
raise ValueError(f"Cannot resolve reference: {ref}")

# Navigate through the reference path
for part in ref_parts[1:]: # Skip the first part since we already used it
if not isinstance(current, dict):
raise ValueError(f"Invalid reference path: {ref}")
current = current.get(part)
if current is None:
raise ValueError(f"Reference not found: {ref}")

# Merge referenced schema while preserving local overrides
prop_schema = {
**current,
**{k: v for k, v in original_schema.items() if k != "$ref"},
}

prop_type = prop_schema.get("type")
prop_desc = prop_schema.get("description", "")
Expand All @@ -87,6 +111,8 @@ def _process_schema_property(
f"{model_name_prefix}_{prop_name}",
f"choice_{i}",
False,
schema_defs=schema_defs,
root_schema=root_schema,
)
type_hints.append(type_hint)
return Union[tuple(type_hints)], pydantic_field
Expand All @@ -100,7 +126,13 @@ def _process_schema_property(
temp_schema = dict(prop_schema)
temp_schema["type"] = type_option
type_hint, _ = _process_schema_property(
_model_cache, temp_schema, model_name_prefix, prop_name, False
_model_cache,
temp_schema,
model_name_prefix,
prop_name,
False,
schema_defs=schema_defs,
root_schema=root_schema,
)
type_hints.append(type_hint)

Expand All @@ -127,7 +159,8 @@ def _process_schema_property(
nested_model_name,
name,
is_nested_required,
schema_defs,
schema_defs=schema_defs,
root_schema=root_schema,
)

nested_fields[name] = (nested_type_hint, nested_pydantic_field)
Expand All @@ -153,7 +186,8 @@ def _process_schema_property(
f"{model_name_prefix}_{prop_name}",
"item",
False, # Items aren't required at this level,
schema_defs,
schema_defs=schema_defs,
root_schema=root_schema,
)
list_type_hint = List[item_type_hint]
return list_type_hint, pydantic_field
Expand All @@ -172,7 +206,13 @@ def _process_schema_property(
return Any, pydantic_field


def get_model_fields(form_model_name, properties, required_fields, schema_defs=None):
def get_model_fields(
form_model_name: str,
properties: Dict[str, Any],
required_fields: List[str],
schema_defs: Optional[Dict[str, Any]] = None,
root_schema: Optional[Dict[str, Any]] = None,
) -> Dict[str, tuple[Union[Type, List[Any], ForwardRef, Any], FieldInfo]]:
model_fields = {}

_model_cache: Dict[str, Type] = {}
Expand All @@ -185,7 +225,8 @@ def get_model_fields(form_model_name, properties, required_fields, schema_defs=N
form_model_name,
param_name,
is_required,
schema_defs,
schema_defs=schema_defs,
root_schema=root_schema,
)
# Use the generated type hint and Field info
model_fields[param_name] = (python_type_hint, pydantic_field_info)
Expand Down