diff --git a/pydantic_xml/model.py b/pydantic_xml/model.py index 1ee18ef..aa6c895 100644 --- a/pydantic_xml/model.py +++ b/pydantic_xml/model.py @@ -304,12 +304,13 @@ def to_xml_tree( assert self.__xml_serializer__ is not None, f"model {type(self).__name__} is partially initialized" root = XmlElement(tag=self.__xml_serializer__.element_name, nsmap=self.__xml_serializer__.nsmap) + encoded = pdc.to_jsonable_python( + self, + by_alias=False, + fallback=lambda obj: obj if not isinstance(obj, ElementT) else None, # for raw fields support + ) self.__xml_serializer__.serialize( - root, self, pdc.to_jsonable_python( - self, - by_alias=False, - fallback=lambda obj: obj if not isinstance(obj, ElementT) else None, # for raw fields support - ), + root, self, encoded, skip_empty=skip_empty, exclude_none=exclude_none, exclude_unset=exclude_unset, diff --git a/pydantic_xml/serializers/factories/model.py b/pydantic_xml/serializers/factories/model.py index 39fcf4a..5de15cf 100644 --- a/pydantic_xml/serializers/factories/model.py +++ b/pydantic_xml/serializers/factories/model.py @@ -9,8 +9,11 @@ import pydantic_xml as pxml from pydantic_xml import errors, utils from pydantic_xml.element import XmlElementReader, XmlElementWriter, is_element_nill, make_element_nill +from pydantic_xml.element.native import ElementT from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP, extract_field_xml_entity_info +from pydantic_xml.serializers.factories.primitive import AttributeSerializer from pydantic_xml.serializers.serializer import SearchMode, Serializer +from pydantic_xml.serializers.factories.raw import ElementSerializer as RawElementSerializer from pydantic_xml.typedefs import EntityLocation, Location, NsMap from pydantic_xml.utils import QName, merge_nsmaps, select_ns @@ -51,6 +54,29 @@ def _check_extra(cls, error_title: str, element: XmlElementReader) -> None: if line_errors: raise pd.ValidationError.from_exception_data(title=error_title, line_errors=line_errors) + def _keep_extra(self, element: XmlElementReader) -> Dict: + """Get a struct of extra (=unmapped) data from the XML. + + Attributes are put in key-value pairs directly. + Child elements are put in as native Elements. + """ + result = {} + + # Extract un-mapped attributes: + if attrs := element._state.attrib: + for name, value in attrs.items(): + if name not in self._field_serializers: + result[name] = value + + # `get_unbound` returns paths of leaf-level elements of type `XmlElement`, while + # we want to produce the same result of a raw-element, which is `ElemenT` + # So manually find un-mapped elements and get the native element back: + for sub_element in element._state.elements: + if (tag := sub_element.tag) not in self._field_serializers: + result[tag] = sub_element.to_native() + + return result + class ModelSerializer(BaseModelSerializer): @classmethod @@ -163,7 +189,23 @@ def serialize( if self._model.__xml_skip_empty__ is not None: skip_empty = self._model.__xml_skip_empty__ - for field_name, field_serializer in self._field_serializers.items(): + all_fields = list(self._field_serializers.keys()) + all_fields += [k for k in encoded.keys() if k not in all_fields] + # ^ avoid sets to preserve order of fields + + for field_name in all_fields: + field_serializer = self._field_serializers.get(field_name, None) + if field_serializer is None: # Probably from an `extra` field + if encoded[field_name] is None: + field_serializer = RawElementSerializer(field_name, ns=None, + nsmap=None, + search_mode=SearchMode.ORDERED, + computed=False) + else: + field_serializer = AttributeSerializer(field_name, ns=None, + nsmap=None, computed=False) + + if field_name in self._fields_serialization_exclude: continue if exclude_unset and field_name not in value.__pydantic_fields_set__: @@ -212,8 +254,13 @@ def deserialize( if field_errors: raise utils.into_validation_error(title=self._model.__name__, errors_map=field_errors) - if self._model.model_config.get('extra', 'ignore') == 'forbid': + config_extra = self._model.model_config.get('extra', 'ignore') + if config_extra == 'forbid': self._check_extra(self._model.__name__, element) + elif config_extra == 'allow': + result.update( + self._keep_extra(element) + ) try: return self._model.model_validate(result, strict=False, context=context) diff --git a/tests/test_extra_allow.py b/tests/test_extra_allow.py new file mode 100644 index 0000000..98cdf06 --- /dev/null +++ b/tests/test_extra_allow.py @@ -0,0 +1,148 @@ +from pydantic import ValidationError +from pydantic_xml import BaseXmlModel, attr, element +from pydantic_xml.element.native import ElementT + +import pytest + +from tests.helpers import assert_xml_equal + + +def test_extra_attributes_ignored(): + class TestModel(BaseXmlModel, tag='model', extra='ignore'): + prop1: str = attr() + data: str + + xml = ''' + text + ''' + + actual_obj = TestModel.from_xml(xml) + assert actual_obj.model_extra is None + + +def test_extra_attributes_forbidden(): + class TestModel(BaseXmlModel, tag='model', extra='forbid'): + prop1: str = attr() + data: str + + xml = ''' + text + ''' + + with pytest.raises(ValidationError): + _ = TestModel.from_xml(xml) + + +def test_extra_attributes(): + class TestModel(BaseXmlModel, tag='model', extra='allow'): + prop1: str = attr() + data: str + + xml = ''' + text + ''' + + actual_obj = TestModel.from_xml(xml) + assert 'p2' == actual_obj.model_extra['prop2'] + + +def test_extra_elements(): + + class TestModelChild(BaseXmlModel, tag='child'): + data: str + + class TestModel(BaseXmlModel, tag='model', extra='allow'): + child: TestModelChild + + xml = ''' + + hello world! + hi again... + + 1 + 2 + 3 + + 3.14 + + + + ''' + + actual_obj = TestModel.from_xml(xml) + assert 'hello world!' == actual_obj.child.data + assert 'extra_child' in actual_obj.model_extra + assert 'extra_nested' in actual_obj.model_extra + + +def test_raw_save(): + # Just for debugging!!! + + class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True): + extra_nested: ElementT = element() + + xml = ''' + + + 1 + 2 + 3 + + 3.14 + + + + ''' + + actual_obj = TestModel.from_xml(xml) + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_extra_save(): + + class TestModelChild(BaseXmlModel, tag='child'): + data: str + + class TestModel(BaseXmlModel, tag='model', extra='allow'): + prop1: str = attr() + child: TestModelChild + + xml = ''' + + hello world! + hi again... + + 1 + 2 + 3 + + 3.14 + + + + ''' + + actual_obj = TestModel.from_xml(xml) + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml) + + +def test_extra_save_order(): + + class TestModelChild(BaseXmlModel, tag='child'): + data: str + + class TestModel(BaseXmlModel, tag='model', extra='allow', search_mode='ordered'): + child: TestModelChild + + xml = ''' + + Hi there + Hello world! + + ''' + + actual_obj = TestModel.from_xml(xml) + actual_xml = actual_obj.to_xml() + assert_xml_equal(actual_xml, xml)