From a0b5dcc30f5179bf7bf05c815c38158843a09019 Mon Sep 17 00:00:00 2001 From: bmen25124 Date: Mon, 5 May 2025 21:52:59 +0300 Subject: [PATCH] Fixed "Custom field not found" error --- src/mcpo/main.py | 6 ++- src/mcpo/tests/test_main.py | 97 ++++++++++++++++++++++++++++++++++++- src/mcpo/utils/main.py | 61 +++++++++++++++++++---- 3 files changed, 150 insertions(+), 14 deletions(-) diff --git a/src/mcpo/main.py b/src/mcpo/main.py index 7339c710..a193da79 100644 --- a/src/mcpo/main.py +++ b/src/mcpo/main.py @@ -44,7 +44,8 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): f"{endpoint_name}_form_model", inputSchema.get("properties", {}), inputSchema.get("required", []), - inputSchema.get("$defs", {}), + schema_defs=inputSchema.get("$defs", {}), + root_schema=inputSchema, ) response_model_fields = None @@ -53,7 +54,8 @@ async def create_dynamic_endpoints(app: FastAPI, api_dependency=None): f"{endpoint_name}_response_model", outputSchema.get("properties", {}), outputSchema.get("required", []), - outputSchema.get("$defs", {}), + schema_defs=outputSchema.get("$defs", {}), + root_schema=outputSchema, ) tool_handler = get_tool_handler( diff --git a/src/mcpo/tests/test_main.py b/src/mcpo/tests/test_main.py index a7cfe5f0..a9fa155f 100644 --- a/src/mcpo/tests/test_main.py +++ b/src/mcpo/tests/test_main.py @@ -1,11 +1,11 @@ import pytest from pydantic import BaseModel, Field -from typing import Any, List, Dict, Union +from typing import Any, List, Dict, Type, Union from mcpo.utils.main import _process_schema_property -_model_cache = {} +_model_cache: Dict[str, Type] = {} @pytest.fixture(autouse=True) @@ -310,3 +310,96 @@ def test_multi_type_property_with_any_of(): # assert result_field parameter config assert result_field.description == "A property with multiple types" + + +def test_process_property_reference(): + schema = { + "type": "object", + "properties": { + "start_time": { + "type": "string", + "format": "date-time", + "description": "Start time in ISO 8601 format", + }, + "end_time": { + "$ref": "#/properties/start_time", + "description": "End time in ISO 8601 format", + }, + }, + "required": ["start_time"], + } + + # First process the start_time property to ensure reference target exists + result_type, result_field = _process_schema_property( + _model_cache, + schema, + "test", + "prop", + True, + schema_defs=None, + root_schema=schema, + ) + + assert issubclass(result_type, BaseModel) + model_fields = result_type.model_fields + + # Check that both fields have the same type (string) + assert model_fields["start_time"].annotation is str + assert model_fields["end_time"].annotation is str + + # Check descriptions are preserved + assert model_fields["start_time"].description == "Start time in ISO 8601 format" + assert model_fields["end_time"].description == "End time in ISO 8601 format" + + +def test_process_invalid_property_reference(): + schema = { + "type": "object", + "properties": {"invalid_ref": {"$ref": "#/properties/nonexistent"}}, + } + + with pytest.raises( + ValueError, match="Reference not found: #/properties/nonexistent" + ): + _process_schema_property( + _model_cache, + schema, + "test", + "prop", + True, + schema_defs=None, + root_schema=schema, + ) + + +def test_process_nested_property_reference(): + schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "created_at": {"type": "string", "format": "date-time"}, + "updated_at": {"$ref": "#/properties/user/properties/created_at"}, + }, + } + }, + } + + result_type, _ = _process_schema_property( + _model_cache, + schema, + "test", + "prop", + True, + schema_defs=None, + root_schema=schema, + ) + + assert issubclass(result_type, BaseModel) + user_field = result_type.model_fields["user"] + user_model = user_field.annotation + + # Both timestamps should be strings + assert user_model.model_fields["created_at"].annotation is str + assert user_model.model_fields["updated_at"].annotation is str diff --git a/src/mcpo/utils/main.py b/src/mcpo/utils/main.py index 4d72815a..13ad924d 100644 --- a/src/mcpo/utils/main.py +++ b/src/mcpo/utils/main.py @@ -54,8 +54,9 @@ def _process_schema_property( model_name_prefix: str, prop_name: str, is_required: bool, - schema_defs: Optional[Dict] = None, -) -> tuple[Union[Type, List, ForwardRef, Any], FieldInfo]: + schema_defs: Optional[Dict[str, Any]] = None, + root_schema: Optional[Dict[str, Any]] = None, +) -> tuple[Union[Type, List[Any], ForwardRef, Any], FieldInfo]: """ Recursively processes a schema property to determine its Python type hint and Pydantic Field definition. @@ -64,11 +65,34 @@ def _process_schema_property( A tuple containing (python_type_hint, pydantic_field). The pydantic_field contains default value and description. """ + original_schema = prop_schema.copy() if "$ref" in prop_schema: ref = prop_schema["$ref"] - ref = ref.split("/")[-1] - assert ref in schema_defs, "Custom field not found" - prop_schema = schema_defs[ref] + ref_parts = ref.split("/")[1:] # Skip the '#' at the start + + # Start from the root schema + current: Optional[Dict[str, Any]] = None + if ref_parts[0] in ["definitions", "$defs"] and schema_defs is not None: + current = schema_defs + elif ref_parts[0] == "properties" and root_schema is not None: + current = root_schema.get("properties", {}) + + if current is None: + raise ValueError(f"Cannot resolve reference: {ref}") + + # Navigate through the reference path + for part in ref_parts[1:]: # Skip the first part since we already used it + if not isinstance(current, dict): + raise ValueError(f"Invalid reference path: {ref}") + current = current.get(part) + if current is None: + raise ValueError(f"Reference not found: {ref}") + + # Merge referenced schema while preserving local overrides + prop_schema = { + **current, + **{k: v for k, v in original_schema.items() if k != "$ref"}, + } prop_type = prop_schema.get("type") prop_desc = prop_schema.get("description", "") @@ -87,6 +111,8 @@ def _process_schema_property( f"{model_name_prefix}_{prop_name}", f"choice_{i}", False, + schema_defs=schema_defs, + root_schema=root_schema, ) type_hints.append(type_hint) return Union[tuple(type_hints)], pydantic_field @@ -100,7 +126,13 @@ def _process_schema_property( temp_schema = dict(prop_schema) temp_schema["type"] = type_option type_hint, _ = _process_schema_property( - _model_cache, temp_schema, model_name_prefix, prop_name, False + _model_cache, + temp_schema, + model_name_prefix, + prop_name, + False, + schema_defs=schema_defs, + root_schema=root_schema, ) type_hints.append(type_hint) @@ -127,7 +159,8 @@ def _process_schema_property( nested_model_name, name, is_nested_required, - schema_defs, + schema_defs=schema_defs, + root_schema=root_schema, ) nested_fields[name] = (nested_type_hint, nested_pydantic_field) @@ -153,7 +186,8 @@ def _process_schema_property( f"{model_name_prefix}_{prop_name}", "item", False, # Items aren't required at this level, - schema_defs, + schema_defs=schema_defs, + root_schema=root_schema, ) list_type_hint = List[item_type_hint] return list_type_hint, pydantic_field @@ -172,7 +206,13 @@ def _process_schema_property( return Any, pydantic_field -def get_model_fields(form_model_name, properties, required_fields, schema_defs=None): +def get_model_fields( + form_model_name: str, + properties: Dict[str, Any], + required_fields: List[str], + schema_defs: Optional[Dict[str, Any]] = None, + root_schema: Optional[Dict[str, Any]] = None, +) -> Dict[str, tuple[Union[Type, List[Any], ForwardRef, Any], FieldInfo]]: model_fields = {} _model_cache: Dict[str, Type] = {} @@ -185,7 +225,8 @@ def get_model_fields(form_model_name, properties, required_fields, schema_defs=N form_model_name, param_name, is_required, - schema_defs, + schema_defs=schema_defs, + root_schema=root_schema, ) # Use the generated type hint and Field info model_fields[param_name] = (python_type_hint, pydantic_field_info)